Ver código fonte

refactor: using plugin id to dispatch request instead

Yeuoly 10 meses atrás
pai
commit
b92d497ced

+ 6 - 2
internal/server/constants/constants.go

@@ -1,6 +1,10 @@
 package constants
 
 const (
-	X_PLUGIN_IDENTIFIER = "X-Plugin-Identifier"
-	X_API_KEY           = "X-Api-Key"
+	X_PLUGIN_ID = "X-Plugin-ID"
+	X_API_KEY   = "X-Api-Key"
+
+	CONTEXT_KEY_PLUGIN_INSTALLATION      = "plugin_installation"
+	CONTEXT_KEY_PLUGIN_UNIQUE_IDENTIFIER = "plugin_unique_identifier"
+	CONTEXT_KEY_CLUSTER_ID               = "cluster_id"
 )

+ 12 - 9
internal/server/controllers/base.go

@@ -30,24 +30,27 @@ func BindRequest[T any](r *gin.Context, success func(T)) {
 	success(request)
 }
 
-func BindRequestWithPluginUniqueIdentifier[T any](r *gin.Context, success func(
-	T, plugin_entities.PluginUniqueIdentifier,
+func BindPluginDispatchRequest[T any](r *gin.Context, success func(
+	plugin_entities.InvokePluginRequest[T],
 )) {
-	BindRequest(r, func(req T) {
-		plugin_unique_identifier := r.GetHeader(constants.X_PLUGIN_IDENTIFIER)
-		if plugin_unique_identifier == "" {
+	BindRequest(r, func(req plugin_entities.InvokePluginRequest[T]) {
+		plugin_unique_identifier_any, exists := r.Get(constants.CONTEXT_KEY_PLUGIN_UNIQUE_IDENTIFIER)
+		if !exists {
 			resp := entities.NewErrorResponse(-400, "Plugin unique identifier is required")
 			r.JSON(400, resp)
 			return
 		}
 
-		identifier, err := plugin_entities.NewPluginUniqueIdentifier(plugin_unique_identifier)
-		if err != nil {
-			resp := entities.NewErrorResponse(-400, err.Error())
+		plugin_unique_identifier, ok := plugin_unique_identifier_any.(plugin_entities.PluginUniqueIdentifier)
+		if !ok {
+			resp := entities.NewErrorResponse(-400, "Plugin unique identifier is required")
 			r.JSON(400, resp)
 			return
 		}
 
-		success(req, identifier)
+		// set plugin unique identifier
+		req.UniqueIdentifier = plugin_unique_identifier
+
+		success(req)
 	})
 }

+ 12 - 12
internal/server/controllers/model.go

@@ -12,7 +12,7 @@ func InvokeLLM(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.InvokeLLM(&itr, c, config.PluginMaxExecutionTimeout)
@@ -25,7 +25,7 @@ func InvokeTextEmbedding(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.InvokeTextEmbedding(&itr, c, config.PluginMaxExecutionTimeout)
@@ -38,7 +38,7 @@ func InvokeRerank(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.InvokeRerank(&itr, c, config.PluginMaxExecutionTimeout)
@@ -51,7 +51,7 @@ func InvokeTTS(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.InvokeTTS(&itr, c, config.PluginMaxExecutionTimeout)
@@ -64,7 +64,7 @@ func InvokeSpeech2Text(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.InvokeSpeech2Text(&itr, c, config.PluginMaxExecutionTimeout)
@@ -77,7 +77,7 @@ func InvokeModeration(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.InvokeModeration(&itr, c, config.PluginMaxExecutionTimeout)
@@ -90,7 +90,7 @@ func ValidateProviderCredentials(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.ValidateProviderCredentials(&itr, c, config.PluginMaxExecutionTimeout)
@@ -103,7 +103,7 @@ func ValidateModelCredentials(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.ValidateModelCredentials(&itr, c, config.PluginMaxExecutionTimeout)
@@ -116,7 +116,7 @@ func GetTTSModelVoices(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestGetTTSModelVoices]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.GetTTSModelVoices(&itr, c, config.PluginMaxExecutionTimeout)
@@ -129,7 +129,7 @@ func GetTextEmbeddingNumTokens(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestGetTextEmbeddingNumTokens]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.GetTextEmbeddingNumTokens(&itr, c, config.PluginMaxExecutionTimeout)
@@ -142,7 +142,7 @@ func GetLLMNumTokens(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestGetLLMNumTokens]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.GetLLMNumTokens(&itr, c, config.PluginMaxExecutionTimeout)
@@ -155,7 +155,7 @@ func GetAIModelSchema(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestGetAIModelSchema]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.GetAIModelSchema(&itr, c, config.PluginMaxExecutionTimeout)

+ 2 - 2
internal/server/controllers/tool.go

@@ -12,7 +12,7 @@ func InvokeTool(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeTool]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.InvokeTool(&itr, c, config.PluginMaxExecutionTimeout)
@@ -25,7 +25,7 @@ func ValidateToolCredentials(config *app.Config) gin.HandlerFunc {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestValidateToolCredentials]
 
 	return func(c *gin.Context) {
-		BindRequest(
+		BindPluginDispatchRequest(
 			c,
 			func(itr request) {
 				service.ValidateToolCredentials(&itr, c, config.PluginMaxExecutionTimeout)

+ 3 - 1
internal/server/endpoint.go

@@ -54,7 +54,9 @@ func (app *App) EndpointHandler(ctx *gin.Context, hook_id string, path string) {
 		return
 	}
 
-	plugin_unique_identifier, err := plugin_entities.NewPluginUniqueIdentifier(plugin_installation.PluginUniqueIdentifier)
+	plugin_unique_identifier, err := plugin_entities.NewPluginUniqueIdentifier(
+		plugin_installation.PluginUniqueIdentifier,
+	)
 	if err != nil {
 		ctx.JSON(400, gin.H{"error": "invalid plugin unique identifier"})
 		return

+ 1 - 0
internal/server/http_server.go

@@ -52,6 +52,7 @@ func (app *App) pluginGroup(group *gin.RouterGroup, config *app.Config) {
 }
 
 func (app *App) pluginDispatchGroup(group *gin.RouterGroup, config *app.Config) {
+	group.Use(app.FetchPluginInstallation())
 	group.Use(app.RedirectPluginInvoke())
 	group.Use(app.InitClusterID())
 

+ 35 - 21
internal/server/middleware.go

@@ -1,13 +1,14 @@
 package server
 
 import (
-	"bytes"
 	"io"
 
 	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/db"
 	"github.com/langgenius/dify-plugin-daemon/internal/server/constants"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 )
 
@@ -24,35 +25,48 @@ func CheckingKey(key string) gin.HandlerFunc {
 	}
 }
 
-type ginContextReader struct {
-	reader *bytes.Reader
-}
+func (app *App) FetchPluginInstallation() gin.HandlerFunc {
+	return func(ctx *gin.Context) {
+		plugin_id := ctx.Request.Header.Get(constants.X_PLUGIN_ID)
+		if plugin_id == "" {
+			ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request, plugin_id is required"})
+			return
+		}
 
-func (g *ginContextReader) Read(p []byte) (n int, err error) {
-	return g.reader.Read(p)
-}
+		// fetch plugin installation
+		installation, err := db.GetOne[models.PluginInstallation](
+			db.Equal("plugin_id", plugin_id),
+		)
+		if err != nil {
+			ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request, " + err.Error()})
+			return
+		}
 
-func (g *ginContextReader) Close() error {
-	return nil
+		identity, err := plugin_entities.NewPluginUniqueIdentifier(installation.PluginUniqueIdentifier)
+		if err != nil {
+			ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request, " + err.Error()})
+			return
+		}
+
+		ctx.Set(constants.CONTEXT_KEY_PLUGIN_INSTALLATION, installation)
+		ctx.Set(constants.CONTEXT_KEY_PLUGIN_UNIQUE_IDENTIFIER, identity)
+		ctx.Next()
+	}
 }
 
 // RedirectPluginInvoke redirects the request to the correct cluster node
 func (app *App) RedirectPluginInvoke() gin.HandlerFunc {
 	return func(ctx *gin.Context) {
-		// get plugin identity
-		raw, err := ctx.GetRawData()
-		if err != nil {
-			ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request"})
+		// get plugin unique identifier
+		identity_any, ok := ctx.Get(constants.CONTEXT_KEY_PLUGIN_UNIQUE_IDENTIFIER)
+		if !ok {
+			ctx.AbortWithStatusJSON(500, gin.H{"error": "Internal server error, plugin unique identifier not found"})
 			return
 		}
 
-		ctx.Request.Body = &ginContextReader{
-			reader: bytes.NewReader(raw),
-		}
-
-		identity, err := plugin_entities.NewPluginUniqueIdentifier(ctx.Request.Header.Get(constants.X_PLUGIN_IDENTIFIER))
-		if err != nil {
-			ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request, " + err.Error()})
+		identity, ok := identity_any.(plugin_entities.PluginUniqueIdentifier)
+		if !ok {
+			ctx.AbortWithStatusJSON(500, gin.H{"error": "Internal server error, failed to parse plugin unique identifier"})
 			return
 		}
 
@@ -119,7 +133,7 @@ func (app *App) redirectPluginInvokeByPluginIdentifier(
 
 func (app *App) InitClusterID() gin.HandlerFunc {
 	return func(ctx *gin.Context) {
-		ctx.Set("cluster_id", app.cluster.ID())
+		ctx.Set(constants.CONTEXT_KEY_CLUSTER_ID, app.cluster.ID())
 		ctx.Next()
 	}
 }

+ 5 - 2
internal/service/invoke_tool.go

@@ -24,7 +24,10 @@ func createSession[T any](
 	if manager == nil {
 		return nil, errors.New("failed to get plugin manager")
 	}
-	runtime := manager.Get(r.PluginUniqueIdentifier)
+
+	// try fetch plugin identifier from plugin id
+
+	runtime := manager.Get(r.UniqueIdentifier)
 	if runtime == nil {
 		return nil, errors.New("failed to get plugin runtime")
 	}
@@ -33,7 +36,7 @@ func createSession[T any](
 		session_manager.NewSessionPayload{
 			TenantID:               r.TenantId,
 			UserID:                 r.UserId,
-			PluginUniqueIdentifier: r.PluginUniqueIdentifier,
+			PluginUniqueIdentifier: r.UniqueIdentifier,
 			ClusterID:              cluster_id,
 			InvokeFrom:             access_type,
 			Action:                 access_action,

+ 3 - 2
internal/types/entities/plugin_entities/request.go

@@ -6,12 +6,13 @@ type InvokePluginUserIdentity struct {
 }
 
 type BasePluginIdentifier struct {
-	PluginUniqueIdentifier PluginUniqueIdentifier `json:"plugin_unique_identifier"`
+	PluginID string `json:"plugin_id"`
 }
 
 type InvokePluginRequest[T any] struct {
 	InvokePluginUserIdentity
 	BasePluginIdentifier
 
-	Data T `json:"data" validate:"required"`
+	UniqueIdentifier PluginUniqueIdentifier `json:"unique_identifier"`
+	Data             T                      `json:"data" validate:"required"`
 }