Переглянути джерело

feat: support validate json schema

Yeuoly 10 місяців тому
батько
коміт
fb27adfd48

+ 4 - 4
internal/core/plugin_daemon/backwards_invocation/task.go

@@ -258,7 +258,7 @@ func executeDifyInvocationToolTask(
 		return
 	}
 
-	response.Wrap(func(t tool_entities.ToolResponseChunk) {
+	response.Async(func(t tool_entities.ToolResponseChunk) {
 		handle.WriteResponse("stream", t)
 	})
 }
@@ -273,7 +273,7 @@ func executeDifyInvocationLLMTask(
 		return
 	}
 
-	response.Wrap(func(t model_entities.LLMResultChunk) {
+	response.Async(func(t model_entities.LLMResultChunk) {
 		handle.WriteResponse("stream", t)
 	})
 }
@@ -314,7 +314,7 @@ func executeDifyInvocationTTSTask(
 		return
 	}
 
-	response.Wrap(func(t model_entities.TTSResult) {
+	response.Async(func(t model_entities.TTSResult) {
 		handle.WriteResponse("struct", t)
 	})
 }
@@ -363,7 +363,7 @@ func executeDifyInvocationAppTask(
 
 	request.User = user_id
 
-	response.Wrap(func(t map[string]any) {
+	response.Async(func(t map[string]any) {
 		handle.WriteResponse("stream", t)
 	})
 }

+ 94 - 1
internal/core/plugin_daemon/tool_service.go

@@ -1,10 +1,15 @@
 package plugin_daemon
 
 import (
+	"errors"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
+	"github.com/xeipuuv/gojsonschema"
 )
 
 func InvokeTool(
@@ -13,11 +18,99 @@ func InvokeTool(
 ) (
 	*stream.Stream[tool_entities.ToolResponseChunk], error,
 ) {
-	return genericInvokePlugin[requests.RequestInvokeTool, tool_entities.ToolResponseChunk](
+	runtime := plugin_manager.Manager().Get(session.PluginUniqueIdentifier)
+	if runtime == nil {
+		return nil, errors.New("plugin not found")
+	}
+
+	response, err := genericInvokePlugin[
+		requests.RequestInvokeTool, tool_entities.ToolResponseChunk,
+	](
 		session,
 		request,
 		128,
 	)
+
+	if err != nil {
+		return nil, err
+	}
+
+	tool_declaration := runtime.Configuration().Tool
+	if tool_declaration == nil {
+		return nil, errors.New("tool declaration not found")
+	}
+
+	var tool_output_schema plugin_entities.ToolOutputSchema
+	for _, v := range tool_declaration.Tools {
+		if v.Identity.Name == request.Tool {
+			tool_output_schema = v.OutputSchema
+		}
+	}
+
+	// check if the tool_output_schema is valid
+	variables := make(map[string]any)
+
+	response.Filter(func(trc tool_entities.ToolResponseChunk) error {
+		if trc.Type == tool_entities.ToolResponseChunkTypeVariable {
+			variable_name, ok := trc.Message["variable_name"].(string)
+			if !ok {
+				return errors.New("variable name is not a string")
+			}
+			stream, ok := trc.Message["stream"].(bool)
+			if !ok {
+				return errors.New("stream is not a boolean")
+			}
+
+			if stream {
+				// ensure variable_value is a string
+				variable_value, ok := trc.Message["variable_value"].(string)
+				if !ok {
+					return errors.New("variable value is not a string")
+				}
+
+				// create it if not exists
+				if _, ok := variables[variable_name]; !ok {
+					variables[variable_name] = ""
+				}
+
+				original_value, ok := variables[variable_name].(string)
+				if !ok {
+					return errors.New("variable value is not a string")
+				}
+
+				// add the variable value to the variable
+				variables[variable_name] = original_value + variable_value
+			} else {
+				variables[variable_name] = trc.Message["variable_value"]
+			}
+		}
+
+		return nil
+	})
+
+	response.BeforeClose(func() {
+		// validate the variables
+		schema, err := gojsonschema.NewSchema(gojsonschema.NewGoLoader(tool_output_schema))
+		if err != nil {
+			response.WriteError(err)
+			return
+		}
+
+		// validate the variables
+		result, err := schema.Validate(gojsonschema.NewGoLoader(variables))
+		if err != nil {
+			response.WriteError(err)
+			return
+		}
+
+		if !result.Valid() {
+			response.WriteError(errors.New("tool output schema is not valid"))
+			return
+		}
+	})
+
+	return response, nil
+
 }
 
 func ValidateToolCredentials(

+ 1 - 1
internal/core/plugin_manager/remote_manager/run.go

@@ -54,7 +54,7 @@ func (r *RemotePluginRuntime) StartPlugin() error {
 		}
 	})
 
-	r.response.Wrap(func(data []byte) {
+	r.response.Async(func(data []byte) {
 		// handle event
 		event, err := parser.UnmarshalJsonBytes[plugin_entities.PluginUniversalEvent](data)
 		if err != nil {

+ 1 - 1
internal/core/plugin_manager/remote_manager/server.go

@@ -40,7 +40,7 @@ func (r *RemotePluginServer) Next() bool {
 
 // Wrap wraps the wrap method of stream response
 func (r *RemotePluginServer) Wrap(f func(*RemotePluginRuntime)) {
-	r.server.response.Wrap(f)
+	r.server.response.Async(f)
 }
 
 // Stop stops the server

+ 174 - 13
internal/service/invoke_model.go

@@ -4,22 +4,22 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
-	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
-func InvokeTool(
-	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTool],
+func InvokeLLM(
+	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM],
 	ctx *gin.Context,
 	max_timeout_seconds int,
 ) {
 	// create session
 	session, err := createSession(
 		r,
-		access_types.PLUGIN_ACCESS_TYPE_TOOL,
-		access_types.PLUGIN_ACCESS_ACTION_INVOKE_TOOL,
+		access_types.PLUGIN_ACCESS_TYPE_MODEL,
+		access_types.PLUGIN_ACCESS_ACTION_INVOKE_LLM,
 		ctx.GetString("cluster_id"),
 	)
 	if err != nil {
@@ -29,24 +29,185 @@ func InvokeTool(
 	defer session.Close()
 
 	baseSSEService(
-		func() (*stream.Stream[tool_entities.ToolResponseChunk], error) {
-			return plugin_daemon.InvokeTool(session, &r.Data)
+		func() (*stream.Stream[model_entities.LLMResultChunk], error) {
+			return plugin_daemon.InvokeLLM(session, &r.Data)
 		},
 		ctx,
 		max_timeout_seconds,
 	)
 }
 
-func ValidateToolCredentials(
-	r *plugin_entities.InvokePluginRequest[requests.RequestValidateToolCredentials],
+func InvokeTextEmbedding(
+	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding],
 	ctx *gin.Context,
 	max_timeout_seconds int,
 ) {
 	// create session
 	session, err := createSession(
 		r,
-		access_types.PLUGIN_ACCESS_TYPE_TOOL,
-		access_types.PLUGIN_ACCESS_ACTION_VALIDATE_TOOL_CREDENTIALS,
+		access_types.PLUGIN_ACCESS_TYPE_MODEL,
+		access_types.PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
+		ctx.GetString("cluster_id"))
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
+	defer session.Close()
+
+	baseSSEService(
+		func() (*stream.Stream[model_entities.TextEmbeddingResult], error) {
+			return plugin_daemon.InvokeTextEmbedding(session, &r.Data)
+		},
+		ctx,
+		max_timeout_seconds,
+	)
+}
+
+func InvokeRerank(
+	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank],
+	ctx *gin.Context,
+	max_timeout_seconds int,
+) {
+	// create session
+	session, err := createSession(
+		r,
+		access_types.PLUGIN_ACCESS_TYPE_MODEL,
+		access_types.PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
+		ctx.GetString("cluster_id"),
+	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
+	defer session.Close()
+
+	baseSSEService(
+		func() (*stream.Stream[model_entities.RerankResult], error) {
+			return plugin_daemon.InvokeRerank(session, &r.Data)
+		},
+		ctx,
+		max_timeout_seconds,
+	)
+}
+
+func InvokeTTS(
+	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS],
+	ctx *gin.Context,
+	max_timeout_seconds int,
+) {
+	// create session
+	session, err := createSession(
+		r,
+		access_types.PLUGIN_ACCESS_TYPE_MODEL,
+		access_types.PLUGIN_ACCESS_ACTION_INVOKE_TTS,
+		ctx.GetString("cluster_id"),
+	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
+	defer session.Close()
+
+	baseSSEService(
+		func() (*stream.Stream[model_entities.TTSResult], error) {
+			return plugin_daemon.InvokeTTS(session, &r.Data)
+		},
+		ctx,
+		max_timeout_seconds,
+	)
+}
+
+func InvokeSpeech2Text(
+	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text],
+	ctx *gin.Context,
+	max_timeout_seconds int,
+) {
+	// create session
+	session, err := createSession(
+		r,
+		access_types.PLUGIN_ACCESS_TYPE_MODEL,
+		access_types.PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
+		ctx.GetString("cluster_id"),
+	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
+	defer session.Close()
+
+	baseSSEService(
+		func() (*stream.Stream[model_entities.Speech2TextResult], error) {
+			return plugin_daemon.InvokeSpeech2Text(session, &r.Data)
+		},
+		ctx,
+		max_timeout_seconds,
+	)
+}
+
+func InvokeModeration(
+	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration],
+	ctx *gin.Context,
+	max_timeout_seconds int,
+) {
+	// create session
+	session, err := createSession(
+		r,
+		access_types.PLUGIN_ACCESS_TYPE_MODEL,
+		access_types.PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
+		ctx.GetString("cluster_id"),
+	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
+	defer session.Close()
+
+	baseSSEService(
+		func() (*stream.Stream[model_entities.ModerationResult], error) {
+			return plugin_daemon.InvokeModeration(session, &r.Data)
+		},
+		ctx,
+		max_timeout_seconds,
+	)
+}
+
+func ValidateProviderCredentials(
+	r *plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials],
+	ctx *gin.Context,
+	max_timeout_seconds int,
+) {
+	// create session
+	session, err := createSession(
+		r,
+		access_types.PLUGIN_ACCESS_TYPE_MODEL,
+		access_types.PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
+		ctx.GetString("cluster_id"),
+	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
+	defer session.Close()
+
+	baseSSEService(
+		func() (*stream.Stream[model_entities.ValidateCredentialsResult], error) {
+			return plugin_daemon.ValidateProviderCredentials(session, &r.Data)
+		},
+		ctx,
+		max_timeout_seconds,
+	)
+}
+
+func ValidateModelCredentials(
+	r *plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials],
+	ctx *gin.Context,
+	max_timeout_seconds int,
+) {
+	// create session
+	session, err := createSession(
+		r,
+		access_types.PLUGIN_ACCESS_TYPE_MODEL,
+		access_types.PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
 		ctx.GetString("cluster_id"),
 	)
 	if err != nil {
@@ -56,8 +217,8 @@ func ValidateToolCredentials(
 	defer session.Close()
 
 	baseSSEService(
-		func() (*stream.Stream[tool_entities.ValidateCredentialsResult], error) {
-			return plugin_daemon.ValidateToolCredentials(session, &r.Data)
+		func() (*stream.Stream[model_entities.ValidateCredentialsResult], error) {
+			return plugin_daemon.ValidateModelCredentials(session, &r.Data)
 		},
 		ctx,
 		max_timeout_seconds,

+ 13 - 174
internal/service/invoke_tool.go

@@ -6,9 +6,9 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
-	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
@@ -34,16 +34,16 @@ func createSession[T any](
 	return session, nil
 }
 
-func InvokeLLM(
-	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM],
+func InvokeTool(
+	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTool],
 	ctx *gin.Context,
 	max_timeout_seconds int,
 ) {
 	// create session
 	session, err := createSession(
 		r,
-		access_types.PLUGIN_ACCESS_TYPE_MODEL,
-		access_types.PLUGIN_ACCESS_ACTION_INVOKE_LLM,
+		access_types.PLUGIN_ACCESS_TYPE_TOOL,
+		access_types.PLUGIN_ACCESS_ACTION_INVOKE_TOOL,
 		ctx.GetString("cluster_id"),
 	)
 	if err != nil {
@@ -53,185 +53,24 @@ func InvokeLLM(
 	defer session.Close()
 
 	baseSSEService(
-		func() (*stream.Stream[model_entities.LLMResultChunk], error) {
-			return plugin_daemon.InvokeLLM(session, &r.Data)
+		func() (*stream.Stream[tool_entities.ToolResponseChunk], error) {
+			return plugin_daemon.InvokeTool(session, &r.Data)
 		},
 		ctx,
 		max_timeout_seconds,
 	)
 }
 
-func InvokeTextEmbedding(
-	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding],
+func ValidateToolCredentials(
+	r *plugin_entities.InvokePluginRequest[requests.RequestValidateToolCredentials],
 	ctx *gin.Context,
 	max_timeout_seconds int,
 ) {
 	// create session
 	session, err := createSession(
 		r,
-		access_types.PLUGIN_ACCESS_TYPE_MODEL,
-		access_types.PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
-		ctx.GetString("cluster_id"))
-	if err != nil {
-		ctx.JSON(500, gin.H{"error": err.Error()})
-		return
-	}
-	defer session.Close()
-
-	baseSSEService(
-		func() (*stream.Stream[model_entities.TextEmbeddingResult], error) {
-			return plugin_daemon.InvokeTextEmbedding(session, &r.Data)
-		},
-		ctx,
-		max_timeout_seconds,
-	)
-}
-
-func InvokeRerank(
-	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank],
-	ctx *gin.Context,
-	max_timeout_seconds int,
-) {
-	// create session
-	session, err := createSession(
-		r,
-		access_types.PLUGIN_ACCESS_TYPE_MODEL,
-		access_types.PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
-		ctx.GetString("cluster_id"),
-	)
-	if err != nil {
-		ctx.JSON(500, gin.H{"error": err.Error()})
-		return
-	}
-	defer session.Close()
-
-	baseSSEService(
-		func() (*stream.Stream[model_entities.RerankResult], error) {
-			return plugin_daemon.InvokeRerank(session, &r.Data)
-		},
-		ctx,
-		max_timeout_seconds,
-	)
-}
-
-func InvokeTTS(
-	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS],
-	ctx *gin.Context,
-	max_timeout_seconds int,
-) {
-	// create session
-	session, err := createSession(
-		r,
-		access_types.PLUGIN_ACCESS_TYPE_MODEL,
-		access_types.PLUGIN_ACCESS_ACTION_INVOKE_TTS,
-		ctx.GetString("cluster_id"),
-	)
-	if err != nil {
-		ctx.JSON(500, gin.H{"error": err.Error()})
-		return
-	}
-	defer session.Close()
-
-	baseSSEService(
-		func() (*stream.Stream[model_entities.TTSResult], error) {
-			return plugin_daemon.InvokeTTS(session, &r.Data)
-		},
-		ctx,
-		max_timeout_seconds,
-	)
-}
-
-func InvokeSpeech2Text(
-	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text],
-	ctx *gin.Context,
-	max_timeout_seconds int,
-) {
-	// create session
-	session, err := createSession(
-		r,
-		access_types.PLUGIN_ACCESS_TYPE_MODEL,
-		access_types.PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
-		ctx.GetString("cluster_id"),
-	)
-	if err != nil {
-		ctx.JSON(500, gin.H{"error": err.Error()})
-		return
-	}
-	defer session.Close()
-
-	baseSSEService(
-		func() (*stream.Stream[model_entities.Speech2TextResult], error) {
-			return plugin_daemon.InvokeSpeech2Text(session, &r.Data)
-		},
-		ctx,
-		max_timeout_seconds,
-	)
-}
-
-func InvokeModeration(
-	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration],
-	ctx *gin.Context,
-	max_timeout_seconds int,
-) {
-	// create session
-	session, err := createSession(
-		r,
-		access_types.PLUGIN_ACCESS_TYPE_MODEL,
-		access_types.PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
-		ctx.GetString("cluster_id"),
-	)
-	if err != nil {
-		ctx.JSON(500, gin.H{"error": err.Error()})
-		return
-	}
-	defer session.Close()
-
-	baseSSEService(
-		func() (*stream.Stream[model_entities.ModerationResult], error) {
-			return plugin_daemon.InvokeModeration(session, &r.Data)
-		},
-		ctx,
-		max_timeout_seconds,
-	)
-}
-
-func ValidateProviderCredentials(
-	r *plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials],
-	ctx *gin.Context,
-	max_timeout_seconds int,
-) {
-	// create session
-	session, err := createSession(
-		r,
-		access_types.PLUGIN_ACCESS_TYPE_MODEL,
-		access_types.PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
-		ctx.GetString("cluster_id"),
-	)
-	if err != nil {
-		ctx.JSON(500, gin.H{"error": err.Error()})
-		return
-	}
-	defer session.Close()
-
-	baseSSEService(
-		func() (*stream.Stream[model_entities.ValidateCredentialsResult], error) {
-			return plugin_daemon.ValidateProviderCredentials(session, &r.Data)
-		},
-		ctx,
-		max_timeout_seconds,
-	)
-}
-
-func ValidateModelCredentials(
-	r *plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials],
-	ctx *gin.Context,
-	max_timeout_seconds int,
-) {
-	// create session
-	session, err := createSession(
-		r,
-		access_types.PLUGIN_ACCESS_TYPE_MODEL,
-		access_types.PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
+		access_types.PLUGIN_ACCESS_TYPE_TOOL,
+		access_types.PLUGIN_ACCESS_ACTION_VALIDATE_TOOL_CREDENTIALS,
 		ctx.GetString("cluster_id"),
 	)
 	if err != nil {
@@ -241,8 +80,8 @@ func ValidateModelCredentials(
 	defer session.Close()
 
 	baseSSEService(
-		func() (*stream.Stream[model_entities.ValidateCredentialsResult], error) {
-			return plugin_daemon.ValidateModelCredentials(session, &r.Data)
+		func() (*stream.Stream[tool_entities.ValidateCredentialsResult], error) {
+			return plugin_daemon.ValidateToolCredentials(session, &r.Data)
 		},
 		ctx,
 		max_timeout_seconds,

+ 44 - 2
internal/types/entities/tool_entities/tool.go

@@ -1,6 +1,48 @@
 package tool_entities
 
+import (
+	"github.com/go-playground/validator/v10"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
+)
+
+type ToolResponseChunkType string
+
+const (
+	ToolResponseChunkTypeText      ToolResponseChunkType = "text"
+	ToolResponseChunkTypeFile      ToolResponseChunkType = "file"
+	ToolResponseChunkTypeBlob      ToolResponseChunkType = "blob"
+	ToolResponseChunkTypeJson      ToolResponseChunkType = "json"
+	ToolResponseChunkTypeLink      ToolResponseChunkType = "link"
+	ToolResponseChunkTypeImage     ToolResponseChunkType = "image"
+	ToolResponseChunkTypeImageLink ToolResponseChunkType = "image_link"
+	ToolResponseChunkTypeVariable  ToolResponseChunkType = "variable"
+)
+
+func IsValidToolResponseChunkType(fl validator.FieldLevel) bool {
+	t := fl.Field().String()
+	switch ToolResponseChunkType(t) {
+	case ToolResponseChunkTypeText,
+		ToolResponseChunkTypeFile,
+		ToolResponseChunkTypeBlob,
+		ToolResponseChunkTypeJson,
+		ToolResponseChunkTypeLink,
+		ToolResponseChunkTypeImage,
+		ToolResponseChunkTypeImageLink,
+		ToolResponseChunkTypeVariable:
+		return true
+	default:
+		return false
+	}
+}
+
+func init() {
+	validators.GlobalEntitiesValidator.RegisterValidation(
+		"is_valid_tool_response_chunk_type",
+		IsValidToolResponseChunkType,
+	)
+}
+
 type ToolResponseChunk struct {
-	Type    string         `json:"type"`
-	Message map[string]any `json:"message"`
+	Type    ToolResponseChunkType `json:"type" validate:"required,is_valid_tool_response_chunk_type"`
+	Message map[string]any        `json:"message"`
 }

+ 5 - 0
internal/utils/parser/map2struct.go

@@ -3,6 +3,7 @@ package parser
 import (
 	"fmt"
 
+	"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
 	"github.com/mitchellh/mapstructure"
 )
 
@@ -25,5 +26,9 @@ func MapToStruct[T any](m map[string]any) (*T, error) {
 		return nil, fmt.Errorf("error decoding map: %s", err)
 	}
 
+	if err := validators.GlobalEntitiesValidator.Struct(s); err != nil {
+		return nil, fmt.Errorf("error validating struct: %s", err)
+	}
+
 	return &s, nil
 }

+ 36 - 7
internal/utils/stream/response.go

@@ -15,8 +15,12 @@ type Stream[T any] struct {
 	closed    int32
 	max       int
 	listening bool
-	onClose   func()
-	err       error
+
+	onClose     []func()
+	beforeClose []func()
+	filter      []func(T) error
+
+	err error
 }
 
 func NewStreamResponse[T any](max int) *Stream[T] {
@@ -27,8 +31,20 @@ func NewStreamResponse[T any](max int) *Stream[T] {
 	}
 }
 
+// Filter filters the stream with a function
+// if the function returns an error, the stream will be closed
+func (r *Stream[T]) Filter(f func(T) error) {
+	r.filter = append(r.filter, f)
+}
+
+// OnClose adds a function to be called when the stream is closed
 func (r *Stream[T]) OnClose(f func()) {
-	r.onClose = f
+	r.onClose = append(r.onClose, f)
+}
+
+// BeforeClose adds a function to be called before the stream is closed
+func (r *Stream[T]) BeforeClose(f func()) {
+	r.beforeClose = append(r.beforeClose, f)
 }
 
 // Next returns true if there are more data to be read
@@ -64,6 +80,14 @@ func (r *Stream[T]) Read() (T, error) {
 
 	if r.q.Len() > 0 {
 		data := r.q.PopFront()
+		for _, f := range r.filter {
+			err := f(data)
+			if err != nil {
+				// close the stream
+				r.Close()
+				return data, err
+			}
+		}
 		return data, nil
 	} else {
 		var data T
@@ -77,8 +101,8 @@ func (r *Stream[T]) Read() (T, error) {
 	}
 }
 
-// Wrap wraps the stream with a new stream, and allows customized operations
-func (r *Stream[T]) Wrap(fn func(T)) error {
+// Async wraps the stream with a new stream, and allows customized operations
+func (r *Stream[T]) Async(fn func(T)) error {
 	if atomic.LoadInt32(&r.closed) == 1 {
 		return errors.New("stream is closed")
 	}
@@ -124,13 +148,18 @@ func (r *Stream[T]) Close() {
 		return
 	}
 
+	for _, f := range r.beforeClose {
+		f()
+	}
+
 	select {
 	case r.sig <- false:
 	default:
 	}
 	close(r.sig)
-	if r.onClose != nil {
-		r.onClose()
+
+	for _, f := range r.onClose {
+		f()
 	}
 }
 

+ 1 - 1
internal/utils/stream/response_test.go

@@ -83,7 +83,7 @@ func TestStreamGeneratorWrapper(t *testing.T) {
 		response.Close()
 	}()
 
-	response.Wrap(func(t int) {
+	response.Async(func(t int) {
 		nums += 1
 	})