浏览代码

refactor: types

Yeuoly 1 年之前
父节点
当前提交
82df72e20b

+ 24 - 6
internal/core/dify_invocation/http_request.go

@@ -1,6 +1,8 @@
 package dify_invocation
 
 import (
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/http_requests"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
@@ -29,14 +31,30 @@ func StreamResponse[T any](method string, path string, options ...http_requests.
 	return http_requests.RequestAndParseStream[T](client, difyPath(path), method, options...)
 }
 
-func InvokeModel(payload *InvokeModelRequest) (*stream.StreamResponse[InvokeModelResponseChunk], error) {
-	return StreamResponse[InvokeModelResponseChunk]("POST", "invoke/model", http_requests.HttpPayloadJson(payload))
+func InvokeLLM(payload *InvokeLLMRequest) (*stream.StreamResponse[model_entities.LLMResultChunk], error) {
+	return StreamResponse[model_entities.LLMResultChunk]("POST", "invoke/llm", http_requests.HttpPayloadJson(payload))
 }
 
-func InvokeTool(payload *InvokeToolRequest) (*stream.StreamResponse[InvokeToolResponseChunk], error) {
-	return StreamResponse[InvokeToolResponseChunk]("POST", "invoke/tool", http_requests.HttpPayloadJson(payload))
+func InvokeTextEmbedding(payload *InvokeTextEmbeddingRequest) (*model_entities.TextEmbeddingResult, error) {
+	return Request[model_entities.TextEmbeddingResult]("POST", "invoke/text_embedding", http_requests.HttpPayloadJson(payload))
 }
 
-func InvokeNode[T WorkflowNodeData](payload *InvokeNodeRequest[T]) (*InvokeNodeResponse, error) {
-	return Request[InvokeNodeResponse]("POST", "invoke/node", http_requests.HttpPayloadJson(payload))
+func InvokeRerank(payload *InvokeRerankRequest) (*model_entities.RerankResult, error) {
+	return Request[model_entities.RerankResult]("POST", "invoke/rerank", http_requests.HttpPayloadJson(payload))
+}
+
+func InvokeTTS(payload *InvokeTTSRequest) (*stream.StreamResponse[model_entities.TTSResult], error) {
+	return StreamResponse[model_entities.TTSResult]("POST", "invoke/tts", http_requests.HttpPayloadJson(payload))
+}
+
+func InvokeSpeech2Text(payload *InvokeSpeech2TextRequest) (*model_entities.Speech2TextResult, error) {
+	return Request[model_entities.Speech2TextResult]("POST", "invoke/speech2text", http_requests.HttpPayloadJson(payload))
+}
+
+func InvokeModeration(payload *InvokeModerationRequest) (*model_entities.ModerationResult, error) {
+	return Request[model_entities.ModerationResult]("POST", "invoke/moderation", http_requests.HttpPayloadJson(payload))
+}
+
+func InvokeTool(payload *InvokeToolRequest) (*stream.StreamResponse[tool_entities.ToolResponseChunk], error) {
+	return StreamResponse[tool_entities.ToolResponseChunk]("POST", "invoke/tool", http_requests.HttpPayloadJson(payload))
 }

+ 51 - 45
internal/core/dify_invocation/types.go

@@ -1,9 +1,7 @@
 package dify_invocation
 
 import (
-	"encoding/json"
-
-	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
 )
 
 type BaseInvokeDifyRequest struct {
@@ -15,66 +13,74 @@ type BaseInvokeDifyRequest struct {
 type InvokeType string
 
 const (
-	INVOKE_TYPE_MODEL InvokeType = "model"
-	INVOKE_TYPE_TOOL  InvokeType = "tool"
-	INVOKE_TYPE_NODE  InvokeType = "node"
+	INVOKE_TYPE_LLM            InvokeType = "LLM"
+	INVOKE_TYPE_TEXT_EMBEDDING InvokeType = "text_embedding"
+	INVOKE_TYPE_RERANK         InvokeType = "rerank"
+	INVOKE_TYPE_TTS            InvokeType = "tts"
+	INVOKE_TYPE_SPEECH2TEXT    InvokeType = "speech2text"
+	INVOKE_TYPE_MODERATION     InvokeType = "moderation"
+	INVOKE_TYPE_TOOL           InvokeType = "tool"
+	INVOKE_TYPE_NODE           InvokeType = "node"
 )
 
-type InvokeModelRequest struct {
+type InvokeLLMRequest struct {
 	BaseInvokeDifyRequest
-	Provider   string                   `json:"provider"`
-	Model      string                   `json:"model"`
-	ModelType  model_entities.ModelType `json:"model_type"`
-	Parameters map[string]any           `json:"parameters"`
-}
-
-func (r InvokeModelRequest) MarshalJSON() ([]byte, error) {
-	flattened := make(map[string]any)
-	flattened["tenant_id"] = r.TenantId
-	flattened["user_id"] = r.UserId
-	flattened["provider"] = r.Provider
-	flattened["model"] = r.Model
-	flattened["parameters"] = r.Parameters
-	return json.Marshal(flattened)
+	Data struct {
+		requests.BaseRequestInvokeModel
+		requests.InvokeLLMSchema
+	} `json:"data" validate:"required"`
 }
 
-type InvokeModelResponseChunk struct {
+type InvokeTextEmbeddingRequest struct {
+	BaseInvokeDifyRequest
+	Data struct {
+		requests.BaseRequestInvokeModel
+		requests.InvokeTextEmbeddingSchema
+	} `json:"data" validate:"required"`
 }
 
-type InvokeToolRequest struct {
+type InvokeRerankRequest struct {
 	BaseInvokeDifyRequest
-	Provider   string         `json:"provider"`
-	Tool       string         `json:"tool"`
-	Parameters map[string]any `json:"parameters"`
+	Data struct {
+		requests.BaseRequestInvokeModel
+		requests.InvokeRerankSchema
+	} `json:"data" validate:"required"`
 }
 
-func (r InvokeToolRequest) MarshalJSON() ([]byte, error) {
-	flattened := make(map[string]any)
-	flattened["tenant_id"] = r.TenantId
-	flattened["user_id"] = r.UserId
-	flattened["provider"] = r.Provider
-	flattened["tool"] = r.Tool
-	flattened["parameters"] = r.Parameters
-	return json.Marshal(flattened)
+type InvokeTTSRequest struct {
+	BaseInvokeDifyRequest
+	Data struct {
+		requests.BaseRequestInvokeModel
+		requests.InvokeTTSSchema
+	} `json:"data" validate:"required"`
 }
 
-type InvokeToolResponseChunk struct {
+type InvokeSpeech2TextRequest struct {
+	BaseInvokeDifyRequest
+	Data struct {
+		requests.BaseRequestInvokeModel
+		requests.InvokeSpeech2TextSchema
+	} `json:"data" validate:"required"`
 }
 
-type InvokeNodeRequest[T WorkflowNodeData] struct {
+type InvokeModerationRequest struct {
 	BaseInvokeDifyRequest
-	NodeType NodeType `json:"node_type"`
-	NodeData T        `json:"node_data"`
+	Data struct {
+		requests.BaseRequestInvokeModel
+		requests.InvokeModerationSchema
+	} `json:"data" validate:"required"`
 }
 
-func (r InvokeNodeRequest[T]) MarshalJSON() ([]byte, error) {
-	flattened := make(map[string]any)
-	flattened["tenant_id"] = r.TenantId
-	flattened["user_id"] = r.UserId
-	flattened["node_type"] = r.NodeType
-	flattened["node_data"] = r.NodeData
-	return json.Marshal(flattened)
+type InvokeToolRequest struct {
+	BaseInvokeDifyRequest
+	Data struct {
+		requests.RequestInvokeTool
+	} `json:"data" validate:"required"`
 }
 
 type InvokeNodeResponse struct {
+	ProcessData      map[string]any `json:"process_data"`
+	Output           map[string]any `json:"output"`
+	Input            map[string]any `json:"input"`
+	EdgeSourceHandle []string       `json:"edge_source_handle"`
 }

+ 2 - 112
internal/core/plugin_daemon/invoke_dify.go

@@ -7,9 +7,7 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
-	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
-	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
 
 func invokeDify(
@@ -69,60 +67,14 @@ func prepareDifyInvocationArguments(session *session_manager.Session, request ma
 func dispatchDifyInvocationTask(handle *backwards_invocation.BackwardsInvocation) {
 	switch handle.Type() {
 	case dify_invocation.INVOKE_TYPE_TOOL:
-		r, err := parser.MapToStruct[dify_invocation.InvokeToolRequest](handle.RequestData())
+		_, err := parser.MapToStruct[dify_invocation.InvokeToolRequest](handle.RequestData())
 		if err != nil {
 			handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error()))
 			return
 		}
 
-		submitToolTask(runtime, session, backwards_request_id, &r)
-	case dify_invocation.INVOKE_TYPE_MODEL:
-		r, err := parser.MapToStruct[dify_invocation.InvokeModelRequest](handle.RequestData())
-		if err != nil {
-			handle.WriteError(fmt.Errorf("unmarshal invoke model request failed: %s", err.Error()))
-			return
-		}
-
-		submitModelTask(runtime, session, backwards_request_id, &r)
-	case dify_invocation.INVOKE_TYPE_NODE:
-		node_type, ok := detailed_request["node_type"].(dify_invocation.NodeType)
-		if !ok {
-			return fmt.Errorf("invoke request missing node_type: %s", data)
-		}
-		node_data, ok := detailed_request["data"].(map[string]any)
-		if !ok {
-			return fmt.Errorf("invoke request missing data: %s", data)
-		}
-		switch node_type {
-		case dify_invocation.QUESTION_CLASSIFIER:
-			d := dify_invocation.InvokeNodeRequest[dify_invocation.QuestionClassifierNodeData]{
-				NodeType: dify_invocation.QUESTION_CLASSIFIER,
-			}
-			if err := d.FromMap(node_data); err != nil {
-				return fmt.Errorf("unmarshal question classifier node data failed: %s", err.Error())
-			}
-			submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d)
-		case dify_invocation.KNOWLEDGE_RETRIEVAL:
-			d := dify_invocation.InvokeNodeRequest[dify_invocation.KnowledgeRetrievalNodeData]{
-				NodeType: dify_invocation.KNOWLEDGE_RETRIEVAL,
-			}
-			if err := d.FromMap(node_data); err != nil {
-				return fmt.Errorf("unmarshal knowledge retrieval node data failed: %s", err.Error())
-			}
-			submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d)
-		case dify_invocation.PARAMETER_EXTRACTOR:
-			d := dify_invocation.InvokeNodeRequest[dify_invocation.ParameterExtractorNodeData]{
-				NodeType: dify_invocation.PARAMETER_EXTRACTOR,
-			}
-			if err := d.FromMap(node_data); err != nil {
-				return fmt.Errorf("unmarshal parameter extractor node data failed: %s", err.Error())
-			}
-			submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d)
-		default:
-			return fmt.Errorf("unknown node type: %s", node_type)
-		}
 	default:
-		return fmt.Errorf("unknown invoke type: %s", typ)
+		handle.WriteError(fmt.Errorf("unsupported invoke type: %s", handle.Type()))
 	}
 }
 
@@ -130,65 +82,3 @@ func setTaskContext(session *session_manager.Session, r *dify_invocation.BaseInv
 	r.TenantId = session.TenantID()
 	r.UserId = session.UserID()
 }
-
-func submitModelTask(
-	runtime entities.PluginRuntimeInterface,
-	session *session_manager.Session,
-	request_id string,
-	t *dify_invocation.InvokeModelRequest,
-) {
-	setTaskContext(session, &t.BaseInvokeDifyRequest)
-	routine.Submit(func() {
-		response, err := dify_invocation.InvokeModel(t)
-		if err != nil {
-			log.Error("invoke model failed: %s", err.Error())
-			return
-		}
-		defer response.Close()
-
-		for response.Next() {
-			chunk, _ := response.Read()
-			fmt.Println(chunk)
-		}
-	})
-}
-
-func submitToolTask(
-	runtime entities.PluginRuntimeInterface,
-	session *session_manager.Session,
-	request_id string,
-	t *dify_invocation.InvokeToolRequest,
-) {
-	setTaskContext(session, &t.BaseInvokeDifyRequest)
-	routine.Submit(func() {
-		response, err := dify_invocation.InvokeTool(t)
-		if err != nil {
-			log.Error("invoke tool failed: %s", err.Error())
-			return
-		}
-		defer response.Close()
-
-		for response.Next() {
-			chunk, _ := response.Read()
-			fmt.Println(chunk)
-		}
-	})
-}
-
-func submitNodeInvocationRequestTask[W dify_invocation.WorkflowNodeData](
-	runtime entities.PluginRuntimeInterface,
-	session *session_manager.Session,
-	request_id string,
-	t *dify_invocation.InvokeNodeRequest[W],
-) {
-	setTaskContext(session, &t.BaseInvokeDifyRequest)
-	routine.Submit(func() {
-		response, err := dify_invocation.InvokeNode(t)
-		if err != nil {
-			log.Error("invoke node failed: %s", err.Error())
-			return
-		}
-
-		fmt.Println(response)
-	})
-}

+ 49 - 17
internal/types/entities/requests/model.go

@@ -4,16 +4,16 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
 )
 
-type BaseRequestInvokeModel struct {
-	Provider    string         `json:"provider" validate:"required"`
-	Model       string         `json:"model" validate:"required"`
+type Credentials struct {
 	Credentials map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
 }
 
-type RequestInvokeLLM struct {
-	BaseRequestInvokeModel
+type BaseRequestInvokeModel struct {
+	Provider string `json:"provider" validate:"required"`
+	Model    string `json:"model" validate:"required"`
+}
 
-	ModelType       model_entities.ModelType           `json:"model_type"  validate:"required,model_type,eq=llm"`
+type InvokeLLMSchema struct {
 	ModelParameters map[string]any                     `json:"model_parameters"  validate:"omitempty,dive,is_basic_type"`
 	PromptMessages  []model_entities.PromptMessage     `json:"prompt_messages"  validate:"omitempty,dive"`
 	Tools           []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"`
@@ -21,43 +21,75 @@ type RequestInvokeLLM struct {
 	Stream          bool                               `json:"stream" `
 }
 
+type RequestInvokeLLM struct {
+	BaseRequestInvokeModel
+	Credentials
+	InvokeLLMSchema
+
+	ModelType model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=llm"`
+}
+
+type InvokeTextEmbeddingSchema struct {
+	Texts []string `json:"texts" validate:"required,dive"`
+}
+
 type RequestInvokeTextEmbedding struct {
 	BaseRequestInvokeModel
+	Credentials
+	InvokeTextEmbeddingSchema
 
 	ModelType model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=text-embedding"`
-	Texts     []string                 `json:"texts" validate:"required,dive"`
+}
+
+type InvokeRerankSchema struct {
+	Query          string   `json:"query" validate:"required"`
+	Docs           []string `json:"docs" validate:"required,dive"`
+	ScoreThreshold float64  `json:"score_threshold" `
+	TopN           int      `json:"top_n" `
 }
 
 type RequestInvokeRerank struct {
 	BaseRequestInvokeModel
+	Credentials
+	InvokeRerankSchema
+
+	ModelType model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=rerank"`
+}
 
-	ModelType      model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=rerank"`
-	Query          string                   `json:"query" validate:"required"`
-	Docs           []string                 `json:"docs" validate:"required,dive"`
-	ScoreThreshold float64                  `json:"score_threshold" `
-	TopN           int                      `json:"top_n" `
+type InvokeTTSSchema struct {
+	ContentText string `json:"content_text"  validate:"required"`
+	Voice       string `json:"voice" validate:"required"`
 }
 
 type RequestInvokeTTS struct {
 	BaseRequestInvokeModel
+	Credentials
+
+	ModelType model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=tts"`
+}
 
-	ModelType   model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=tts"`
-	ContentText string                   `json:"content_text"  validate:"required"`
-	Voice       string                   `json:"voice" validate:"required"`
+type InvokeSpeech2TextSchema struct {
+	File string `json:"file" validate:"required"` // hexing encoded voice file
 }
 
 type RequestInvokeSpeech2Text struct {
 	BaseRequestInvokeModel
+	Credentials
+	InvokeSpeech2TextSchema
 
 	ModelType model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=speech2text"`
-	File      string                   `json:"file" validate:"required"` // hexing encoded voice file
+}
+
+type InvokeModerationSchema struct {
+	Text string `json:"text" validate:"required"`
 }
 
 type RequestInvokeModeration struct {
 	BaseRequestInvokeModel
+	Credentials
+	InvokeModerationSchema
 
 	ModelType model_entities.ModelType `json:"model_type"  validate:"required,model_type,eq=moderation"`
-	Text      string                   `json:"text" validate:"required"`
 }
 
 type RequestValidateProviderCredentials struct {

+ 6 - 2
internal/core/dify_invocation/workflow_node_data.go

@@ -1,4 +1,4 @@
-package dify_invocation
+package requests
 
 type WorkflowNodeData interface {
 	KnowledgeRetrievalNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
@@ -10,7 +10,6 @@ const (
 	KNOWLEDGE_RETRIEVAL NodeType = "knowledge_retrieval"
 	QUESTION_CLASSIFIER NodeType = "question_classifier"
 	PARAMETER_EXTRACTOR NodeType = "parameter_extractor"
-	CODE                NodeType = "code"
 )
 
 type KnowledgeRetrievalNodeData struct {
@@ -21,3 +20,8 @@ type QuestionClassifierNodeData struct {
 
 type ParameterExtractorNodeData struct {
 }
+
+type InvokeNodeRequest[T WorkflowNodeData] struct {
+	NodeType NodeType `json:"node_type"`
+	NodeData T        `json:"node_data"`
+}

+ 6 - 2
internal/types/entities/requests/tool.go

@@ -1,10 +1,14 @@
 package requests
 
-type RequestInvokeTool struct {
+type InvokeToolSchema struct {
 	Provider       string         `json:"provider" validate:"required"`
 	Tool           string         `json:"tool" validate:"required"`
 	ToolParameters map[string]any `json:"tool_parameters" validate:"omitempty,dive,is_basic_type"`
-	Credentials    map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
+}
+
+type RequestInvokeTool struct {
+	InvokeToolSchema
+	Credentials
 }
 
 type RequestValidateToolCredentials struct {