middleware.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. package server
  2. import (
  3. "bytes"
  4. "io"
  5. "github.com/gin-gonic/gin"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  7. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  8. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  9. )
  10. func CheckingKey(key string) gin.HandlerFunc {
  11. return func(c *gin.Context) {
  12. // get header X-Api-Key
  13. if c.GetHeader("X-Api-Key") != key {
  14. c.JSON(401, gin.H{"error": "Unauthorized"})
  15. c.Abort()
  16. return
  17. }
  18. c.Next()
  19. }
  20. }
  21. type ginContextReader struct {
  22. reader *bytes.Reader
  23. }
  24. func (g *ginContextReader) Read(p []byte) (n int, err error) {
  25. return g.reader.Read(p)
  26. }
  27. func (g *ginContextReader) Close() error {
  28. return nil
  29. }
  30. // RedirectPluginInvoke redirects the request to the correct cluster node
  31. func (app *App) RedirectPluginInvoke() gin.HandlerFunc {
  32. return func(ctx *gin.Context) {
  33. // get plugin identity
  34. raw, err := ctx.GetRawData()
  35. if err != nil {
  36. ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request"})
  37. return
  38. }
  39. ctx.Request.Body = &ginContextReader{
  40. reader: bytes.NewReader(raw),
  41. }
  42. identity, err := parser.UnmarshalJsonBytes[plugin_entities.InvokePluginPluginIdentity](raw)
  43. if err != nil {
  44. ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request"})
  45. return
  46. }
  47. plugin_id := parser.MarshalPluginIdentity(identity.PluginName, identity.PluginVersion)
  48. // check if plugin in current node
  49. if !app.cluster.IsPluginNoCurrentNode(
  50. plugin_id,
  51. ) {
  52. app.Redirect(ctx, plugin_id)
  53. ctx.Abort()
  54. } else {
  55. ctx.Next()
  56. }
  57. }
  58. }
  59. func (app *App) Redirect(ctx *gin.Context, plugin_id string) {
  60. // try find the correct node
  61. nodes, err := app.cluster.FetchPluginAvailableNodesById(plugin_id)
  62. if err != nil {
  63. ctx.AbortWithStatusJSON(500, gin.H{"error": "Internal server error"})
  64. log.Error("fetch plugin available nodes failed: %s", err.Error())
  65. return
  66. } else if len(nodes) == 0 {
  67. ctx.AbortWithStatusJSON(404, gin.H{"error": "No available node"})
  68. log.Error("no available node")
  69. return
  70. }
  71. // redirect to the correct node
  72. node_id := nodes[0]
  73. status_code, header, body, err := app.cluster.RedirectRequest(node_id, ctx.Request)
  74. if err != nil {
  75. log.Error("redirect request failed: %s", err.Error())
  76. ctx.AbortWithStatusJSON(500, gin.H{"error": "Internal server error"})
  77. return
  78. }
  79. // set status code
  80. ctx.Writer.WriteHeader(status_code)
  81. // set header
  82. for key, values := range header {
  83. for _, value := range values {
  84. ctx.Writer.Header().Set(key, value)
  85. }
  86. }
  87. for {
  88. buf := make([]byte, 1024)
  89. n, err := body.Read(buf)
  90. if err != nil && err != io.EOF {
  91. break
  92. } else if err != nil {
  93. ctx.Writer.Write(buf[:n])
  94. break
  95. }
  96. if n > 0 {
  97. ctx.Writer.Write(buf[:n])
  98. }
  99. }
  100. }