middleware.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. package server
  2. import (
  3. "io"
  4. "github.com/gin-gonic/gin"
  5. "github.com/langgenius/dify-plugin-daemon/internal/db"
  6. "github.com/langgenius/dify-plugin-daemon/internal/server/constants"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/models"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  11. )
  12. func CheckingKey(key string) gin.HandlerFunc {
  13. return func(c *gin.Context) {
  14. // get header X-Api-Key
  15. if c.GetHeader(constants.X_API_KEY) != key {
  16. c.JSON(200, entities.NewErrorResponse(-401, "Unauthorized"))
  17. c.Abort()
  18. return
  19. }
  20. c.Next()
  21. }
  22. }
  23. func (app *App) FetchPluginInstallation() gin.HandlerFunc {
  24. return func(ctx *gin.Context) {
  25. plugin_id := ctx.Request.Header.Get(constants.X_PLUGIN_ID)
  26. if plugin_id == "" {
  27. ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request, plugin_id is required"})
  28. return
  29. }
  30. tenant_id := ctx.Param("tenant_id")
  31. if tenant_id == "" {
  32. ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request, tenant_id is required"})
  33. return
  34. }
  35. // fetch plugin installation
  36. installation, err := db.GetOne[models.PluginInstallation](
  37. db.Equal("tenant_id", tenant_id),
  38. db.Equal("plugin_id", plugin_id),
  39. )
  40. if err != nil {
  41. ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request, " + err.Error()})
  42. return
  43. }
  44. identity, err := plugin_entities.NewPluginUniqueIdentifier(installation.PluginUniqueIdentifier)
  45. if err != nil {
  46. ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request, " + err.Error()})
  47. return
  48. }
  49. ctx.Set(constants.CONTEXT_KEY_PLUGIN_INSTALLATION, installation)
  50. ctx.Set(constants.CONTEXT_KEY_PLUGIN_UNIQUE_IDENTIFIER, identity)
  51. ctx.Next()
  52. }
  53. }
  54. // RedirectPluginInvoke redirects the request to the correct cluster node
  55. func (app *App) RedirectPluginInvoke() gin.HandlerFunc {
  56. return func(ctx *gin.Context) {
  57. // get plugin unique identifier
  58. identity_any, ok := ctx.Get(constants.CONTEXT_KEY_PLUGIN_UNIQUE_IDENTIFIER)
  59. if !ok {
  60. ctx.AbortWithStatusJSON(500, gin.H{"error": "Internal server error, plugin unique identifier not found"})
  61. return
  62. }
  63. identity, ok := identity_any.(plugin_entities.PluginUniqueIdentifier)
  64. if !ok {
  65. ctx.AbortWithStatusJSON(500, gin.H{"error": "Internal server error, failed to parse plugin unique identifier"})
  66. return
  67. }
  68. // check if plugin in current node
  69. if ok, original_error := app.cluster.IsPluginOnCurrentNode(identity); !ok {
  70. app.redirectPluginInvokeByPluginIdentifier(ctx, identity, original_error)
  71. ctx.Abort()
  72. } else {
  73. ctx.Next()
  74. }
  75. }
  76. }
  77. func (app *App) redirectPluginInvokeByPluginIdentifier(
  78. ctx *gin.Context,
  79. plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  80. original_error error,
  81. ) {
  82. // try find the correct node
  83. nodes, err := app.cluster.FetchPluginAvailableNodesById(plugin_unique_identifier.String())
  84. if err != nil {
  85. ctx.AbortWithStatusJSON(
  86. 500,
  87. gin.H{"error": "Internal server error, " + original_error.Error() + ", " + err.Error()},
  88. )
  89. return
  90. } else if len(nodes) == 0 {
  91. ctx.AbortWithStatusJSON(
  92. 404,
  93. gin.H{"error": "No available node, " + original_error.Error()},
  94. )
  95. return
  96. }
  97. // redirect to the correct node
  98. node_id := nodes[0]
  99. status_code, header, body, err := app.cluster.RedirectRequest(node_id, ctx.Request)
  100. if err != nil {
  101. log.Error("redirect request failed: %s", err.Error())
  102. ctx.AbortWithStatusJSON(500, gin.H{"error": "Internal server error"})
  103. return
  104. }
  105. // set status code
  106. ctx.Writer.WriteHeader(status_code)
  107. // set header
  108. for key, values := range header {
  109. for _, value := range values {
  110. ctx.Writer.Header().Set(key, value)
  111. }
  112. }
  113. for {
  114. buf := make([]byte, 1024)
  115. n, err := body.Read(buf)
  116. if err != nil && err != io.EOF {
  117. break
  118. } else if err != nil {
  119. ctx.Writer.Write(buf[:n])
  120. break
  121. }
  122. if n > 0 {
  123. ctx.Writer.Write(buf[:n])
  124. }
  125. }
  126. }
  127. func (app *App) InitClusterID() gin.HandlerFunc {
  128. return func(ctx *gin.Context) {
  129. ctx.Set(constants.CONTEXT_KEY_CLUSTER_ID, app.cluster.ID())
  130. ctx.Next()
  131. }
  132. }