소스 검색

feat: replace next-read operation pair to wrap

Yeuoly 1 년 전
부모
커밋
bd9152b1ac

+ 91 - 15
internal/core/plugin_daemon/invoke_dify.go

@@ -7,6 +7,7 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
@@ -71,6 +72,47 @@ func prepareDifyInvocationArguments(session *session_manager.Session, request ma
 	), nil
 }
 
+var (
+	dispatchMapping = map[dify_invocation.InvokeType]func(handle *backwards_invocation.BackwardsInvocation){
+		dify_invocation.INVOKE_TYPE_TOOL: func(handle *backwards_invocation.BackwardsInvocation) {
+			genericDispatchTask[dify_invocation.InvokeToolRequest](handle, executeDifyInvocationToolTask)
+		},
+		dify_invocation.INVOKE_TYPE_LLM: func(handle *backwards_invocation.BackwardsInvocation) {
+			genericDispatchTask[dify_invocation.InvokeLLMRequest](handle, executeDifyInvocationLLMTask)
+		},
+		dify_invocation.INVOKE_TYPE_TEXT_EMBEDDING: func(handle *backwards_invocation.BackwardsInvocation) {
+			genericDispatchTask[dify_invocation.InvokeTextEmbeddingRequest](handle, executeDifyInvocationTextEmbeddingTask)
+		},
+		dify_invocation.INVOKE_TYPE_RERANK: func(handle *backwards_invocation.BackwardsInvocation) {
+			genericDispatchTask[dify_invocation.InvokeRerankRequest](handle, executeDifyInvocationRerankTask)
+		},
+		dify_invocation.INVOKE_TYPE_TTS: func(handle *backwards_invocation.BackwardsInvocation) {
+			genericDispatchTask[dify_invocation.InvokeTTSRequest](handle, executeDifyInvocationTTSTask)
+		},
+		dify_invocation.INVOKE_TYPE_SPEECH2TEXT: func(handle *backwards_invocation.BackwardsInvocation) {
+			genericDispatchTask[dify_invocation.InvokeSpeech2TextRequest](handle, executeDifyInvocationSpeech2TextTask)
+		},
+		dify_invocation.INVOKE_TYPE_MODERATION: func(handle *backwards_invocation.BackwardsInvocation) {
+			genericDispatchTask[dify_invocation.InvokeModerationRequest](handle, executeDifyInvocationModerationTask)
+		},
+	}
+)
+
+func genericDispatchTask[T any](
+	handle *backwards_invocation.BackwardsInvocation,
+	dispatch func(
+		handle *backwards_invocation.BackwardsInvocation,
+		request *T,
+	),
+) {
+	r, err := parser.MapToStruct[T](handle.RequestData())
+	if err != nil {
+		handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error()))
+		return
+	}
+	dispatch(handle, r)
+}
+
 func dispatchDifyInvocationTask(handle *backwards_invocation.BackwardsInvocation) {
 	request_data := handle.RequestData()
 	tenant_id, err := handle.TenantID()
@@ -86,17 +128,14 @@ func dispatchDifyInvocationTask(handle *backwards_invocation.BackwardsInvocation
 	}
 	request_data["user_id"] = user_id
 
-	switch handle.Type() {
-	case dify_invocation.INVOKE_TYPE_TOOL:
-		r, err := parser.MapToStruct[dify_invocation.InvokeToolRequest](handle.RequestData())
-		if err != nil {
-			handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error()))
+	for t, v := range dispatchMapping {
+		if t == handle.Type() {
+			v(handle)
 			return
 		}
-		executeDifyInvocationToolTask(handle, r)
-	default:
-		handle.WriteError(fmt.Errorf("unsupported invoke type: %s", handle.Type()))
 	}
+
+	handle.WriteError(fmt.Errorf("unsupported invoke type: %s", handle.Type()))
 }
 
 func executeDifyInvocationToolTask(
@@ -109,12 +148,49 @@ func executeDifyInvocationToolTask(
 		return
 	}
 
-	for response.Next() {
-		data, err := response.Read()
-		if err != nil {
-			return
-		}
+	response.Wrap(func(t tool_entities.ToolResponseChunk) {
+		handle.WriteResponse("stream", t)
+	})
+}
+
+func executeDifyInvocationLLMTask(
+	handle *backwards_invocation.BackwardsInvocation,
+	request *dify_invocation.InvokeLLMRequest,
+) {
+
+}
+
+func executeDifyInvocationTextEmbeddingTask(
+	handle *backwards_invocation.BackwardsInvocation,
+	request *dify_invocation.InvokeTextEmbeddingRequest,
+) {
+
+}
+
+func executeDifyInvocationRerankTask(
+	handle *backwards_invocation.BackwardsInvocation,
+	request *dify_invocation.InvokeRerankRequest,
+) {
+
+}
+
+func executeDifyInvocationTTSTask(
+	handle *backwards_invocation.BackwardsInvocation,
+	request *dify_invocation.InvokeTTSRequest,
+) {
+
+}
+
+func executeDifyInvocationSpeech2TextTask(
+	handle *backwards_invocation.BackwardsInvocation,
+	request *dify_invocation.InvokeSpeech2TextRequest,
+) {
+
+}
+
+func executeDifyInvocationModerationTask(
+	handle *backwards_invocation.BackwardsInvocation,
+	request *dify_invocation.InvokeModerationRequest,
+) {
 
-		handle.WriteResponse("stream", data)
-	}
 }

+ 4 - 9
internal/core/plugin_manager/remote_manager/run.go

@@ -55,16 +55,11 @@ func (r *RemotePluginRuntime) StartPlugin() error {
 		}
 	})
 
-	for r.response.Next() {
-		data, err := r.response.Read()
-		if err != nil {
-			return err
-		}
-
+	r.response.Wrap(func(data []byte) {
 		// handle event
 		event, err := parser.UnmarshalJsonBytes[plugin_entities.PluginUniversalEvent](data)
 		if err != nil {
-			continue
+			return
 		}
 
 		session_id := event.SessionId
@@ -77,7 +72,7 @@ func (r *RemotePluginRuntime) StartPlugin() error {
 				)
 				if err != nil {
 					log.Error("unmarshal json failed: %s", err.Error())
-					continue
+					return
 				}
 
 				log.Info("plugin %s: %s", r.Configuration().Identity(), log_event.Message)
@@ -96,7 +91,7 @@ func (r *RemotePluginRuntime) StartPlugin() error {
 		case plugin_entities.PLUGIN_EVENT_HEARTBEAT:
 			r.last_active_at = time.Now()
 		}
-	}
+	})
 
 	return exit_error
 }

+ 5 - 0
internal/core/plugin_manager/remote_manager/server.go

@@ -41,6 +41,11 @@ func (r *RemotePluginServer) Next() bool {
 	return r.server.response.Next()
 }
 
+// Wrap wraps the wrap method of stream response
+func (r *RemotePluginServer) Wrap(f func(*RemotePluginRuntime)) {
+	r.server.response.Wrap(f)
+}
+
 // Stop stops the server
 func (r *RemotePluginServer) Stop() error {
 	if r.server.response == nil {

+ 3 - 8
internal/core/plugin_manager/watcher.go

@@ -38,14 +38,9 @@ func startRemoteWatcher(config *app.Config) {
 			}
 		}()
 		go func() {
-			for server.Next() {
-				plugin, err := server.Read()
-				if err != nil {
-					log.Error("encounter error: %s", err.Error())
-					continue
-				}
-				lifetime(config, plugin)
-			}
+			server.Wrap(func(rpr *remote_manager.RemotePluginRuntime) {
+				lifetime(config, rpr)
+			})
 		}()
 	}
 }

+ 20 - 0
internal/utils/stream/response.go

@@ -76,6 +76,26 @@ func (r *StreamResponse[T]) Read() (T, error) {
 	}
 }
 
+// Wrap wraps the stream with a new stream, and allows customized operations
+func (r *StreamResponse[T]) Wrap(fn func(T)) error {
+	r.l.Lock()
+	if r.closed {
+		r.l.Unlock()
+		return errors.New("stream is closed")
+	}
+	r.l.Unlock()
+
+	for r.Next() {
+		data, err := r.Read()
+		if err != nil {
+			return err
+		}
+		fn(data)
+	}
+
+	return nil
+}
+
 // Write writes data to the stream
 // returns error if the buffer is full
 func (r *StreamResponse[T]) Write(data T) error {

+ 22 - 0
internal/utils/stream/response_test.go

@@ -53,3 +53,25 @@ func TestStreamGeneratorErrorMessage(t *testing.T) {
 		}
 	}
 }
+
+func TestStreamGeneratorWrapper(t *testing.T) {
+	response := NewStreamResponse[int](512)
+
+	nums := 0
+
+	go func() {
+		for i := 0; i < 10000; i++ {
+			response.Write(i)
+			time.Sleep(time.Microsecond)
+		}
+		response.Close()
+	}()
+
+	response.Wrap(func(t int) {
+		nums += 1
+	})
+
+	if nums != 10000 {
+		t.Errorf("Expected 10000 messages, got %d", nums)
+	}
+}