middleware.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. package server
  2. import (
  3. "bytes"
  4. "io"
  5. "github.com/gin-gonic/gin"
  6. "github.com/langgenius/dify-plugin-daemon/internal/server/constants"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  10. )
  11. func CheckingKey(key string) gin.HandlerFunc {
  12. return func(c *gin.Context) {
  13. // get header X-Api-Key
  14. if c.GetHeader(constants.X_API_KEY) != key {
  15. c.JSON(200, entities.NewErrorResponse(-401, "Unauthorized"))
  16. c.Abort()
  17. return
  18. }
  19. c.Next()
  20. }
  21. }
  22. type ginContextReader struct {
  23. reader *bytes.Reader
  24. }
  25. func (g *ginContextReader) Read(p []byte) (n int, err error) {
  26. return g.reader.Read(p)
  27. }
  28. func (g *ginContextReader) Close() error {
  29. return nil
  30. }
  31. // RedirectPluginInvoke redirects the request to the correct cluster node
  32. func (app *App) RedirectPluginInvoke() gin.HandlerFunc {
  33. return func(ctx *gin.Context) {
  34. // get plugin identity
  35. raw, err := ctx.GetRawData()
  36. if err != nil {
  37. ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request"})
  38. return
  39. }
  40. ctx.Request.Body = &ginContextReader{
  41. reader: bytes.NewReader(raw),
  42. }
  43. identity, err := plugin_entities.NewPluginUniqueIdentifier(ctx.Request.Header.Get(constants.X_PLUGIN_IDENTIFIER))
  44. if err != nil {
  45. ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request, " + err.Error()})
  46. return
  47. }
  48. // check if plugin in current node
  49. if !app.cluster.IsPluginNoCurrentNode(
  50. identity,
  51. ) {
  52. app.redirectPluginInvokeByPluginIdentifier(ctx, identity)
  53. ctx.Abort()
  54. } else {
  55. ctx.Next()
  56. }
  57. }
  58. }
  59. func (app *App) redirectPluginInvokeByPluginIdentifier(
  60. ctx *gin.Context,
  61. plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
  62. ) {
  63. // try find the correct node
  64. nodes, err := app.cluster.FetchPluginAvailableNodesById(plugin_unique_identifier.String())
  65. if err != nil {
  66. ctx.AbortWithStatusJSON(500, gin.H{"error": "Internal server error"})
  67. return
  68. } else if len(nodes) == 0 {
  69. ctx.AbortWithStatusJSON(404, gin.H{"error": "No available node"})
  70. return
  71. }
  72. // redirect to the correct node
  73. node_id := nodes[0]
  74. status_code, header, body, err := app.cluster.RedirectRequest(node_id, ctx.Request)
  75. if err != nil {
  76. log.Error("redirect request failed: %s", err.Error())
  77. ctx.AbortWithStatusJSON(500, gin.H{"error": "Internal server error"})
  78. return
  79. }
  80. // set status code
  81. ctx.Writer.WriteHeader(status_code)
  82. // set header
  83. for key, values := range header {
  84. for _, value := range values {
  85. ctx.Writer.Header().Set(key, value)
  86. }
  87. }
  88. for {
  89. buf := make([]byte, 1024)
  90. n, err := body.Read(buf)
  91. if err != nil && err != io.EOF {
  92. break
  93. } else if err != nil {
  94. ctx.Writer.Write(buf[:n])
  95. break
  96. }
  97. if n > 0 {
  98. ctx.Writer.Write(buf[:n])
  99. }
  100. }
  101. }
  102. func (app *App) InitClusterID() gin.HandlerFunc {
  103. return func(ctx *gin.Context) {
  104. ctx.Set("cluster_id", app.cluster.ID())
  105. ctx.Next()
  106. }
  107. }