Yeuoly 1 год назад
Родитель
Сommit
c69b51ad1a

+ 10 - 7
internal/core/plugin_daemon/basic.go

@@ -10,13 +10,16 @@ const (
 type PluginAccessAction string
 
 const (
-	PLUGIN_ACCESS_ACTION_INVOKE_TOOL           PluginAccessAction = "invoke_tool"
-	PLUGIN_ACCESS_ACTION_INVOKE_LLM            PluginAccessAction = "invoke_llm"
-	PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING PluginAccessAction = "invoke_text_embedding"
-	PLUGIN_ACCESS_ACTION_INVOKE_RERANK         PluginAccessAction = "invoke_rerank"
-	PLUGIN_ACCESS_ACTION_INVOKE_TTS            PluginAccessAction = "invoke_tts"
-	PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT    PluginAccessAction = "invoke_speech2text"
-	PLUGIN_ACCESS_ACTION_INVOKE_MODERATION     PluginAccessAction = "invoke_moderation"
+	PLUGIN_ACCESS_ACTION_INVOKE_TOOL                   PluginAccessAction = "invoke_tool"
+	PLUGIN_ACCESS_ACTION_VALIDATE_TOOL_CREDENTIALS     PluginAccessAction = "validate_tool_credentials"
+	PLUGIN_ACCESS_ACTION_INVOKE_LLM                    PluginAccessAction = "invoke_llm"
+	PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING         PluginAccessAction = "invoke_text_embedding"
+	PLUGIN_ACCESS_ACTION_INVOKE_RERANK                 PluginAccessAction = "invoke_rerank"
+	PLUGIN_ACCESS_ACTION_INVOKE_TTS                    PluginAccessAction = "invoke_tts"
+	PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT            PluginAccessAction = "invoke_speech2text"
+	PLUGIN_ACCESS_ACTION_INVOKE_MODERATION             PluginAccessAction = "invoke_moderation"
+	PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS PluginAccessAction = "validate_provider_credentials"
+	PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS    PluginAccessAction = "validate_model_credentials"
 )
 
 const (

+ 30 - 0
internal/core/plugin_daemon/model_service.go

@@ -183,3 +183,33 @@ func InvokeModeration(
 		PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
 	)
 }
+
+func ValidateProviderCredentials(
+	session *session_manager.Session,
+	request *requests.RequestValidateProviderCredentials,
+) (
+	*stream.StreamResponse[model_entities.ValidateCredentialsResult], error,
+) {
+	return genericInvokePlugin[requests.RequestValidateProviderCredentials, model_entities.ValidateCredentialsResult](
+		session,
+		request,
+		1,
+		PLUGIN_ACCESS_TYPE_MODEL,
+		PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
+	)
+}
+
+func ValidateModelCredentials(
+	session *session_manager.Session,
+	request *requests.RequestValidateModelCredentials,
+) (
+	*stream.StreamResponse[model_entities.ValidateCredentialsResult], error,
+) {
+	return genericInvokePlugin[requests.RequestValidateModelCredentials, model_entities.ValidateCredentialsResult](
+		session,
+		request,
+		1,
+		PLUGIN_ACCESS_TYPE_MODEL,
+		PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
+	)
+}

+ 18 - 3
internal/core/plugin_daemon/tool_service.go

@@ -2,8 +2,8 @@ package plugin_daemon
 
 import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
-	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
@@ -11,9 +11,9 @@ func InvokeTool(
 	session *session_manager.Session,
 	request *requests.RequestInvokeTool,
 ) (
-	*stream.StreamResponse[plugin_entities.ToolResponseChunk], error,
+	*stream.StreamResponse[tool_entities.ToolResponseChunk], error,
 ) {
-	return genericInvokePlugin[requests.RequestInvokeTool, plugin_entities.ToolResponseChunk](
+	return genericInvokePlugin[requests.RequestInvokeTool, tool_entities.ToolResponseChunk](
 		session,
 		request,
 		128,
@@ -21,3 +21,18 @@ func InvokeTool(
 		PLUGIN_ACCESS_ACTION_INVOKE_TOOL,
 	)
 }
+
+func ValidateToolCredentials(
+	session *session_manager.Session,
+	request *requests.RequestValidateToolCredentials,
+) (
+	*stream.StreamResponse[tool_entities.ValidateCredentialsResult], error,
+) {
+	return genericInvokePlugin[requests.RequestValidateToolCredentials, tool_entities.ValidateCredentialsResult](
+		session,
+		request,
+		1,
+		PLUGIN_ACCESS_TYPE_TOOL,
+		PLUGIN_ACCESS_ACTION_VALIDATE_TOOL_CREDENTIALS,
+	)
+}

+ 33 - 0
internal/server/controller.go

@@ -22,6 +22,17 @@ func InvokeTool(c *gin.Context) {
 	)
 }
 
+func ValidateToolCredentials(c *gin.Context) {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestValidateToolCredentials]
+
+	BindRequest[request](
+		c,
+		func(itr request) {
+			service.ValidateToolCredentials(&itr, c)
+		},
+	)
+}
+
 func InvokeLLM(c *gin.Context) {
 	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM]
 
@@ -87,3 +98,25 @@ func InvokeModeration(c *gin.Context) {
 		},
 	)
 }
+
+func ValidateProviderCredentials(c *gin.Context) {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials]
+
+	BindRequest[request](
+		c,
+		func(itr request) {
+			service.ValidateProviderCredentials(&itr, c)
+		},
+	)
+}
+
+func ValidateModelCredentials(c *gin.Context) {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials]
+
+	BindRequest[request](
+		c,
+		func(itr request) {
+			service.ValidateModelCredentials(&itr, c)
+		},
+	)
+}

+ 3 - 0
internal/server/http.go

@@ -12,12 +12,15 @@ func server(config *app.Config) {
 
 	engine.GET("/health/check", HealthCheck)
 	engine.POST("/plugin/tool/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTool)
+	engine.POST("/plugin/tool/validate_credentials", CheckingKey(config.PluginInnerApiKey), ValidateToolCredentials)
 	engine.POST("/plugin/llm/invoke", CheckingKey(config.PluginInnerApiKey), InvokeLLM)
 	engine.POST("/plugin/text_embedding/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTextEmbedding)
 	engine.POST("/plugin/rerank/invoke", CheckingKey(config.PluginInnerApiKey), InvokeRerank)
 	engine.POST("/plugin/tts/invoke", CheckingKey(config.PluginInnerApiKey), InvokeTTS)
 	engine.POST("/plugin/speech2text/invoke", CheckingKey(config.PluginInnerApiKey), InvokeSpeech2Text)
 	engine.POST("/plugin/moderation/invoke", CheckingKey(config.PluginInnerApiKey), InvokeModeration)
+	engine.POST("/plugin/model/validate_provider_credentials", CheckingKey(config.PluginInnerApiKey), ValidateProviderCredentials)
+	engine.POST("/plugin/model/validate_model_credentials", CheckingKey(config.PluginInnerApiKey), ValidateModelCredentials)
 
 	engine.Run(fmt.Sprintf(":%d", config.SERVER_PORT))
 }

+ 32 - 0
internal/service/invoke_model.go

@@ -0,0 +1,32 @@
+package service
+
+import (
+	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
+)
+
+func InvokeTool(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTool], ctx *gin.Context) {
+	// create session
+	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	defer session.Close()
+
+	baseSSEService(r, func() (*stream.StreamResponse[tool_entities.ToolResponseChunk], error) {
+		return plugin_daemon.InvokeTool(session, &r.Data)
+	}, ctx)
+}
+
+func ValidateToolCredentials(r *plugin_entities.InvokePluginRequest[requests.RequestValidateToolCredentials], ctx *gin.Context) {
+	// create session
+	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	defer session.Close()
+
+	baseSSEService(r, func() (*stream.StreamResponse[tool_entities.ValidateCredentialsResult], error) {
+		return plugin_daemon.ValidateToolCredentials(session, &r.Data)
+	}, ctx)
+}

+ 20 - 10
internal/service/invoke.go

@@ -11,16 +11,6 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
-func InvokeTool(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTool], ctx *gin.Context) {
-	// create session
-	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
-	defer session.Close()
-
-	baseSSEService(r, func() (*stream.StreamResponse[plugin_entities.ToolResponseChunk], error) {
-		return plugin_daemon.InvokeTool(session, &r.Data)
-	}, ctx)
-}
-
 func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM], ctx *gin.Context) {
 	// create session
 	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
@@ -80,3 +70,23 @@ func InvokeModeration(r *plugin_entities.InvokePluginRequest[requests.RequestInv
 		return plugin_daemon.InvokeModeration(session, &r.Data)
 	}, ctx)
 }
+
+func ValidateProviderCredentials(r *plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials], ctx *gin.Context) {
+	// create session
+	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	defer session.Close()
+
+	baseSSEService(r, func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) {
+		return plugin_daemon.ValidateProviderCredentials(session, &r.Data)
+	}, ctx)
+}
+
+func ValidateModelCredentials(r *plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials], ctx *gin.Context) {
+	// create session
+	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	defer session.Close()
+
+	baseSSEService(r, func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) {
+		return plugin_daemon.ValidateModelCredentials(session, &r.Data)
+	}, ctx)
+}

+ 5 - 0
internal/types/entities/model_entities/validate.go

@@ -0,0 +1,5 @@
+package model_entities
+
+type ValidateCredentialsResult struct {
+	Result bool `json:"result"`
+}

+ 0 - 5
internal/types/entities/plugin_entities/event.go

@@ -39,11 +39,6 @@ const (
 	SESSION_MESSAGE_TYPE_INVOKE SESSION_MESSAGE_TYPE = "invoke"
 )
 
-type ToolResponseChunk struct {
-	Type    string         `json:"type"`
-	Message map[string]any `json:"message"`
-}
-
 type PluginResponseChunk struct {
 	Type string          `json:"type"`
 	Data json.RawMessage `json:"data"`

+ 12 - 0
internal/types/entities/requests/model.go

@@ -59,3 +59,15 @@ type RequestInvokeModeration struct {
 	ModelType model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=moderation"`
 	Text      string                   `json:"text" validate:"required"`
 }
+
+type RequestValidateProviderCredentials struct {
+	Provider    string         `json:"provider" validate:"required"`
+	Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
+}
+
+type RequestValidateModelCredentials struct {
+	Provider    string                   `json:"provider" validate:"required"`
+	ModelType   model_entities.ModelType `json:"model_type"  validate:"required,model_type"`
+	Model       string                   `json:"model" validate:"required"`
+	Credentials map[string]any           `json:"credentials" validate:"omitempty,dive,is_basic_type"`
+}

+ 5 - 0
internal/types/entities/requests/tool.go

@@ -6,3 +6,8 @@ type RequestInvokeTool struct {
 	ToolParameters map[string]any `json:"tool_parameters" validate:"omitempty,dive,is_basic_type"`
 	Credentials    map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
 }
+
+type RequestValidateToolCredentials struct {
+	Provider    string         `json:"provider" validate:"required"`
+	Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
+}

+ 6 - 0
internal/types/entities/tool_entities/tool.go

@@ -0,0 +1,6 @@
+package tool_entities
+
+type ToolResponseChunk struct {
+	Type    string         `json:"type"`
+	Message map[string]any `json:"message"`
+}

+ 5 - 0
internal/types/entities/tool_entities/validate.go

@@ -0,0 +1,5 @@
+package tool_entities
+
+type ValidateCredentialsResult struct {
+	Result bool `json:"result"`
+}