浏览代码

refactor: session manager

Yeuoly 11 月之前
父节点
当前提交
8294a61d4c

+ 4 - 0
internal/cluster/cluster.go

@@ -77,6 +77,10 @@ func (c *Cluster) Close() error {
 	return nil
 }
 
+func (c *Cluster) ID() string {
+	return c.id
+}
+
 // trigger for master event
 func (c *Cluster) notifyBecomeMaster() {
 	if atomic.LoadInt32(&c.stopped) == 1 {

+ 1 - 0
internal/core/plugin_manager/aws_manager/init.go

@@ -0,0 +1 @@
+package aws_manager

+ 1 - 0
internal/core/plugin_manager/aws_manager/io.go

@@ -47,6 +47,7 @@ func (r *AWSPluginRuntime) Write(session_id string, data []byte) {
 	}
 	req.Header.Set("Content-Type", "application/json")
 	req.Header.Set("Accept", "text/event-stream")
+	req.Header.Set("Dify-Plugin-Session-ID", session_id)
 
 	routine.Submit(func() {
 		// remove the session from listeners

+ 38 - 7
internal/core/session_manager/session.go

@@ -6,6 +6,8 @@ import (
 
 	"github.com/google/uuid"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 )
 
@@ -22,20 +24,43 @@ type Session struct {
 	tenant_id       string
 	user_id         string
 	plugin_identity string
+	cluster_id      string
 }
 
-func NewSession(tenant_id string, user_id string, plugin_identity string) *Session {
+type SessionInfo struct {
+	TenantID       string `json:"tenant_id"`
+	UserID         string `json:"user_id"`
+	PluginIdentity string `json:"plugin_identity"`
+	ClusterID      string `json:"cluster_id"`
+}
+
+const (
+	SESSION_INFO_MAP_KEY = "session_info"
+)
+
+func NewSession(tenant_id string, user_id string, plugin_identity string, cluster_id string) *Session {
 	s := &Session{
 		id:              uuid.New().String(),
 		tenant_id:       tenant_id,
 		user_id:         user_id,
 		plugin_identity: plugin_identity,
+		cluster_id:      cluster_id,
 	}
 
 	session_lock.Lock()
-	defer session_lock.Unlock()
-
 	sessions[s.id] = s
+	session_lock.Unlock()
+
+	session_info := &SessionInfo{
+		TenantID:       tenant_id,
+		UserID:         user_id,
+		PluginIdentity: plugin_identity,
+		ClusterID:      cluster_id,
+	}
+
+	if err := cache.SetMapOneField(SESSION_INFO_MAP_KEY, s.id, session_info); err != nil {
+		log.Error("set session info to cache failed, %s", err)
+	}
 
 	return s
 }
@@ -49,16 +74,22 @@ func GetSession(id string) *Session {
 
 func DeleteSession(id string) {
 	session_lock.Lock()
-	defer session_lock.Unlock()
-
 	delete(sessions, id)
+	session_lock.Unlock()
+
+	if err := cache.DelMapField(SESSION_INFO_MAP_KEY, id); err != nil {
+		log.Error("delete session info from cache failed, %s", err)
+	}
 }
 
 func (s *Session) Close() {
 	session_lock.Lock()
-	defer session_lock.Unlock()
-
 	delete(sessions, s.id)
+	session_lock.Unlock()
+
+	if err := cache.DelMapField(SESSION_INFO_MAP_KEY, s.id); err != nil {
+		log.Error("delete session info from cache failed, %s", err)
+	}
 }
 
 func (s *Session) ID() string {

+ 37 - 78
internal/server/http_server.go

@@ -13,86 +13,11 @@ import (
 
 func (app *App) server(config *app.Config) func() {
 	engine := gin.Default()
-
 	engine.GET("/health/check", controllers.HealthCheck)
 
-	engine.POST(
-		"/plugin/tool/invoke",
-		CheckingKey(config.PluginInnerApiKey),
-		app.RedirectPluginInvoke(),
-		controllers.InvokeTool,
-	)
-	engine.POST(
-		"/plugin/tool/validate_credentials",
-		CheckingKey(config.PluginInnerApiKey),
-		app.RedirectPluginInvoke(),
-		controllers.ValidateToolCredentials,
-	)
-	engine.POST(
-		"/plugin/llm/invoke",
-		CheckingKey(config.PluginInnerApiKey),
-		app.RedirectPluginInvoke(),
-		controllers.InvokeLLM,
-	)
-	engine.POST(
-		"/plugin/text_embedding/invoke",
-		CheckingKey(config.PluginInnerApiKey),
-		app.RedirectPluginInvoke(),
-		controllers.InvokeTextEmbedding,
-	)
-	engine.POST(
-		"/plugin/rerank/invoke",
-		CheckingKey(config.PluginInnerApiKey),
-		app.RedirectPluginInvoke(),
-		controllers.InvokeRerank,
-	)
-	engine.POST(
-		"/plugin/tts/invoke",
-		CheckingKey(config.PluginInnerApiKey),
-		app.RedirectPluginInvoke(),
-		controllers.InvokeTTS,
-	)
-	engine.POST(
-		"/plugin/speech2text/invoke",
-		CheckingKey(config.PluginInnerApiKey),
-		app.RedirectPluginInvoke(),
-		controllers.InvokeSpeech2Text,
-	)
-	engine.POST(
-		"/plugin/moderation/invoke",
-		CheckingKey(config.PluginInnerApiKey),
-		app.RedirectPluginInvoke(),
-		controllers.InvokeModeration,
-	)
-	engine.POST(
-		"/plugin/model/validate_provider_credentials",
-		CheckingKey(config.PluginInnerApiKey),
-		app.RedirectPluginInvoke(),
-		controllers.ValidateProviderCredentials,
-	)
-	engine.POST(
-		"/plugin/model/validate_model_credentials",
-		CheckingKey(config.PluginInnerApiKey),
-		app.RedirectPluginInvoke(),
-		controllers.ValidateModelCredentials,
-	)
-
-	if config.PluginRemoteInstallingEnabled {
-		engine.POST(
-			"/plugin/debugging/key",
-			CheckingKey(config.PluginInnerApiKey),
-			controllers.GetRemoteDebuggingKey,
-		)
-	}
-
-	if config.PluginWebhookEnabled {
-		engine.HEAD("/webhook/:hook_id/*path", app.Webhook())
-		engine.POST("/webhook/:hook_id/*path", app.Webhook())
-		engine.GET("/webhook/:hook_id/*path", app.Webhook())
-		engine.PUT("/webhook/:hook_id/*path", app.Webhook())
-		engine.DELETE("/webhook/:hook_id/*path", app.Webhook())
-		engine.OPTIONS("/webhook/:hook_id/*path", app.Webhook())
-	}
+	app.pluginInvokeGroup(engine.Group("/plugin"), config)
+	app.remoteDebuggingGroup(engine.Group("/plugin/debugging"), config)
+	app.webhookGroup(engine.Group("/webhook"), config)
 
 	srv := &http.Server{
 		Addr:    fmt.Sprintf(":%d", config.ServerPort),
@@ -111,3 +36,37 @@ func (app *App) server(config *app.Config) func() {
 		}
 	}
 }
+
+func (app *App) pluginInvokeGroup(group *gin.RouterGroup, config *app.Config) {
+	group.Use(CheckingKey(config.PluginInnerApiKey))
+	group.Use(app.RedirectPluginInvoke())
+	group.Use(app.InitClusterID())
+
+	group.POST("/tool/invoke", controllers.InvokeTool)
+	group.POST("/tool/validate_credentials", controllers.ValidateToolCredentials)
+	group.POST("/llm/invoke", controllers.InvokeLLM)
+	group.POST("/text_embedding/invoke", controllers.InvokeTextEmbedding)
+	group.POST("/rerank/invoke", controllers.InvokeRerank)
+	group.POST("/tts/invoke", controllers.InvokeTTS)
+	group.POST("/speech2text/invoke", controllers.InvokeSpeech2Text)
+	group.POST("/moderation/invoke", controllers.InvokeModeration)
+	group.POST("/model/validate_provider_credentials", controllers.ValidateProviderCredentials)
+	group.POST("/model/validate_model_credentials", controllers.ValidateModelCredentials)
+}
+
+func (app *App) remoteDebuggingGroup(group *gin.RouterGroup, config *app.Config) {
+	if config.PluginRemoteInstallingEnabled {
+		group.POST("/key", CheckingKey(config.PluginInnerApiKey), controllers.GetRemoteDebuggingKey)
+	}
+}
+
+func (app *App) webhookGroup(group *gin.RouterGroup, config *app.Config) {
+	if config.PluginWebhookEnabled {
+		group.HEAD("/:hook_id/*path", app.Webhook())
+		group.POST("/:hook_id/*path", app.Webhook())
+		group.GET("/:hook_id/*path", app.Webhook())
+		group.PUT("/:hook_id/*path", app.Webhook())
+		group.DELETE("/:hook_id/*path", app.Webhook())
+		group.OPTIONS("/:hook_id/*path", app.Webhook())
+	}
+}

+ 7 - 0
internal/server/middleware.go

@@ -117,3 +117,10 @@ func (app *App) Redirect(ctx *gin.Context, plugin_id string) {
 		}
 	}
 }
+
+func (app *App) InitClusterID() gin.HandlerFunc {
+	return func(ctx *gin.Context) {
+		ctx.Set("cluster_id", app.cluster.ID())
+		ctx.Next()
+	}
+}

+ 2 - 2
internal/service/invoke_model.go

@@ -11,7 +11,7 @@ import (
 
 func InvokeTool(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTool], ctx *gin.Context) {
 	// create session
-	session := createSession(r)
+	session := createSession(r, ctx.GetString("cluster_id"))
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[tool_entities.ToolResponseChunk], error) {
@@ -21,7 +21,7 @@ func InvokeTool(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeToo
 
 func ValidateToolCredentials(r *plugin_entities.InvokePluginRequest[requests.RequestValidateToolCredentials], ctx *gin.Context) {
 	// create session
-	session := createSession(r)
+	session := createSession(r, ctx.GetString("cluster_id"))
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[tool_entities.ValidateCredentialsResult], error) {

+ 10 - 10
internal/service/invoke_tool.go

@@ -12,8 +12,8 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
-func createSession[T any](r *plugin_entities.InvokePluginRequest[T]) *session_manager.Session {
-	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+func createSession[T any](r *plugin_entities.InvokePluginRequest[T], cluster_id string) *session_manager.Session {
+	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion), cluster_id)
 	runtime := plugin_manager.GetGlobalPluginManager().Get(session.PluginIdentity())
 	session.BindRuntime(runtime)
 	return session
@@ -21,7 +21,7 @@ func createSession[T any](r *plugin_entities.InvokePluginRequest[T]) *session_ma
 
 func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM], ctx *gin.Context) {
 	// create session
-	session := createSession(r)
+	session := createSession(r, ctx.GetString("cluster_id"))
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.LLMResultChunk], error) {
@@ -31,7 +31,7 @@ func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM]
 
 func InvokeTextEmbedding(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding], ctx *gin.Context) {
 	// create session
-	session := createSession(r)
+	session := createSession(r, ctx.GetString("cluster_id"))
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.TextEmbeddingResult], error) {
@@ -41,7 +41,7 @@ func InvokeTextEmbedding(r *plugin_entities.InvokePluginRequest[requests.Request
 
 func InvokeRerank(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank], ctx *gin.Context) {
 	// create session
-	session := createSession(r)
+	session := createSession(r, ctx.GetString("cluster_id"))
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.RerankResult], error) {
@@ -51,7 +51,7 @@ func InvokeRerank(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeR
 
 func InvokeTTS(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS], ctx *gin.Context) {
 	// create session
-	session := createSession(r)
+	session := createSession(r, ctx.GetString("cluster_id"))
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.TTSResult], error) {
@@ -61,7 +61,7 @@ func InvokeTTS(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS]
 
 func InvokeSpeech2Text(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text], ctx *gin.Context) {
 	// create session
-	session := createSession(r)
+	session := createSession(r, ctx.GetString("cluster_id"))
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.Speech2TextResult], error) {
@@ -71,7 +71,7 @@ func InvokeSpeech2Text(r *plugin_entities.InvokePluginRequest[requests.RequestIn
 
 func InvokeModeration(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration], ctx *gin.Context) {
 	// create session
-	session := createSession(r)
+	session := createSession(r, ctx.GetString("cluster_id"))
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.ModerationResult], error) {
@@ -81,7 +81,7 @@ func InvokeModeration(r *plugin_entities.InvokePluginRequest[requests.RequestInv
 
 func ValidateProviderCredentials(r *plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials], ctx *gin.Context) {
 	// create session
-	session := createSession(r)
+	session := createSession(r, ctx.GetString("cluster_id"))
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) {
@@ -91,7 +91,7 @@ func ValidateProviderCredentials(r *plugin_entities.InvokePluginRequest[requests
 
 func ValidateModelCredentials(r *plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials], ctx *gin.Context) {
 	// create session
-	session := createSession(r)
+	session := createSession(r, ctx.GetString("cluster_id"))
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) {

+ 1 - 1
internal/service/webhook.go

@@ -35,7 +35,7 @@ func Webhook(ctx *gin.Context, webhook *models.Webhook, path string) {
 		return
 	}
 
-	session := session_manager.NewSession(webhook.TenantID, "", webhook.PluginID)
+	session := session_manager.NewSession(webhook.TenantID, "", webhook.PluginID, ctx.GetString("cluster_id"))
 	defer session.Close()
 
 	session.BindRuntime(runtime)