소스 검색

refactor: validator

Yeuoly 1 년 전
부모
커밋
cbd3189e06

+ 1 - 8
internal/core/dify_invocation/workflow_node_data.go

@@ -3,7 +3,7 @@ package dify_invocation
 type WorkflowNodeData interface {
 	FromMap(map[string]any) error
 
-	*KnowledgeRetrievalNodeData | *QuestionClassifierNodeData | *ParameterExtractorNodeData | *CodeNodeData
+	*KnowledgeRetrievalNodeData | *QuestionClassifierNodeData | *ParameterExtractorNodeData
 }
 
 type NodeType string
@@ -35,10 +35,3 @@ type ParameterExtractorNodeData struct {
 func (r *ParameterExtractorNodeData) FromMap(data map[string]any) error {
 	return nil
 }
-
-type CodeNodeData struct {
-}
-
-func (r *CodeNodeData) FromMap(data map[string]any) error {
-	return nil
-}

+ 0 - 9
internal/core/plugin_daemon/invoke_dify.go

@@ -89,15 +89,6 @@ func invokeDify(
 				return fmt.Errorf("unmarshal parameter extractor node data failed: %s", err.Error())
 			}
 			submitNodeInvocationRequestTask(runtime, session, request_id, &d)
-		case dify_invocation.CODE:
-			d := dify_invocation.InvokeNodeRequest[*dify_invocation.CodeNodeData]{
-				NodeType: dify_invocation.CODE,
-				NodeData: &dify_invocation.CodeNodeData{},
-			}
-			if err := d.FromMap(node_data); err != nil {
-				return fmt.Errorf("unmarshal code node data failed: %s", err.Error())
-			}
-			submitNodeInvocationRequestTask(runtime, session, request_id, &d)
 		default:
 			return fmt.Errorf("unknown node type: %s", node_type)
 		}

+ 0 - 59
internal/types/entities/model_entities/llm.go

@@ -156,11 +156,6 @@ func (p *PromptMessage) UnmarshalJSON(data []byte) error {
 		}
 	}
 
-	// Validate the struct
-	if err := validators.GlobalEntitiesValidator.Struct(p); err != nil {
-		return err
-	}
-
 	// validate tool call id
 	if p.Role == PROMPT_MESSAGE_ROLE_TOOL && p.ToolCallId == "" {
 		return errors.New("tool call id is required")
@@ -175,24 +170,6 @@ type PromptMessageTool struct {
 	Parameters  map[string]any `json:"parameters"`
 }
 
-func (p *PromptMessageTool) UnmarshalJSON(data []byte) error {
-	type Alias PromptMessageTool
-	aux := &struct {
-		*Alias
-	}{
-		Alias: (*Alias)(p),
-	}
-	if err := json.Unmarshal(data, &aux); err != nil {
-		return err
-	}
-
-	if err := validators.GlobalEntitiesValidator.Struct(p); err != nil {
-		return err
-	}
-
-	return nil
-}
-
 type LLMResultChunk struct {
 	Model             LLMModel            `json:"model" validate:"required"`
 	PromptMessages    []PromptMessage     `json:"prompt_messages" validate:"required,dive"`
@@ -200,24 +177,6 @@ type LLMResultChunk struct {
 	Delta             LLMResultChunkDelta `json:"delta" validate:"required"`
 }
 
-func (l *LLMResultChunk) UnmarshalJSON(data []byte) error {
-	type Alias LLMResultChunk
-	aux := &struct {
-		*Alias
-	}{
-		Alias: (*Alias)(l),
-	}
-	if err := json.Unmarshal(data, &aux); err != nil {
-		return err
-	}
-
-	if err := validators.GlobalEntitiesValidator.Struct(l); err != nil {
-		return err
-	}
-
-	return nil
-}
-
 type LLMUsage struct {
 	PromptTokens        *int            `json:"prompt_tokens" validate:"required"`
 	PromptUnitPrice     decimal.Decimal `json:"prompt_unit_price" validate:"required"`
@@ -233,24 +192,6 @@ type LLMUsage struct {
 	Latency             *float64        `json:"latency" validate:"required"`
 }
 
-func (l *LLMUsage) UnmarshalJSON(data []byte) error {
-	type Alias LLMUsage
-	aux := &struct {
-		*Alias
-	}{
-		Alias: (*Alias)(l),
-	}
-	if err := json.Unmarshal(data, &aux); err != nil {
-		return err
-	}
-
-	if err := validators.GlobalEntitiesValidator.Struct(l); err != nil {
-		return err
-	}
-
-	return nil
-}
-
 type LLMResultChunkDelta struct {
 	Index        *int          `json:"index" validate:"required"`
 	Message      PromptMessage `json:"message" validate:"required"`

+ 14 - 29
internal/types/entities/model_entities/llm_test.go

@@ -1,8 +1,9 @@
 package model_entities
 
 import (
-	"encoding/json"
 	"testing"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 )
 
 func TestFullFunctionPromptMessage(t *testing.T) {
@@ -42,9 +43,7 @@ func TestFullFunctionPromptMessage(t *testing.T) {
 		`
 	)
 
-	var prompt_message PromptMessage
-
-	err := json.Unmarshal([]byte(system_message), &prompt_message)
+	prompt_message, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(system_message))
 	if err != nil {
 		t.Error(err)
 	}
@@ -52,7 +51,7 @@ func TestFullFunctionPromptMessage(t *testing.T) {
 		t.Error("role is not system")
 	}
 
-	err = json.Unmarshal([]byte(user_message), &prompt_message)
+	prompt_message, err = parser.UnmarshalJsonBytes[PromptMessage]([]byte(user_message))
 	if err != nil {
 		t.Error(err)
 	}
@@ -60,7 +59,7 @@ func TestFullFunctionPromptMessage(t *testing.T) {
 		t.Error("role is not user")
 	}
 
-	err = json.Unmarshal([]byte(assistant_message), &prompt_message)
+	prompt_message, err = parser.UnmarshalJsonBytes[PromptMessage]([]byte(assistant_message))
 	if err != nil {
 		t.Error(err)
 	}
@@ -68,7 +67,7 @@ func TestFullFunctionPromptMessage(t *testing.T) {
 		t.Error("role is not assistant")
 	}
 
-	err = json.Unmarshal([]byte(image_message), &prompt_message)
+	prompt_message, err = parser.UnmarshalJsonBytes[PromptMessage]([]byte(image_message))
 	if err != nil {
 		t.Error(err)
 	}
@@ -79,7 +78,7 @@ func TestFullFunctionPromptMessage(t *testing.T) {
 		t.Error("type is not image")
 	}
 
-	err = json.Unmarshal([]byte(tool_message), &prompt_message)
+	prompt_message, err = parser.UnmarshalJsonBytes[PromptMessage]([]byte(tool_message))
 	if err != nil {
 		t.Error(err)
 	}
@@ -101,9 +100,7 @@ func TestWrongRole(t *testing.T) {
 		`
 	)
 
-	var prompt_message PromptMessage
-
-	err := json.Unmarshal([]byte(wrong_role), &prompt_message)
+	_, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_role))
 	if err == nil {
 		t.Error("error is nil")
 	}
@@ -119,9 +116,7 @@ func TestWrongContent(t *testing.T) {
 		`
 	)
 
-	var prompt_message PromptMessage
-
-	err := json.Unmarshal([]byte(wrong_content), &prompt_message)
+	_, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_content))
 	if err == nil {
 		t.Error("error is nil")
 	}
@@ -142,9 +137,7 @@ func TestWrongContentArray(t *testing.T) {
 		`
 	)
 
-	var prompt_message PromptMessage
-
-	err := json.Unmarshal([]byte(wrong_content_array), &prompt_message)
+	_, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_content_array))
 	if err == nil {
 		t.Error("error is nil")
 	}
@@ -164,9 +157,7 @@ func TestWrongContentArray2(t *testing.T) {
 		`
 	)
 
-	var prompt_message PromptMessage
-
-	err := json.Unmarshal([]byte(wrong_content_array2), &prompt_message)
+	_, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_content_array2))
 	if err == nil {
 		t.Error("error is nil")
 	}
@@ -191,9 +182,7 @@ func TestWrongContentArray3(t *testing.T) {
 		`
 	)
 
-	var prompt_message PromptMessage
-
-	err := json.Unmarshal([]byte(wrong_content_array3), &prompt_message)
+	_, err := parser.UnmarshalJsonBytes[PromptMessage]([]byte(wrong_content_array3))
 	if err == nil {
 		t.Error("error is nil")
 	}
@@ -241,9 +230,7 @@ func TestFullFunctionLLMResultChunk(t *testing.T) {
 		`
 	)
 
-	var c LLMResultChunk
-
-	err := json.Unmarshal([]byte(llm_result_chunk), &c)
+	_, err := parser.UnmarshalJsonBytes[LLMResultChunk]([]byte(llm_result_chunk))
 	if err != nil {
 		t.Error(err)
 	}
@@ -269,9 +256,7 @@ func TestZeroLLMUsage(t *testing.T) {
 		`
 	)
 
-	var u LLMUsage
-
-	err := json.Unmarshal([]byte(llm_usage), &u)
+	_, err := parser.UnmarshalJsonBytes[LLMUsage]([]byte(llm_usage))
 	if err != nil {
 		t.Error(err)
 	}

+ 1 - 0
internal/types/entities/model_entities/moderation.go

@@ -0,0 +1 @@
+package model_entities

+ 12 - 0
internal/types/entities/model_entities/rerank.go

@@ -0,0 +1,12 @@
+package model_entities
+
+type RerankDocument struct {
+	Index *int     `json:"index" validate:"required"`
+	Text  *string  `json:"text" validate:"required"`
+	Score *float64 `json:"score" validate:"required"`
+}
+
+type RerankResult struct {
+	Model string           `json:"model" validate:"required"`
+	Docs  []RerankDocument `json:"docs" validate:"required,dive"`
+}

+ 1 - 0
internal/types/entities/model_entities/speech2text.go

@@ -0,0 +1 @@
+package model_entities

+ 1 - 0
internal/types/entities/model_entities/text_embedding.go

@@ -0,0 +1 @@
+package model_entities

+ 1 - 0
internal/types/entities/model_entities/tts.go

@@ -0,0 +1 @@
+package model_entities

+ 0 - 17
internal/types/entities/plugin_entities/model_configuration.go

@@ -1,8 +1,6 @@
 package plugin_entities
 
 import (
-	"encoding/json"
-
 	"github.com/go-playground/locales/en"
 	ut "github.com/go-playground/universal-translator"
 	"github.com/go-playground/validator/v10"
@@ -278,18 +276,3 @@ func init() {
 
 	validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isGenericType)
 }
-
-func UnmarshalModelProviderConfiguration(data []byte) (*ModelProviderConfiguration, error) {
-	var modelProviderConfiguration ModelProviderConfiguration
-	err := json.Unmarshal(data, &modelProviderConfiguration)
-	if err != nil {
-		return nil, err
-	}
-
-	err = validators.GlobalEntitiesValidator.Struct(modelProviderConfiguration)
-	if err != nil {
-		return nil, err
-	}
-
-	return &modelProviderConfiguration, nil
-}

+ 2 - 1
internal/types/entities/plugin_entities/model_configuration_test.go

@@ -4,6 +4,7 @@ import (
 	"encoding/json"
 	"testing"
 
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 	"gopkg.in/yaml.v3"
 )
 
@@ -156,7 +157,7 @@ func TestFullFunctionModelProvider_Validate(t *testing.T) {
 		t.Error(err)
 	}
 
-	_, err = UnmarshalModelProviderConfiguration(json_data)
+	_, err = parser.UnmarshalJsonBytes[ModelProviderConfiguration](json_data)
 	if err != nil {
 		t.Errorf("UnmarshalModelProviderConfiguration() error = %v", err)
 	}

+ 0 - 19
internal/types/entities/plugin_entities/tool_configuration.go

@@ -1,7 +1,6 @@
 package plugin_entities
 
 import (
-	"encoding/json"
 	"fmt"
 
 	"github.com/go-playground/locales/en"
@@ -253,24 +252,6 @@ func init() {
 	validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isGenericType)
 }
 
-func (t *ToolProviderConfiguration) UnmarshalJSON(data []byte) error {
-	type Alias ToolProviderConfiguration
-	aux := &struct {
-		*Alias
-	}{
-		Alias: (*Alias)(t),
-	}
-	if err := json.Unmarshal(data, &aux); err != nil {
-		return err
-	}
-
-	if err := validators.GlobalEntitiesValidator.Struct(t); err != nil {
-		return err
-	}
-
-	return nil
-}
-
 func UnmarshalToolProviderConfiguration(data []byte) (*ToolProviderConfiguration, error) {
 	obj, err := parser.UnmarshalJsonBytes[ToolProviderConfiguration](data)
 	if err != nil {

+ 0 - 21
internal/types/entities/requests/model.go

@@ -1,10 +1,7 @@
 package requests
 
 import (
-	"encoding/json"
-
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
-	"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
 )
 
 type RequestInvokeLLM struct {
@@ -18,21 +15,3 @@ type RequestInvokeLLM struct {
 	Stream          bool                               `json:"stream"`
 	Credentials     map[string]any                     `json:"credentials" validate:"omitempty,dive,is_basic_type"`
 }
-
-func (r *RequestInvokeLLM) UnmarshalJSON(data []byte) error {
-	type Alias RequestInvokeLLM
-	aux := &struct {
-		*Alias
-	}{
-		Alias: (*Alias)(r),
-	}
-	if err := json.Unmarshal(data, &aux); err != nil {
-		return err
-	}
-
-	if err := validators.GlobalEntitiesValidator.Struct(r); err != nil {
-		return err
-	}
-
-	return nil
-}

+ 0 - 24
internal/types/entities/requests/tool.go

@@ -1,32 +1,8 @@
 package requests
 
-import (
-	"encoding/json"
-
-	"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
-)
-
 type RequestInvokeTool struct {
 	Provider       string         `json:"provider" validate:"required"`
 	Tool           string         `json:"tool" validate:"required"`
 	ToolParameters map[string]any `json:"tool_parameters" validate:"omitempty,dive,is_basic_type"`
 	Credentials    map[string]any `json:"credentials" validate:"omitempty,dive,is_basic_type"`
 }
-
-func (r *RequestInvokeTool) UnmarshalJSON(data []byte) error {
-	type Alias RequestInvokeTool
-	aux := &struct {
-		*Alias
-	}{
-		Alias: (*Alias)(r),
-	}
-	if err := json.Unmarshal(data, &aux); err != nil {
-		return err
-	}
-
-	if err := validators.GlobalEntitiesValidator.Struct(r); err != nil {
-		return err
-	}
-
-	return nil
-}

+ 13 - 1
internal/utils/parser/json.go

@@ -1,6 +1,10 @@
 package parser
 
-import "encoding/json"
+import (
+	"encoding/json"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
+)
 
 func UnmarshalJson[T any](text string) (T, error) {
 	return UnmarshalJsonBytes[T]([]byte(text))
@@ -9,6 +13,14 @@ func UnmarshalJson[T any](text string) (T, error) {
 func UnmarshalJsonBytes[T any](data []byte) (T, error) {
 	var result T
 	err := json.Unmarshal(data, &result)
+	if err != nil {
+		return result, err
+	}
+
+	if err := validators.GlobalEntitiesValidator.Struct(&result); err != nil {
+		return result, err
+	}
+
 	return result, err
 }
 

+ 13 - 3
internal/utils/parser/yaml.go

@@ -1,16 +1,26 @@
 package parser
 
 import (
+	"github.com/go-playground/validator/v10"
 	"gopkg.in/yaml.v3"
 )
 
-func UnmarshalYaml[T any](text string) (T, error) {
-	return UnmarshalYamlBytes[T]([]byte(text))
+func UnmarshalYaml[T any](text string, validator ...validator.Validate) (T, error) {
+	return UnmarshalYamlBytes[T]([]byte(text), validator...)
 }
 
-func UnmarshalYamlBytes[T any](data []byte) (T, error) {
+func UnmarshalYamlBytes[T any](data []byte, validator ...validator.Validate) (T, error) {
 	var result T
 	err := yaml.Unmarshal(data, &result)
+	if err != nil {
+		return result, err
+	}
+
+	if len(validator) > 0 {
+		if err := validator[0].Struct(result); err != nil {
+			return result, err
+		}
+	}
 	return result, err
 }