ソースを参照

feat: validate json schema

Yeuoly 10 ヶ月 前
コミット
cdfdbad7eb

+ 1 - 1
internal/core/plugin_daemon/endpoint_service.go

@@ -29,7 +29,7 @@ func InvokeEndpoint(
 
 	status_code := http.StatusContinue
 	headers := &http.Header{}
-	response := stream.NewStreamResponse[[]byte](128)
+	response := stream.NewStream[[]byte](128)
 	response.OnClose(func() {
 		// add close callback, ensure resources are released
 		resp.Close()

+ 1 - 1
internal/core/plugin_daemon/generic.go

@@ -23,7 +23,7 @@ func genericInvokePlugin[Req any, Rsp any](
 		return nil, errors.New("plugin not found")
 	}
 
-	response := stream.NewStreamResponse[Rsp](response_buffer_size)
+	response := stream.NewStream[Rsp](response_buffer_size)
 
 	listener := runtime.Listen(session.ID)
 	listener.Listen(func(chunk plugin_entities.SessionMessage) {

+ 10 - 3
internal/core/plugin_daemon/tool_service.go

@@ -47,6 +47,16 @@ func InvokeTool(
 		}
 	}
 
+	// bind json schema validator
+	bindValidator(response, tool_output_schema)
+
+	return response, nil
+}
+
+func bindValidator(
+	response *stream.Stream[tool_entities.ToolResponseChunk],
+	tool_output_schema plugin_entities.ToolOutputSchema,
+) {
 	// check if the tool_output_schema is valid
 	variables := make(map[string]any)
 
@@ -108,9 +118,6 @@ func InvokeTool(
 			return
 		}
 	})
-
-	return response, nil
-
 }
 
 func ValidateToolCredentials(

+ 78 - 0
internal/core/plugin_daemon/tool_service_test.go

@@ -0,0 +1,78 @@
+package plugin_daemon
+
+import (
+	"testing"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
+)
+
+func TestToolInvokeJSONSchemaValidator(t *testing.T) {
+	response := stream.NewStream[tool_entities.ToolResponseChunk](128)
+
+	bindValidator(response, map[string]any{
+		"output_schema": map[string]any{
+			"type": "object",
+			"properties": map[string]any{
+				"name": map[string]any{
+					"type": "string",
+				},
+			},
+		},
+	})
+
+	response.Write(tool_entities.ToolResponseChunk{
+		Type: tool_entities.ToolResponseChunkTypeVariable,
+		Message: map[string]any{
+			"variable_name":  "name",
+			"variable_value": "1",
+			"stream":         true,
+		},
+	})
+	response.Close()
+
+	for response.Next() {
+		data, err := response.Read()
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		t.Log(data)
+	}
+}
+
+func TestToolInvokeJSONSchemaValidatorWithInvalidSchema(t *testing.T) {
+	response := stream.NewStream[tool_entities.ToolResponseChunk](128)
+
+	bindValidator(response, map[string]any{
+		"output_schema": map[string]any{
+			"type": "object",
+			"properties": map[string]any{
+				"name": map[string]any{
+					"type": "string",
+				},
+			},
+		},
+	})
+
+	response.Write(tool_entities.ToolResponseChunk{
+		Type: tool_entities.ToolResponseChunkTypeVariable,
+		Message: map[string]any{
+			"variable_name":  "name",
+			"variable_value": 1,
+			"stream":         false,
+		},
+	})
+
+	response.Close()
+
+	_, err := response.Read()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	_, err = response.Read()
+	if err == nil {
+		t.Fatal("expected error, got nil")
+	}
+}

+ 1 - 1
internal/core/plugin_manager/remote_manager/hooks.go

@@ -60,7 +60,7 @@ func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
 		),
 
 		conn:           c,
-		response:       stream.NewStreamResponse[[]byte](512),
+		response:       stream.NewStream[[]byte](512),
 		callbacks:      make(map[string][]func([]byte)),
 		callbacks_lock: &sync.RWMutex{},
 

+ 1 - 1
internal/core/plugin_manager/remote_manager/server.go

@@ -85,7 +85,7 @@ func NewRemotePluginServer(config *app.Config, media_manager *media_manager.Medi
 		config.PluginRemoteInstallingPort,
 	)
 
-	response := stream.NewStreamResponse[*RemotePluginRuntime](
+	response := stream.NewStream[*RemotePluginRuntime](
 		config.PluginRemoteInstallingMaxConn,
 	)
 

+ 1 - 1
internal/core/plugin_manager/serverless/upload.go

@@ -41,7 +41,7 @@ func UploadPlugin(decoder decoder.PluginDecoder) (*stream.Stream[LaunchAWSLambda
 		}
 	} else {
 		// found, return directly
-		response := stream.NewStreamResponse[LaunchAWSLambdaFunctionResponse](2)
+		response := stream.NewStream[LaunchAWSLambdaFunctionResponse](2)
 		response.Write(LaunchAWSLambdaFunctionResponse{
 			Event:   LambdaUrl,
 			Message: function.FunctionURL,

+ 1 - 1
internal/utils/http_requests/http_warpper.go

@@ -86,7 +86,7 @@ func RequestAndParseStream[T any](client *http.Client, url string, method string
 		return nil, fmt.Errorf("request failed with status code: %d and respond with: %s", resp.StatusCode, error_text)
 	}
 
-	ch := stream.NewStreamResponse[T](1024)
+	ch := stream.NewStream[T](1024)
 
 	// get read timeout
 	read_timeout := int64(60000)

+ 1 - 1
internal/utils/stream/response.go

@@ -23,7 +23,7 @@ type Stream[T any] struct {
 	err error
 }
 
-func NewStreamResponse[T any](max int) *Stream[T] {
+func NewStream[T any](max int) *Stream[T] {
 	return &Stream[T]{
 		l:   &sync.Mutex{},
 		sig: make(chan bool),

+ 3 - 3
internal/utils/stream/response_test.go

@@ -8,7 +8,7 @@ import (
 )
 
 func TestStreamGenerator(t *testing.T) {
-	response := NewStreamResponse[int](512)
+	response := NewStream[int](512)
 
 	wg := sync.WaitGroup{}
 	wg.Add(2)
@@ -50,7 +50,7 @@ func TestStreamGenerator(t *testing.T) {
 }
 
 func TestStreamGeneratorErrorMessage(t *testing.T) {
-	response := NewStreamResponse[int](512)
+	response := NewStream[int](512)
 
 	go func() {
 		for i := 0; i < 10000; i++ {
@@ -72,7 +72,7 @@ func TestStreamGeneratorErrorMessage(t *testing.T) {
 }
 
 func TestStreamGeneratorWrapper(t *testing.T) {
-	response := NewStreamResponse[int](512)
+	response := NewStream[int](512)
 	nums := 0
 
 	go func() {

+ 1 - 1
tests/benchmark/stream/stream_response_test.go

@@ -11,7 +11,7 @@ func BenchmarkStreamResponse(b *testing.B) {
 	b.Run("Read", func(b *testing.B) {
 		wg_started := sync.WaitGroup{}
 		wg_started.Add(8)
-		resp := stream.NewStreamResponse[int](1024)
+		resp := stream.NewStream[int](1024)
 
 		for i := 0; i < 8; i++ {
 			go func() {