middleware.go 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. package server
  2. import (
  3. "errors"
  4. "io"
  5. "github.com/gin-gonic/gin"
  6. "github.com/langgenius/dify-plugin-daemon/internal/db"
  7. "github.com/langgenius/dify-plugin-daemon/internal/server/constants"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/exception"
  10. "github.com/langgenius/dify-plugin-daemon/internal/types/models"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  12. )
  13. func CheckingKey(key string) gin.HandlerFunc {
  14. return func(c *gin.Context) {
  15. // get header X-Api-Key
  16. if c.GetHeader(constants.X_API_KEY) != key {
  17. c.JSON(200, exception.UnauthorizedError().ToResponse())
  18. c.Abort()
  19. return
  20. }
  21. c.Next()
  22. }
  23. }
  24. func (app *App) FetchPluginInstallation() gin.HandlerFunc {
  25. return func(ctx *gin.Context) {
  26. pluginId := ctx.Request.Header.Get(constants.X_PLUGIN_ID)
  27. if pluginId == "" {
  28. ctx.AbortWithStatusJSON(400, exception.BadRequestError(errors.New("plugin_id is required")).ToResponse())
  29. return
  30. }
  31. tenantId := ctx.Param("tenant_id")
  32. if tenantId == "" {
  33. ctx.AbortWithStatusJSON(400, exception.BadRequestError(errors.New("tenant_id is required")).ToResponse())
  34. return
  35. }
  36. // fetch plugin installation
  37. installation, err := db.GetOne[models.PluginInstallation](
  38. db.Equal("tenant_id", tenantId),
  39. db.Equal("plugin_id", pluginId),
  40. )
  41. if err == db.ErrDatabaseNotFound {
  42. ctx.AbortWithStatusJSON(404, exception.ErrPluginNotFound().ToResponse())
  43. return
  44. }
  45. if err != nil {
  46. ctx.AbortWithStatusJSON(500, exception.InternalServerError(err).ToResponse())
  47. return
  48. }
  49. identity, err := plugin_entities.NewPluginUniqueIdentifier(installation.PluginUniqueIdentifier)
  50. if err != nil {
  51. ctx.AbortWithStatusJSON(400, exception.PluginUniqueIdentifierError(err).ToResponse())
  52. return
  53. }
  54. ctx.Set(constants.CONTEXT_KEY_PLUGIN_INSTALLATION, installation)
  55. ctx.Set(constants.CONTEXT_KEY_PLUGIN_UNIQUE_IDENTIFIER, identity)
  56. ctx.Next()
  57. }
  58. }
  59. // RedirectPluginInvoke redirects the request to the correct cluster node
  60. func (app *App) RedirectPluginInvoke() gin.HandlerFunc {
  61. return func(ctx *gin.Context) {
  62. // get plugin unique identifier
  63. identityAny, ok := ctx.Get(constants.CONTEXT_KEY_PLUGIN_UNIQUE_IDENTIFIER)
  64. if !ok {
  65. ctx.AbortWithStatusJSON(
  66. 500,
  67. exception.InternalServerError(errors.New("plugin unique identifier not found")).ToResponse(),
  68. )
  69. return
  70. }
  71. identity, ok := identityAny.(plugin_entities.PluginUniqueIdentifier)
  72. if !ok {
  73. ctx.AbortWithStatusJSON(
  74. 500,
  75. exception.InternalServerError(errors.New("failed to parse plugin unique identifier")).ToResponse(),
  76. )
  77. return
  78. }
  79. // check if plugin in current node
  80. if ok, originalError := app.cluster.IsPluginOnCurrentNode(identity); !ok {
  81. app.redirectPluginInvokeByPluginIdentifier(ctx, identity, originalError)
  82. ctx.Abort()
  83. } else {
  84. ctx.Next()
  85. }
  86. }
  87. }
  88. func (app *App) redirectPluginInvokeByPluginIdentifier(
  89. ctx *gin.Context,
  90. plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  91. originalError error,
  92. ) {
  93. // try find the correct node
  94. nodes, err := app.cluster.FetchPluginAvailableNodesById(plugin_unique_identifier.String())
  95. if err != nil {
  96. ctx.AbortWithStatusJSON(
  97. 500,
  98. exception.InternalServerError(
  99. errors.New("failed to fetch plugin available nodes, "+originalError.Error()+", "+err.Error()),
  100. ).ToResponse(),
  101. )
  102. return
  103. } else if len(nodes) == 0 {
  104. ctx.AbortWithStatusJSON(
  105. 404,
  106. exception.InternalServerError(
  107. errors.New("no available node, "+originalError.Error()),
  108. ).ToResponse(),
  109. )
  110. return
  111. }
  112. // redirect to the correct node
  113. nodeId := nodes[0]
  114. statusCode, header, body, err := app.cluster.RedirectRequest(nodeId, ctx.Request)
  115. if err != nil {
  116. log.Error("redirect request failed: %s", err.Error())
  117. ctx.AbortWithStatusJSON(
  118. 500,
  119. exception.InternalServerError(errors.New("redirect request failed: "+err.Error())).ToResponse(),
  120. )
  121. return
  122. }
  123. // set status code
  124. ctx.Writer.WriteHeader(statusCode)
  125. // set header
  126. for key, values := range header {
  127. for _, value := range values {
  128. ctx.Writer.Header().Set(key, value)
  129. }
  130. }
  131. for {
  132. buf := make([]byte, 1024)
  133. n, err := body.Read(buf)
  134. if err != nil && err != io.EOF {
  135. break
  136. } else if err != nil {
  137. ctx.Writer.Write(buf[:n])
  138. break
  139. }
  140. if n > 0 {
  141. ctx.Writer.Write(buf[:n])
  142. }
  143. }
  144. }
  145. func (app *App) InitClusterID() gin.HandlerFunc {
  146. return func(ctx *gin.Context) {
  147. ctx.Set(constants.CONTEXT_KEY_CLUSTER_ID, app.cluster.ID())
  148. ctx.Next()
  149. }
  150. }