Browse Source

fix(workflow): Take back LLM streaming output after IF-ELSE (#9875)

-LAN- 5 months ago
parent
commit
72ea3d6b98

+ 6 - 7
api/core/workflow/graph_engine/graph_engine.py

@@ -130,15 +130,14 @@ class GraphEngine:
         yield GraphRunStartedEvent()
 
         try:
-            stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
             if self.init_params.workflow_type == WorkflowType.CHAT:
-                stream_processor_cls = AnswerStreamProcessor
+                stream_processor = AnswerStreamProcessor(
+                    graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool
+                )
             else:
-                stream_processor_cls = EndStreamProcessor
-
-            stream_processor = stream_processor_cls(
-                graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool
-            )
+                stream_processor = EndStreamProcessor(
+                    graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool
+                )
 
             # run graph
             generator = stream_processor.process(self._run(start_node_id=self.graph.root_node_id))

+ 4 - 4
api/core/workflow/nodes/answer/answer_stream_generate_router.py

@@ -149,10 +149,10 @@ class AnswerStreamGeneratorRouter:
             source_node_id = edge.source_node_id
             source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
             if source_node_type in {
-                NodeType.ANSWER.value,
-                NodeType.IF_ELSE.value,
-                NodeType.QUESTION_CLASSIFIER.value,
-                NodeType.ITERATION.value,
+                NodeType.ANSWER,
+                NodeType.IF_ELSE,
+                NodeType.QUESTION_CLASSIFIER,
+                NodeType.ITERATION,
             }:
                 answer_dependencies[answer_node_id].append(source_node_id)
             else:

+ 1 - 1
api/core/workflow/nodes/answer/answer_stream_processor.py

@@ -22,7 +22,7 @@ class AnswerStreamProcessor(StreamProcessor):
         super().__init__(graph, variable_pool)
         self.generate_routes = graph.answer_stream_generate_routes
         self.route_position = {}
-        for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
+        for answer_node_id in self.generate_routes.answer_generate_route:
             self.route_position[answer_node_id] = 0
         self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
 

+ 0 - 1
api/core/workflow/nodes/answer/base_stream_processor.py

@@ -41,7 +41,6 @@ class StreamProcessor(ABC):
                     continue
                 else:
                     unreachable_first_node_ids.append(edge.target_node_id)
-                    unreachable_first_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
 
             for node_id in unreachable_first_node_ids:
                 self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)

+ 2 - 1
api/core/workflow/nodes/answer/entities.py

@@ -1,3 +1,4 @@
+from collections.abc import Sequence
 from enum import Enum
 
 from pydantic import BaseModel, Field
@@ -32,7 +33,7 @@ class VarGenerateRouteChunk(GenerateRouteChunk):
 
     type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR
     """generate route chunk type"""
-    value_selector: list[str] = Field(..., description="value selector")
+    value_selector: Sequence[str] = Field(..., description="value selector")
 
 
 class TextGenerateRouteChunk(GenerateRouteChunk):