Bläddra i källkod

feat: add backwards invoke node api

Yeuoly 7 månader sedan
förälder
incheckning
68c10a1672

+ 47 - 23
api/controllers/inner_api/plugin/plugin.py

@@ -8,13 +8,15 @@ from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
 from controllers.inner_api.wraps import plugin_inner_api_only
 from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
 from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
+from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
 from core.plugin.encrypt import PluginEncrypter
 from core.plugin.entities.request import (
     RequestInvokeApp,
     RequestInvokeEncrypt,
     RequestInvokeLLM,
     RequestInvokeModeration,
-    RequestInvokeNode,
+    RequestInvokeParameterExtractorNode,
+    RequestInvokeQuestionClassifierNode,
     RequestInvokeRerank,
     RequestInvokeSpeech2Text,
     RequestInvokeTextEmbedding,
@@ -96,23 +98,46 @@ class PluginInvokeToolApi(Resource):
                 yield (
                     ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.TEXT,
-                        message=ToolInvokeMessage.TextMessage(text='helloworld'),
+                        message=ToolInvokeMessage.TextMessage(text="helloworld"),
                     )
                     .model_dump_json()
                     .encode()
-                    + b'\n\n'
+                    + b"\n\n"
                 )
 
         return compact_generate_response(generator())
 
 
-class PluginInvokeNodeApi(Resource):
+class PluginInvokeParameterExtractorNodeApi(Resource):
     @setup_required
     @plugin_inner_api_only
     @get_tenant
-    @plugin_data(payload_type=RequestInvokeNode)
-    def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeNode):
-        pass
+    @plugin_data(payload_type=RequestInvokeParameterExtractorNode)
+    def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
+        return PluginNodeBackwardsInvocation.invoke_parameter_extractor(
+            tenant_id=tenant_model.id,
+            user_id=user_id,
+            parameters=payload.parameters,
+            model_config=payload.model,
+            instruction=payload.instruction,
+            query=payload.query,
+        )
+
+
+class PluginInvokeQuestionClassifierNodeApi(Resource):
+    @setup_required
+    @plugin_inner_api_only
+    @get_tenant
+    @plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
+    def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
+        return PluginNodeBackwardsInvocation.invoke_question_classifier(
+            tenant_id=tenant_model.id,
+            user_id=user_id,
+            query=payload.query,
+            model_config=payload.model,
+            classes=payload.classes,
+            instruction=payload.instruction,
+        )
 
 
 class PluginInvokeAppApi(Resource):
@@ -127,15 +152,13 @@ class PluginInvokeAppApi(Resource):
             tenant_id=tenant_model.id,
             conversation_id=payload.conversation_id,
             query=payload.query,
-            stream=payload.response_mode == 'streaming',
+            stream=payload.response_mode == "streaming",
             inputs=payload.inputs,
-            files=payload.files
-        )
-        
-        return compact_generate_response(
-            PluginAppBackwardsInvocation.convert_to_event_stream(response)
+            files=payload.files,
         )
 
+        return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
+
 
 class PluginInvokeEncryptApi(Resource):
     @setup_required
@@ -149,13 +172,14 @@ class PluginInvokeEncryptApi(Resource):
         return PluginEncrypter.invoke_encrypt(tenant_model, payload)
 
 
-api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
-api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')
-api.add_resource(PluginInvokeRerankApi, '/invoke/rerank')
-api.add_resource(PluginInvokeTTSApi, '/invoke/tts')
-api.add_resource(PluginInvokeSpeech2TextApi, '/invoke/speech2text')
-api.add_resource(PluginInvokeModerationApi, '/invoke/moderation')
-api.add_resource(PluginInvokeToolApi, '/invoke/tool')
-api.add_resource(PluginInvokeNodeApi, '/invoke/node')
-api.add_resource(PluginInvokeAppApi, '/invoke/app')
-api.add_resource(PluginInvokeEncryptApi, '/invoke/encrypt')
+api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
+api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
+api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
+api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
+api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
+api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
+api.add_resource(PluginInvokeToolApi, "/invoke/tool")
+api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
+api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
+api.add_resource(PluginInvokeAppApi, "/invoke/app")
+api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")

+ 114 - 0
api/core/plugin/backwards_invocation/node.py

@@ -0,0 +1,114 @@
+from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
+from core.workflow.nodes.parameter_extractor.entities import (
+    ModelConfig as ParameterExtractorModelConfig,
+)
+from core.workflow.nodes.parameter_extractor.entities import (
+    ParameterConfig,
+    ParameterExtractorNodeData,
+)
+from core.workflow.nodes.question_classifier.entities import (
+    ClassConfig,
+    QuestionClassifierNodeData,
+)
+from core.workflow.nodes.question_classifier.entities import (
+    ModelConfig as QuestionClassifierModelConfig,
+)
+from services.workflow_service import WorkflowService
+
+
+class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
+    @classmethod
+    def invoke_parameter_extractor(
+        cls,
+        tenant_id: str,
+        user_id: str,
+        parameters: list[ParameterConfig],
+        model_config: ParameterExtractorModelConfig,
+        instruction: str,
+        query: str,
+    ) -> dict:
+        """
+        Invoke parameter extractor node.
+
+        :param tenant_id: str
+        :param user_id: str
+        :param parameters: list[ParameterConfig]
+        :param model_config: ModelConfig
+        :param instruction: str
+        :param query: str
+        :return: dict with __reason, __is_success, and other parameters
+        """
+        workflow_service = WorkflowService()
+        node_id = "1919810"
+        node_data = ParameterExtractorNodeData(
+            title="parameter_extractor",
+            desc="parameter_extractor",
+            parameters=parameters,
+            reasoning_mode="function_call",
+            query=[node_id, "query"],
+            model=model_config,
+            instruction=instruction,  # instruct with variables are not supported
+        )
+        node_data_dict = node_data.model_dump()
+        execution = workflow_service.run_free_workflow_node(
+            node_data_dict,
+            tenant_id=tenant_id,
+            user_id=user_id,
+            node_id=node_id,
+            user_inputs={
+                f"{node_id}.query": query,
+            },
+        )
+
+        output = execution.outputs_dict
+        return output or {
+            "__reason": "No parameters extracted",
+            "__is_success": False,
+        }
+
+    @classmethod
+    def invoke_question_classifier(
+        cls,
+        tenant_id: str,
+        user_id: str,
+        model_config: QuestionClassifierModelConfig,
+        classes: list[ClassConfig],
+        instruction: str,
+        query: str,
+    ) -> dict:
+        """
+        Invoke question classifier node.
+
+        :param tenant_id: str
+        :param user_id: str
+        :param model_config: ModelConfig
+        :param classes: list[ClassConfig]
+        :param instruction: str
+        :param query: str
+        :return: dict with class_name
+        """
+        workflow_service = WorkflowService()
+        node_id = "1919810"
+        node_data = QuestionClassifierNodeData(
+            title="question_classifier",
+            desc="question_classifier",
+            query_variable_selector=[node_id, "query"],
+            model=model_config,
+            classes=classes,
+            instruction=instruction,  # instruct with variables are not supported
+        )
+        node_data_dict = node_data.model_dump()
+        execution = workflow_service.run_free_workflow_node(
+            node_data_dict,
+            tenant_id=tenant_id,
+            user_id=user_id,
+            node_id=node_id,
+            user_inputs={
+                f"{node_id}.query": query,
+            },
+        )
+
+        output = execution.outputs_dict
+        return output or {
+            "class_name": classes[0].name,
+        }

+ 28 - 2
api/core/plugin/entities/request.py

@@ -14,6 +14,16 @@ from core.model_runtime.entities.message_entities import (
     UserPromptMessage,
 )
 from core.model_runtime.entities.model_entities import ModelType
+from core.workflow.nodes.question_classifier.entities import (
+    ClassConfig,
+    ModelConfig as QuestionClassifierModelConfig,
+)
+from core.workflow.nodes.parameter_extractor.entities import (
+    ModelConfig as ParameterExtractorModelConfig,
+)
+from core.workflow.nodes.parameter_extractor.entities import (
+    ParameterConfig,
+)
 
 
 class RequestInvokeTool(BaseModel):
@@ -92,11 +102,27 @@ class RequestInvokeModeration(BaseModel):
     """
 
 
-class RequestInvokeNode(BaseModel):
+class RequestInvokeParameterExtractorNode(BaseModel):
     """
-    Request to invoke node
+    Request to invoke parameter extractor node
     """
 
+    parameters: list[ParameterConfig]
+    model: ParameterExtractorModelConfig
+    instruction: str
+    query: str
+
+
+class RequestInvokeQuestionClassifierNode(BaseModel):
+    """
+    Request to invoke question classifier node
+    """
+
+    query: str
+    model: QuestionClassifierModelConfig
+    classes: list[ClassConfig]
+    instruction: str
+
 
 class RequestInvokeApp(BaseModel):
     """

+ 82 - 0
api/core/workflow/workflow_entry.py

@@ -206,6 +206,88 @@ class WorkflowEntry:
             raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
 
     @classmethod
+    def run_free_node(
+        cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
+    ) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]:
+        """
+        Run free node
+
+        NOTE: only parameter_extractor/question_classifier are supported
+
+        :param node_data: node data
+        :param user_id: user id
+        :param user_inputs: user inputs
+        :return:
+        """
+        # generate a fake graph
+        node_config = {"id": node_id, "width": 114, "height": 514, "type": "custom", "data": node_data}
+        graph_dict = {
+            "nodes": [node_config],
+        }
+
+        node_type = NodeType.value_of(node_data.get("type", ""))
+        if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
+            raise ValueError(f"Node type {node_type} not supported")
+
+        node_cls = node_classes.get(node_type)
+        if not node_cls:
+            raise ValueError(f"Node class not found for node type {node_type}")
+
+        graph = Graph.init(graph_config=graph_dict)
+
+        # init variable pool
+        variable_pool = VariablePool(
+            system_variables={},
+            user_inputs={},
+            environment_variables=[],
+        )
+
+        node_cls = cast(type[BaseNode], node_cls)
+        # init workflow run state
+        node_instance: BaseNode = node_cls(
+            id=str(uuid.uuid4()),
+            config=node_config,
+            graph_init_params=GraphInitParams(
+                tenant_id=tenant_id,
+                app_id="",
+                workflow_type=WorkflowType.WORKFLOW,
+                workflow_id="",
+                graph_config=graph_dict,
+                user_id=user_id,
+                user_from=UserFrom.ACCOUNT,
+                invoke_from=InvokeFrom.DEBUGGER,
+                call_depth=0,
+            ),
+            graph=graph,
+            graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
+        )
+
+        try:
+            # variable selector to variable mapping
+            try:
+                variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
+                    graph_config=graph_dict, config=node_config
+                )
+            except NotImplementedError:
+                variable_mapping = {}
+
+            cls.mapping_user_inputs_to_variable_pool(
+                variable_mapping=variable_mapping,
+                user_inputs=user_inputs,
+                variable_pool=variable_pool,
+                tenant_id=tenant_id,
+                node_type=node_type,
+                node_data=node_instance.node_data,
+            )
+
+            # run node
+            generator = node_instance.run()
+
+            return node_instance, generator
+        except Exception as e:
+            raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
+
+    @classmethod
     def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]:
         """
         Handle special values

+ 64 - 17
api/services/workflow_service.py

@@ -1,8 +1,8 @@
 import json
 import time
-from collections.abc import Sequence
+from collections.abc import Callable, Generator, Sequence
 from datetime import datetime, timezone
-from typing import Optional
+from typing import Any, Optional
 
 from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
 from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@@ -10,7 +10,9 @@ from core.app.segments import Variable
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.workflow.entities.node_entities import NodeRunResult, NodeType
 from core.workflow.errors import WorkflowNodeRunFailedError
-from core.workflow.nodes.event import RunCompletedEvent
+from core.workflow.graph_engine.entities.event import InNodeEvent
+from core.workflow.nodes.base_node import BaseNode
+from core.workflow.nodes.event import RunCompletedEvent, RunEvent
 from core.workflow.nodes.node_mapping import node_classes
 from core.workflow.workflow_entry import WorkflowEntry
 from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
@@ -216,13 +218,64 @@ class WorkflowService:
         # run draft workflow node
         start_at = time.perf_counter()
 
-        try:
-            node_instance, generator = WorkflowEntry.single_step_run(
+        workflow_node_execution = self._handle_node_run_result(
+            getter=lambda: WorkflowEntry.single_step_run(
                 workflow=draft_workflow,
                 node_id=node_id,
                 user_inputs=user_inputs,
                 user_id=account.id,
-            )
+            ),
+            start_at=start_at,
+            tenant_id=app_model.tenant_id,
+            node_id=node_id,
+        )
+
+        db.session.add(workflow_node_execution)
+        db.session.commit()
+
+        return workflow_node_execution
+    
+    def run_free_workflow_node(
+        self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
+    ) -> WorkflowNodeExecution:
+        """
+        Run draft workflow node
+        """
+        # run draft workflow node
+        start_at = time.perf_counter()
+
+        workflow_node_execution = self._handle_node_run_result(
+            getter=lambda: WorkflowEntry.run_free_node(
+                node_id=node_id,
+                node_data=node_data,
+                tenant_id=tenant_id,
+                user_id=user_id,
+                user_inputs=user_inputs,
+            ),
+            start_at=start_at,
+            tenant_id=tenant_id,
+            node_id=node_id
+        )
+
+        return workflow_node_execution
+
+    def _handle_node_run_result(
+        self,
+        getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]],
+        start_at: float,
+        tenant_id: str,
+        node_id: str,
+    ):
+        """
+        Handle node run result
+
+        :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
+        :param start_at: float
+        :param tenant_id: str
+        :param node_id: str
+        """
+        try:
+            node_instance, generator = getter()
 
             node_run_result: NodeRunResult | None = None
             for event in generator:
@@ -245,9 +298,7 @@ class WorkflowService:
             error = e.error
 
         workflow_node_execution = WorkflowNodeExecution()
-        workflow_node_execution.tenant_id = app_model.tenant_id
-        workflow_node_execution.app_id = app_model.id
-        workflow_node_execution.workflow_id = draft_workflow.id
+        workflow_node_execution.tenant_id = tenant_id
         workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
         workflow_node_execution.index = 1
         workflow_node_execution.node_id = node_id
@@ -255,7 +306,6 @@ class WorkflowService:
         workflow_node_execution.title = node_instance.node_data.title
         workflow_node_execution.elapsed_time = time.perf_counter() - start_at
         workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
-        workflow_node_execution.created_by = account.id
         workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
         workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
 
@@ -277,9 +327,6 @@ class WorkflowService:
             workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
             workflow_node_execution.error = error
 
-        db.session.add(workflow_node_execution)
-        db.session.commit()
-
         return workflow_node_execution
 
     def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
@@ -302,10 +349,10 @@ class WorkflowService:
         new_app = workflow_converter.convert_to_workflow(
             app_model=app_model,
             account=account,
-            name=args.get("name"),
-            icon_type=args.get("icon_type"),
-            icon=args.get("icon"),
-            icon_background=args.get("icon_background"),
+            name=args.get("name", ""),
+            icon_type=args.get("icon_type", ""),
+            icon=args.get("icon", ""),
+            icon_background=args.get("icon_background", ""),
         )
 
         return new_app