Parcourir la source

fix(workflow): fix answer node stream processing in conditional branches (#12510)

Kevin9703 il y a 3 mois
Parent
commit
54b5b80a07
1 fichiers modifiés avec 18 ajouts et 7 suppressions
  1. 18 7
      api/core/workflow/nodes/answer/base_stream_processor.py

+ 18 - 7
api/core/workflow/nodes/answer/base_stream_processor.py

@@ -1,6 +1,7 @@
 import logging
 from abc import ABC, abstractmethod
 from collections.abc import Generator
+from typing import Optional
 
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
@@ -48,25 +49,35 @@ class StreamProcessor(ABC):
                     # we remove the node maybe shortcut the answer node, so comment this code for now
                     # there is not effect on the answer node and the workflow, when we have a better solution
                     # we can open this code. Issues: #11542 #9560 #10638 #10564
-                    ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
-                    if "answer" in ids:
-                        continue
-                    else:
-                        reachable_node_ids.extend(ids)
+                    # ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
+                    # if "answer" in ids:
+                    #     continue
+                    # else:
+                    #     reachable_node_ids.extend(ids)
+
+                    # The branch_identify parameter is added to ensure that
+                    # only nodes in the correct logical branch are included.
+                    ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
+                    reachable_node_ids.extend(ids)
                 else:
                     unreachable_first_node_ids.append(edge.target_node_id)
 
             for node_id in unreachable_first_node_ids:
                 self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
 
-    def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
+    def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
         node_ids = []
         for edge in self.graph.edge_mapping.get(node_id, []):
             if edge.target_node_id == self.graph.root_node_id:
                 continue
 
+            # Only follow edges that match the branch_identify or have no run_condition
+            if edge.run_condition and edge.run_condition.branch_identify:
+                if not branch_identify or edge.run_condition.branch_identify != branch_identify:
+                    continue
+
             node_ids.append(edge.target_node_id)
-            node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
+            node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id, branch_identify))
         return node_ids
 
     def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None: