Ver código fonte

feat: invoke node

Yeuoly 7 meses atrás
pai
commit
a91951b374

+ 13 - 10
api/core/plugin/backwards_invocation/node.py

@@ -1,4 +1,5 @@
 from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
+from core.workflow.entities.node_entities import NodeType
 from core.workflow.nodes.parameter_extractor.entities import (
     ModelConfig as ParameterExtractorModelConfig,
 )
@@ -36,7 +37,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
         :param model_config: ModelConfig
         :param instruction: str
         :param query: str
-        :return: dict with __reason, __is_success, and other parameters
+        :return: dict
         """
         workflow_service = WorkflowService()
         node_id = "1919810"
@@ -50,6 +51,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
             instruction=instruction,  # instruct with variables are not supported
         )
         node_data_dict = node_data.model_dump()
+        node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR.value
         execution = workflow_service.run_free_workflow_node(
             node_data_dict,
             tenant_id=tenant_id,
@@ -60,10 +62,10 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
             },
         )
 
-        output = execution.outputs_dict
-        return output or {
-            "__reason": "No parameters extracted",
-            "__is_success": False,
+        return {
+            "inputs": execution.inputs_dict,
+            "outputs": execution.outputs_dict,
+            "process_data": execution.process_data_dict,
         }
 
     @classmethod
@@ -85,7 +87,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
         :param classes: list[ClassConfig]
         :param instruction: str
         :param query: str
-        :return: dict with class_name
+        :return: dict
         """
         workflow_service = WorkflowService()
         node_id = "1919810"
@@ -108,7 +110,8 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
             },
         )
 
-        output = execution.outputs_dict
-        return output or {
-            "class_name": classes[0].name,
-        }
+        return {
+            "inputs": execution.inputs_dict,
+            "outputs": execution.outputs_dict,
+            "process_data": execution.process_data_dict,
+        }

+ 6 - 4
api/core/plugin/entities/request.py

@@ -14,16 +14,18 @@ from core.model_runtime.entities.message_entities import (
     UserPromptMessage,
 )
 from core.model_runtime.entities.model_entities import ModelType
-from core.workflow.nodes.question_classifier.entities import (
-    ClassConfig,
-    ModelConfig as QuestionClassifierModelConfig,
-)
 from core.workflow.nodes.parameter_extractor.entities import (
     ModelConfig as ParameterExtractorModelConfig,
 )
 from core.workflow.nodes.parameter_extractor.entities import (
     ParameterConfig,
 )
+from core.workflow.nodes.question_classifier.entities import (
+    ClassConfig,
+)
+from core.workflow.nodes.question_classifier.entities import (
+    ModelConfig as QuestionClassifierModelConfig,
+)
 
 
 class RequestInvokeTool(BaseModel):

+ 20 - 1
api/core/workflow/workflow_entry.py

@@ -221,8 +221,27 @@ class WorkflowEntry:
         """
         # generate a fake graph
         node_config = {"id": node_id, "width": 114, "height": 514, "type": "custom", "data": node_data}
+        start_node_config = {
+            "id": "start",
+            "width": 114,
+            "height": 514,
+            "type": "custom",
+            "data": {
+                "type": NodeType.START.value,
+                "title": "Start",
+                "desc": "Start",
+            },
+        }
         graph_dict = {
-            "nodes": [node_config],
+            "nodes": [start_node_config, node_config],
+            "edges": [
+                {
+                    "source": "start",
+                    "target": node_id,
+                    "sourceHandle": "source",
+                    "targetHandle": "target",
+                }
+            ],
         }
 
         node_type = NodeType.value_of(node_data.get("type", ""))

+ 4 - 0
api/services/workflow_service.py

@@ -230,6 +230,10 @@ class WorkflowService:
             node_id=node_id,
         )
 
+        workflow_node_execution.app_id = app_model.id
+        workflow_node_execution.created_by = account.id
+        workflow_node_execution.workflow_id = draft_workflow.id
+
         db.session.add(workflow_node_execution)
         db.session.commit()