Browse Source

feat: redirect requests

Yeuoly 11 months ago
parent
commit
c86d0060c4

+ 2 - 3
internal/cluster/cluster.go

@@ -19,13 +19,12 @@ type Cluster struct {
 	// port is the health check port of the cluster
 	port uint16
 
-	// plugins stores all the plugin life time of the cluster
+	// plugins stores all the plugin life time of the current node
 	plugins     mapping.Map[string, *pluginLifeTime]
 	plugin_lock sync.RWMutex
 
 	// nodes stores all the nodes of the cluster
-	nodes     mapping.Map[string, node]
-	node_lock sync.RWMutex
+	nodes mapping.Map[string, node]
 
 	// signals for waiting for the cluster to stop
 	stop_chan chan bool

+ 7 - 7
internal/cluster/node.go

@@ -82,9 +82,6 @@ func (c *Cluster) updateNodeStatus() error {
 	}
 
 	// update self nodes map
-	c.node_lock.Lock()
-	defer c.node_lock.Unlock()
-
 	c.nodes.Clear()
 	for node_id, node := range nodes {
 		c.nodes.Store(node_id, node)
@@ -109,8 +106,8 @@ func (c *Cluster) GetNodes() (map[string]node, error) {
 	return nodes, nil
 }
 
-// FetchPluginAvailableNodes fetches the available nodes of the given plugin
-func (c *Cluster) FetchPluginAvailableNodes(hashed_plugin_id string) ([]string, error) {
+// FetchPluginAvailableNodesByHashedId fetches the available nodes of the given plugin
+func (c *Cluster) FetchPluginAvailableNodesByHashedId(hashed_plugin_id string) ([]string, error) {
 	states, err := cache.ScanMap[entities.PluginRuntimeState](
 		PLUGIN_STATE_MAP_KEY, c.getScanPluginsByIdKey(hashed_plugin_id),
 	)
@@ -132,6 +129,11 @@ func (c *Cluster) FetchPluginAvailableNodes(hashed_plugin_id string) ([]string,
 	return nodes, nil
 }
 
+func (c *Cluster) FetchPluginAvailableNodesById(plugin_id string) ([]string, error) {
+	hashed_plugin_id := entities.HashedIdentity(plugin_id)
+	return c.FetchPluginAvailableNodesByHashedId(hashed_plugin_id)
+}
+
 func (c *Cluster) IsMaster() bool {
 	return c.i_am_master
 }
@@ -191,9 +193,7 @@ func (c *Cluster) gcNode(node_id string) error {
 	}
 
 	// remove the node from the cluster
-	c.node_lock.Lock()
 	c.nodes.Delete(node_id)
-	c.node_lock.Unlock()
 
 	if err := c.LockNodeStatus(node_id); err != nil {
 		return err

+ 5 - 0
internal/cluster/plugin.go

@@ -240,3 +240,8 @@ func (c *Cluster) autoGCPlugins() error {
 		},
 	)
 }
+
+func (c *Cluster) IsPluginNoCurrentNode(identity string) bool {
+	_, ok := c.plugins.Load(identity)
+	return ok
+}

+ 4 - 4
internal/cluster/plugin_test.go

@@ -74,7 +74,7 @@ func TestPluginScheduleLifetime(t *testing.T) {
 		return
 	}
 
-	nodes, err := cluster[0].FetchPluginAvailableNodes(hashed_identity)
+	nodes, err := cluster[0].FetchPluginAvailableNodesByHashedId(hashed_identity)
 	if err != nil {
 		t.Errorf("fetch plugin available nodes failed: %v", err)
 		return
@@ -97,7 +97,7 @@ func TestPluginScheduleLifetime(t *testing.T) {
 	time.Sleep(time.Second * 1)
 
 	// check if the plugin is stopped
-	nodes, err = cluster[0].FetchPluginAvailableNodes(hashed_identity)
+	nodes, err = cluster[0].FetchPluginAvailableNodesByHashedId(hashed_identity)
 	if err != nil {
 		t.Errorf("fetch plugin available nodes failed: %v", err)
 		return
@@ -188,7 +188,7 @@ func TestPluginScheduleWhenMasterClusterShutdown(t *testing.T) {
 	for !done {
 		select {
 		case <-ticker.C:
-			nodes, err := cluster[master_idx].FetchPluginAvailableNodes(hashed_identity)
+			nodes, err := cluster[master_idx].FetchPluginAvailableNodesByHashedId(hashed_identity)
 			if err != nil {
 				t.Errorf("fetch plugin available nodes failed: %v", err)
 				return
@@ -209,7 +209,7 @@ func TestPluginScheduleWhenMasterClusterShutdown(t *testing.T) {
 		return
 	}
 
-	nodes, err := cluster[1-master_idx].FetchPluginAvailableNodes(hashed_identity)
+	nodes, err := cluster[1-master_idx].FetchPluginAvailableNodesByHashedId(hashed_identity)
 	if err != nil {
 		t.Errorf("fetch plugin available nodes failed: %v", err)
 		return

+ 48 - 0
internal/cluster/redirect.go

@@ -0,0 +1,48 @@
+package cluster
+
+import (
+	"errors"
+	"io"
+	"net/http"
+	"strconv"
+)
+
+// RedirectRequest redirects the request to the specified node
+func (c *Cluster) RedirectRequest(
+	node_id string, request *http.Request,
+) (int, http.Header, io.ReadCloser, error) {
+	node, ok := c.nodes.Load(node_id)
+	if !ok {
+		return 0, nil, nil, errors.New("node not found")
+	}
+
+	ips := c.SortIps(node)
+	if len(ips) == 0 {
+		return 0, nil, nil, errors.New("no available ip found")
+	}
+
+	ip := ips[0]
+
+	// create a new request
+	redirected_request, err := http.NewRequest(
+		request.Method,
+		"http://"+ip.Address+":"+strconv.FormatUint(uint64(c.port), 10)+request.URL.Path,
+		request.Body,
+	)
+
+	if err != nil {
+		return 0, nil, nil, err
+	}
+
+	// copy headers
+	redirected_request.Header = request.Header
+
+	client := http.DefaultClient
+	resp, err := client.Do(redirected_request)
+
+	if err != nil {
+		return 0, nil, nil, err
+	}
+
+	return resp.StatusCode, resp.Header, resp.Body, nil
+}

+ 2 - 1
internal/core/plugin_manager/manager.go

@@ -19,10 +19,11 @@ var (
 	manager *PluginManager
 )
 
-func InitGlobalPluginManager(cluster *cluster.Cluster) {
+func InitGlobalPluginManager(cluster *cluster.Cluster, configuration *app.Config) {
 	manager = &PluginManager{
 		cluster: cluster,
 	}
+	manager.Init(configuration)
 }
 
 func GetGlobalPluginManager() *PluginManager {

+ 1 - 3
internal/server/app.go

@@ -2,10 +2,8 @@ package server
 
 import (
 	"github.com/langgenius/dify-plugin-daemon/internal/cluster"
-	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
 )
 
 type App struct {
-	plugin_manager *plugin_manager.PluginManager
-	cluster        *cluster.Cluster
+	cluster *cluster.Cluster
 }

+ 67 - 12
internal/server/http.go

@@ -8,24 +8,79 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 )
 
-func server(config *app.Config) {
+func (app *App) server(config *app.Config) {
 	engine := gin.Default()
 
 	engine.GET("/health/check", controllers.HealthCheck)
 
-	engine.POST("/plugin/tool/invoke", CheckingKey(config.PluginInnerApiKey), controllers.InvokeTool)
-	engine.POST("/plugin/tool/validate_credentials", CheckingKey(config.PluginInnerApiKey), controllers.ValidateToolCredentials)
-	engine.POST("/plugin/llm/invoke", CheckingKey(config.PluginInnerApiKey), controllers.InvokeLLM)
-	engine.POST("/plugin/text_embedding/invoke", CheckingKey(config.PluginInnerApiKey), controllers.InvokeTextEmbedding)
-	engine.POST("/plugin/rerank/invoke", CheckingKey(config.PluginInnerApiKey), controllers.InvokeRerank)
-	engine.POST("/plugin/tts/invoke", CheckingKey(config.PluginInnerApiKey), controllers.InvokeTTS)
-	engine.POST("/plugin/speech2text/invoke", CheckingKey(config.PluginInnerApiKey), controllers.InvokeSpeech2Text)
-	engine.POST("/plugin/moderation/invoke", CheckingKey(config.PluginInnerApiKey), controllers.InvokeModeration)
-	engine.POST("/plugin/model/validate_provider_credentials", CheckingKey(config.PluginInnerApiKey), controllers.ValidateProviderCredentials)
-	engine.POST("/plugin/model/validate_model_credentials", CheckingKey(config.PluginInnerApiKey), controllers.ValidateModelCredentials)
+	engine.POST(
+		"/plugin/tool/invoke",
+		CheckingKey(config.PluginInnerApiKey),
+		app.Redirect(),
+		controllers.InvokeTool,
+	)
+	engine.POST(
+		"/plugin/tool/validate_credentials",
+		CheckingKey(config.PluginInnerApiKey),
+		app.Redirect(),
+		controllers.ValidateToolCredentials,
+	)
+	engine.POST(
+		"/plugin/llm/invoke",
+		CheckingKey(config.PluginInnerApiKey),
+		app.Redirect(),
+		controllers.InvokeLLM,
+	)
+	engine.POST(
+		"/plugin/text_embedding/invoke",
+		CheckingKey(config.PluginInnerApiKey),
+		app.Redirect(),
+		controllers.InvokeTextEmbedding,
+	)
+	engine.POST(
+		"/plugin/rerank/invoke",
+		CheckingKey(config.PluginInnerApiKey),
+		app.Redirect(),
+		controllers.InvokeRerank,
+	)
+	engine.POST(
+		"/plugin/tts/invoke",
+		CheckingKey(config.PluginInnerApiKey),
+		app.Redirect(),
+		controllers.InvokeTTS,
+	)
+	engine.POST(
+		"/plugin/speech2text/invoke",
+		CheckingKey(config.PluginInnerApiKey),
+		app.Redirect(),
+		controllers.InvokeSpeech2Text,
+	)
+	engine.POST(
+		"/plugin/moderation/invoke",
+		CheckingKey(config.PluginInnerApiKey),
+		app.Redirect(),
+		controllers.InvokeModeration,
+	)
+	engine.POST(
+		"/plugin/model/validate_provider_credentials",
+		CheckingKey(config.PluginInnerApiKey),
+		app.Redirect(),
+		controllers.ValidateProviderCredentials,
+	)
+	engine.POST(
+		"/plugin/model/validate_model_credentials",
+		CheckingKey(config.PluginInnerApiKey),
+		app.Redirect(),
+		controllers.ValidateModelCredentials,
+	)
 
 	if config.PluginRemoteInstallingEnabled {
-		engine.POST("/plugin/debugging/key", CheckingKey(config.PluginInnerApiKey), controllers.GetRemoteDebuggingKey)
+		engine.POST(
+			"/plugin/debugging/key",
+			CheckingKey(config.PluginInnerApiKey),
+			app.Redirect(),
+			controllers.GetRemoteDebuggingKey,
+		)
 	}
 
 	engine.Run(fmt.Sprintf(":%d", config.ServerPort))

+ 99 - 1
internal/server/middleware.go

@@ -1,6 +1,14 @@
 package server
 
-import "github.com/gin-gonic/gin"
+import (
+	"bytes"
+	"io"
+
+	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+)
 
 func CheckingKey(key string) gin.HandlerFunc {
 	return func(c *gin.Context) {
@@ -14,3 +22,93 @@ func CheckingKey(key string) gin.HandlerFunc {
 		c.Next()
 	}
 }
+
+type ginContextReader struct {
+	reader *bytes.Reader
+}
+
+func (g *ginContextReader) Read(p []byte) (n int, err error) {
+	return g.reader.Read(p)
+}
+
+func (g *ginContextReader) Close() error {
+	return nil
+}
+
+// Redirect redirects the request to the correct cluster node
+func (app *App) Redirect() gin.HandlerFunc {
+	return func(ctx *gin.Context) {
+		// get plugin identity
+		raw, err := ctx.GetRawData()
+		if err != nil {
+			ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request"})
+			return
+		}
+
+		ctx.Request.Body = &ginContextReader{
+			reader: bytes.NewReader(raw),
+		}
+
+		identity, err := parser.UnmarshalJsonBytes[plugin_entities.InvokePluginPluginIdentity](raw)
+
+		if err != nil {
+			ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request"})
+			return
+		}
+
+		plugin_id := parser.MarshalPluginIdentity(identity.PluginName, identity.PluginVersion)
+
+		// check if plugin in current node
+		if !app.cluster.IsPluginNoCurrentNode(
+			plugin_id,
+		) {
+			// try find the correct node
+			nodes, err := app.cluster.FetchPluginAvailableNodesById(plugin_id)
+			if err != nil {
+				ctx.AbortWithStatusJSON(500, gin.H{"error": "Internal server error"})
+				log.Error("fetch plugin available nodes failed: %s", err.Error())
+				return
+			} else if len(nodes) == 0 {
+				ctx.AbortWithStatusJSON(404, gin.H{"error": "No available node"})
+				log.Error("no available node")
+				return
+			}
+
+			// redirect to the correct node
+			node_id := nodes[0]
+			status_code, header, body, err := app.cluster.RedirectRequest(node_id, ctx.Request)
+			if err != nil {
+				log.Error("redirect request failed: %s", err.Error())
+				ctx.AbortWithStatusJSON(500, gin.H{"error": "Internal server error"})
+				return
+			}
+
+			// set status code
+			ctx.Writer.WriteHeader(status_code)
+
+			// set header
+			for key, values := range header {
+				for _, value := range values {
+					ctx.Writer.Header().Set(key, value)
+				}
+			}
+
+			for {
+				buf := make([]byte, 1024)
+				n, err := body.Read(buf)
+				if err != nil && err != io.EOF {
+					break
+				} else if err != nil {
+					ctx.Writer.Write(buf[:n])
+					break
+				}
+
+				if n > 0 {
+					ctx.Writer.Write(buf[:n])
+				}
+			}
+		} else {
+			ctx.Next()
+		}
+	}
+}

+ 2 - 4
internal/server/server.go

@@ -11,8 +11,6 @@ import (
 
 func (a *App) Run(config *app.Config) {
 	a.cluster = cluster.NewCluster(config)
-	plugin_manager.InitGlobalPluginManager(a.cluster)
-	a.plugin_manager = plugin_manager.GetGlobalPluginManager()
 
 	// init routine pool
 	routine.InitPool(config.RoutinePoolSize)
@@ -24,11 +22,11 @@ func (a *App) Run(config *app.Config) {
 	process.Init(config)
 
 	// init plugin daemon
-	a.plugin_manager.Init(config)
+	plugin_manager.InitGlobalPluginManager(a.cluster, config)
 
 	// launch cluster
 	a.cluster.Launch()
 
 	// start http server
-	server(config)
+	a.server(config)
 }

+ 13 - 4
internal/types/entities/plugin_entities/request.go

@@ -1,9 +1,18 @@
 package plugin_entities
 
-type InvokePluginRequest[T any] struct {
+type InvokePluginPluginIdentity struct {
 	PluginName    string `json:"plugin_name" binding:"required"`
 	PluginVersion string `json:"plugin_version" binding:"required"`
-	TenantId      string `json:"tenant_id" binding:"required"`
-	UserId        string `json:"user_id" binding:"required"`
-	Data          T      `json:"data" binding:"required"`
+}
+
+type InvokePluginUserIdentity struct {
+	TenantId string `json:"tenant_id" binding:"required"`
+	UserId   string `json:"user_id" binding:"required"`
+}
+
+type InvokePluginRequest[T any] struct {
+	InvokePluginPluginIdentity
+	InvokePluginUserIdentity
+
+	Data T `json:"data" binding:"required"`
 }