Browse Source

feat: backwards invoke tools

Yeuoly 8 months ago
parent
commit
118fa66567

+ 12 - 14
api/controllers/inner_api/plugin/plugin.py

@@ -1,5 +1,3 @@
-import time
-
 from flask_restful import Resource
 
 from controllers.console.setup import setup_required
@@ -10,6 +8,7 @@ from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
 from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
 from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
 from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
+from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
 from core.plugin.encrypt import PluginEncrypter
 from core.plugin.entities.request import (
     RequestInvokeApp,
@@ -24,7 +23,7 @@ from core.plugin.entities.request import (
     RequestInvokeTool,
     RequestInvokeTTS,
 )
-from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.entities.tool_entities import ToolProviderType
 from libs.helper import compact_generate_response
 from models.account import Tenant
 
@@ -138,17 +137,16 @@ class PluginInvokeToolApi(Resource):
     @plugin_data(payload_type=RequestInvokeTool)
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTool):
         def generator():
-            for i in range(10):
-                time.sleep(0.1)
-                yield (
-                    ToolInvokeMessage(
-                        type=ToolInvokeMessage.MessageType.TEXT,
-                        message=ToolInvokeMessage.TextMessage(text="helloworld"),
-                    )
-                    .model_dump_json()
-                    .encode()
-                    + b"\n\n"
-                )
+            return PluginToolBackwardsInvocation.convert_to_event_stream(
+                PluginToolBackwardsInvocation.invoke_tool(
+                    tenant_id=tenant_model.id,
+                    user_id=user_id,
+                    tool_type=ToolProviderType.value_of(payload.tool_type),
+                    provider=payload.provider,
+                    tool_name=payload.tool,
+                    tool_parameters=payload.tool_parameters,
+                ),
+            )
 
         return compact_generate_response(generator())
 

+ 45 - 0
api/core/plugin/backwards_invocation/tool.py

@@ -0,0 +1,45 @@
+from collections.abc import Generator
+from typing import Any
+
+from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
+from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
+from core.tools.tool_engine import ToolEngine
+from core.tools.tool_manager import ToolManager
+from core.tools.utils.message_transformer import ToolFileMessageTransformer
+
+
+class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
+    """
+    Backwards invocation for plugin tools.
+    """
+
+    @classmethod
+    def invoke_tool(
+        cls,
+        tenant_id: str,
+        user_id: str,
+        tool_type: ToolProviderType,
+        provider: str,
+        tool_name: str,
+        tool_parameters: dict[str, Any],
+    ) -> Generator[ToolInvokeMessage, None, None]:
+        """
+        invoke tool
+        """
+        # get tool runtime
+        try:
+            tool_runtime = ToolManager.get_tool_runtime_from_plugin(
+                tool_type, tenant_id, provider, tool_name, tool_parameters
+            )
+            response = ToolEngine.generic_invoke(
+                tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
+            )
+
+            response = ToolFileMessageTransformer.transform_tool_invoke_messages(
+                response, user_id=user_id, tenant_id=tenant_id
+            )
+
+            return response
+        except Exception as e:
+            raise e

+ 5 - 0
api/core/plugin/entities/request.py

@@ -32,6 +32,11 @@ class RequestInvokeTool(BaseModel):
     Request to invoke a tool
     """
 
+    tool_type: Literal["builtin", "workflow", "api"]
+    provider: str
+    tool: str
+    tool_parameters: dict
+
 
 class BaseRequestInvokeModel(BaseModel):
     provider: str

+ 1 - 0
api/core/tools/entities/tool_entities.py

@@ -378,6 +378,7 @@ class ToolInvokeFrom(Enum):
 
     WORKFLOW = "workflow"
     AGENT = "agent"
+    PLUGIN = "plugin"
 
 
 class ToolProviderID:

+ 1 - 1
api/core/tools/tool_engine.py

@@ -131,7 +131,7 @@ class ToolEngine:
         return error_response, [], ToolInvokeMeta.error_instance(error_response)
 
     @staticmethod
-    def workflow_invoke(
+    def generic_invoke(
         tool: Tool,
         tool_parameters: dict[str, Any],
         user_id: str,

+ 34 - 0
api/core/tools/tool_manager.py

@@ -366,6 +366,40 @@ class ToolManager:
         return tool_runtime
 
     @classmethod
+    def get_tool_runtime_from_plugin(
+        cls,
+        tool_type: ToolProviderType,
+        tenant_id: str,
+        provider: str,
+        tool_name: str,
+        tool_parameters: dict[str, Any],
+    ) -> Tool:
+        """
+        get tool runtime from plugin
+        """
+        tool_entity = cls.get_tool_runtime(
+            provider_type=tool_type,
+            provider_id=provider,
+            tool_name=tool_name,
+            tenant_id=tenant_id,
+            invoke_from=InvokeFrom.SERVICE_API,
+            tool_invoke_from=ToolInvokeFrom.PLUGIN,
+        )
+        runtime_parameters = {}
+        parameters = tool_entity.get_merged_runtime_parameters()
+        for parameter in parameters:
+            if parameter.form == ToolParameter.ToolParameterForm.FORM:
+                # save tool parameter to tool entity memory
+                value = cls._init_runtime_parameter(parameter, tool_parameters)
+                runtime_parameters[parameter.name] = value
+
+        if not tool_entity.runtime:
+            raise Exception("tool missing runtime")
+
+        tool_entity.runtime.runtime_parameters.update(runtime_parameters)
+        return tool_entity
+
+    @classmethod
     def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
         """
         get the absolute path of the icon of the builtin provider

+ 1 - 1
api/core/workflow/nodes/tool/tool_node.py

@@ -66,7 +66,7 @@ class ToolNode(BaseNode):
         )
 
         try:
-            message_stream = ToolEngine.workflow_invoke(
+            message_stream = ToolEngine.generic_invoke(
                 tool=tool_runtime,
                 tool_parameters=parameters,
                 user_id=self.user_id,