Browse Source

refactor: tool response to generator

Yeuoly 9 months ago
parent
commit
563d81277b

+ 2 - 0
api/controllers/inner_api/plugin/plugin.py

@@ -21,6 +21,8 @@ class PluginInvokeModelApi(Resource):
         parser.add_argument('parameters', type=dict, required=True, location='json')
 
         args = parser.parse_args()
+
+        
         
 
 class PluginInvokeToolApi(Resource):

+ 4 - 2
api/core/agent/entities.py

@@ -1,14 +1,16 @@
 from enum import Enum
-from typing import Any, Literal, Optional, Union
+from typing import Any, Optional, Union
 
 from pydantic import BaseModel
 
+from core.tools.entities.tool_entities import ToolProviderType
+
 
 class AgentToolEntity(BaseModel):
     """
     Agent Tool Entity.
     """
-    provider_type: Literal["builtin", "api", "workflow"]
+    provider_type: ToolProviderType
     provider_id: str
     tool_name: str
     tool_parameters: dict[str, Any] = {}

+ 5 - 0
api/core/callback_handler/plugin_tool_callback_handler.py

@@ -0,0 +1,5 @@
+from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
+
+
+class DifyPluginCallbackHandler(DifyAgentCallbackHandler):
+    """Callback Handler that prints to std out."""

+ 3 - 2
api/core/tools/tool/api_tool.py

@@ -1,4 +1,5 @@
 import json
+from collections.abc import Generator
 from os import getenv
 from typing import Any
 from urllib.parse import urlencode
@@ -269,7 +270,7 @@ class ApiTool(Tool):
         except ValueError as e:
             return value
 
-    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
         """
         invoke http request
         """
@@ -283,4 +284,4 @@ class ApiTool(Tool):
         response = self.validate_and_parse_response(response)
 
         # assemble invoke message
-        return self.create_text_message(response)
+        yield self.create_text_message(response)

+ 3 - 2
api/core/tools/tool/dataset_retriever_tool.py

@@ -1,3 +1,4 @@
+from collections.abc import Generator
 from typing import Any
 
 from core.app.app_config.entities import DatasetRetrieveConfigEntity
@@ -86,7 +87,7 @@ class DatasetRetrieverTool(Tool):
     def tool_provider_type(self) -> ToolProviderType:
         return ToolProviderType.DATASET_RETRIEVAL
 
-    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
         """
         invoke dataset retriever tool
         """
@@ -97,7 +98,7 @@ class DatasetRetrieverTool(Tool):
         # invoke dataset retriever tool
         result = self.retrival_tool._run(query=query)
 
-        return self.create_text_message(text=result)
+        yield self.create_text_message(text=result)
 
     def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
         """

+ 3 - 5
api/core/tools/tool/tool.py

@@ -1,4 +1,5 @@
 from abc import ABC, abstractmethod
+from collections.abc import Generator
 from copy import deepcopy
 from enum import Enum
 from typing import Any, Optional, Union
@@ -190,7 +191,7 @@ class Tool(BaseModel, ABC):
 
         return result
 
-    def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
+    def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
         # update tool_parameters
         if self.runtime.runtime_parameters:
             tool_parameters.update(self.runtime.runtime_parameters)
@@ -203,9 +204,6 @@ class Tool(BaseModel, ABC):
             tool_parameters=tool_parameters,
         )
 
-        if not isinstance(result, list):
-            result = [result]
-
         return result
 
     def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
@@ -221,7 +219,7 @@ class Tool(BaseModel, ABC):
         return result
 
     @abstractmethod
-    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
         pass
     
     def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:

+ 5 - 7
api/core/tools/tool/workflow_tool.py

@@ -1,5 +1,6 @@
 import json
 import logging
+from collections.abc import Generator
 from copy import deepcopy
 from typing import Any, Union
 
@@ -34,7 +35,7 @@ class WorkflowTool(Tool):
 
     def _invoke(
         self, user_id: str, tool_parameters: dict[str, Any]
-    ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+    ) -> Generator[ToolInvokeMessage, None, None]:
         """
             invoke the tool
         """
@@ -46,6 +47,7 @@ class WorkflowTool(Tool):
 
         from core.app.apps.workflow.app_generator import WorkflowAppGenerator
         generator = WorkflowAppGenerator()
+
         result = generator.generate(
             app_model=app, 
             workflow=workflow, 
@@ -64,16 +66,12 @@ class WorkflowTool(Tool):
         if data.get('error'):
             raise Exception(data.get('error'))
         
-        result = []
-
         outputs = data.get('outputs', {})
         outputs, files = self._extract_files(outputs)
         for file in files:
-            result.append(self.create_file_var_message(file))
+            yield self.create_file_var_message(file)
         
-        result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
-
-        return result
+        yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
 
     def _get_user(self, user_id: str) -> Union[EndUser, Account]:
         """

+ 53 - 13
api/core/tools/tool_engine.py

@@ -1,4 +1,5 @@
 import json
+from collections.abc import Generator
 from copy import deepcopy
 from datetime import datetime, timezone
 from mimetypes import guess_type
@@ -8,6 +9,7 @@ from yarl import URL
 
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
+from core.callback_handler.plugin_tool_callback_handler import DifyPluginCallbackHandler
 from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
 from core.file.file_obj import FileTransferMethod
 from core.ops.ops_trace_manager import TraceQueueManager
@@ -64,16 +66,25 @@ class ToolEngine:
                 tool_inputs=tool_parameters
             )
 
-            meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
-            response = ToolFileMessageTransformer.transform_tool_invoke_messages(
-                messages=response, 
-                user_id=user_id, 
-                tenant_id=tenant_id, 
+            messages = ToolEngine._invoke(tool, tool_parameters, user_id)
+            invocation_meta_dict = {'meta': None}
+
+            def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage, None, None]):
+                for message in messages:
+                    if isinstance(message, ToolInvokeMeta):
+                        invocation_meta_dict['meta'] = message
+                    else:
+                        yield message
+
+            messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
+                messages=message_callback(invocation_meta_dict, messages),
+                user_id=user_id,
+                tenant_id=tenant_id,
                 conversation_id=message.conversation_id
             )
 
             # extract binary data from tool invoke message
-            binary_files = ToolEngine._extract_tool_response_binary(response)
+            binary_files = ToolEngine._extract_tool_response_binary(messages)
             # create message file
             message_files = ToolEngine._create_message_files(
                 tool_messages=binary_files,
@@ -82,7 +93,9 @@ class ToolEngine:
                 user_id=user_id
             )
 
-            plain_text = ToolEngine._convert_tool_response_to_str(response)
+            plain_text = ToolEngine._convert_tool_response_to_str(messages)
+
+            meta = invocation_meta_dict['meta']
 
             # hit the callback handler
             agent_tool_callback.on_tool_end(
@@ -127,7 +140,7 @@ class ToolEngine:
                         user_id: str, workflow_id: str, 
                         workflow_tool_callback: DifyWorkflowCallbackHandler,
                         workflow_call_depth: int,
-                        ) -> list[ToolInvokeMessage]:
+        ) -> Generator[ToolInvokeMessage, None, None]:
         """
         Workflow invokes the tool with the given arguments.
         """
@@ -154,10 +167,38 @@ class ToolEngine:
         except Exception as e:
             workflow_tool_callback.on_tool_error(e)
             raise e
-        
+    
+    @staticmethod
+    def plugin_invoke(tool: Tool, tool_parameters: dict, user_id: str,
+                      callback: DifyPluginCallbackHandler
+        ) -> Generator[ToolInvokeMessage, None, None]:
+        """
+        Plugin invokes the tool with the given arguments.
+        """
+        try:
+            # hit the callback handler
+            callback.on_tool_start(
+                tool_name=tool.identity.name, 
+                tool_inputs=tool_parameters
+            )
+
+            response = tool.invoke(user_id, tool_parameters)
+
+            # hit the callback handler
+            callback.on_tool_end(
+                tool_name=tool.identity.name,
+                tool_inputs=tool_parameters,
+                tool_outputs=response,
+            )
+
+            return response
+        except Exception as e:
+            callback.on_tool_error(e)
+            raise e
+    
     @staticmethod
     def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \
-          -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]:
+          -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
         """
         Invoke the tool with the given arguments.
         """
@@ -170,16 +211,15 @@ class ToolEngine:
             'tool_icon': tool.identity.icon
         })
         try:
-            response = tool.invoke(user_id, tool_parameters)
+            yield from tool.invoke(user_id, tool_parameters)
         except Exception as e:
             meta.error = str(e)
             raise ToolEngineInvokeError(meta)
         finally:
             ended_at = datetime.now(timezone.utc)
             meta.time_cost = (ended_at - started_at).total_seconds()
+            yield meta
 
-        return meta, response
-    
     @staticmethod
     def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
         """

+ 13 - 37
api/core/tools/tool_manager.py

@@ -18,6 +18,7 @@ from core.tools.entities.tool_entities import (
     ApiProviderAuthType,
     ToolInvokeFrom,
     ToolParameter,
+    ToolProviderType,
 )
 from core.tools.errors import ToolProviderNotFoundError
 from core.tools.provider.api_tool_provider import ApiToolProviderController
@@ -26,6 +27,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
 from core.tools.tool.api_tool import ApiTool
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.tool import Tool
+from core.tools.tool.workflow_tool import WorkflowTool
 from core.tools.tool_label_manager import ToolLabelManager
 from core.tools.utils.configuration import (
     ToolConfigurationManager,
@@ -78,37 +80,13 @@ class ToolManager:
         return tool
 
     @classmethod
-    def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
-            -> Union[BuiltinTool, ApiTool]:
-        """
-            get the tool
-
-            :param provider_type: the type of the provider
-            :param provider_name: the name of the provider
-            :param tool_name: the name of the tool
-
-            :return: the tool
-        """
-        if provider_type == 'builtin':
-            return cls.get_builtin_tool(provider_id, tool_name)
-        elif provider_type == 'api':
-            if tenant_id is None:
-                raise ValueError('tenant id is required for api provider')
-            api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id)
-            return api_provider.get_tool(tool_name)
-        elif provider_type == 'app':
-            raise NotImplementedError('app provider not implemented')
-        else:
-            raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
-
-    @classmethod
-    def get_tool_runtime(cls, provider_type: str,
+    def get_tool_runtime(cls, provider_type: ToolProviderType,
                          provider_id: str,
                          tool_name: str,
                          tenant_id: str,
                          invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
                          tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-        -> Union[BuiltinTool, ApiTool]:
+        -> Union[BuiltinTool, ApiTool, WorkflowTool]:
         """
             get the tool runtime
 
@@ -118,7 +96,7 @@ class ToolManager:
 
             :return: the tool
         """
-        if provider_type == 'builtin':
+        if provider_type == ToolProviderType.BUILT_IN:
             builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
 
             # check if the builtin tool need credentials
@@ -155,7 +133,7 @@ class ToolManager:
                 'tool_invoke_from': tool_invoke_from,
             })
 
-        elif provider_type == 'api':
+        elif provider_type == ToolProviderType.API:
             if tenant_id is None:
                 raise ValueError('tenant id is required for api provider')
 
@@ -171,7 +149,7 @@ class ToolManager:
                 'invoke_from': invoke_from,
                 'tool_invoke_from': tool_invoke_from,
             })
-        elif provider_type == 'workflow':
+        elif provider_type == ToolProviderType.WORKFLOW:
             workflow_provider = db.session.query(WorkflowToolProvider).filter(
                 WorkflowToolProvider.tenant_id == tenant_id,
                 WorkflowToolProvider.id == provider_id
@@ -190,10 +168,10 @@ class ToolManager:
                 'invoke_from': invoke_from,
                 'tool_invoke_from': tool_invoke_from,
             })
-        elif provider_type == 'app':
+        elif provider_type == ToolProviderType.APP:
             raise NotImplementedError('app provider not implemented')
         else:
-            raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
+            raise ToolProviderNotFoundError(f'provider type {provider_type.value} not found')
 
     @classmethod
     def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
@@ -554,7 +532,7 @@ class ToolManager:
         })
 
     @classmethod
-    def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]:
+    def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
         """
             get the tool icon
 
@@ -563,14 +541,12 @@ class ToolManager:
             :param provider_id: the id of the provider
             :return:
         """
-        provider_type = provider_type
-        provider_id = provider_id
-        if provider_type == 'builtin':
+        if provider_type == ToolProviderType.BUILT_IN:
             return (current_app.config.get("CONSOLE_API_URL")
                     + "/console/api/workspaces/current/tool-provider/builtin/"
                     + provider_id
                     + "/icon")
-        elif provider_type == 'api':
+        elif provider_type == ToolProviderType.API:
             try:
                 provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
                     ApiToolProvider.tenant_id == tenant_id,
@@ -582,7 +558,7 @@ class ToolManager:
                     "background": "#252525",
                     "content": "\ud83d\ude01"
                 }
-        elif provider_type == 'workflow':
+        elif provider_type == ToolProviderType.WORKFLOW:
             provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
                 WorkflowToolProvider.tenant_id == tenant_id,
                 WorkflowToolProvider.id == provider_id

+ 4 - 3
api/core/tools/utils/configuration.py

@@ -9,6 +9,7 @@ from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolPr
 from core.tools.entities.tool_entities import (
     ToolParameter,
     ToolProviderCredentials,
+    ToolProviderType,
 )
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.tool.tool import Tool
@@ -108,7 +109,7 @@ class ToolParameterConfigurationManager(BaseModel):
     tenant_id: str
     tool_runtime: Tool
     provider_name: str
-    provider_type: str
+    provider_type: ToolProviderType
     identity_id: str
 
     def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
@@ -191,7 +192,7 @@ class ToolParameterConfigurationManager(BaseModel):
         """
         cache = ToolParameterCache(
             tenant_id=self.tenant_id,
-            provider=f'{self.provider_type}.{self.provider_name}',
+            provider=f'{self.provider_type.value}.{self.provider_name}',
             tool_name=self.tool_runtime.identity.name,
             cache_type=ToolParameterCacheType.PARAMETER,
             identity_id=self.identity_id
@@ -221,7 +222,7 @@ class ToolParameterConfigurationManager(BaseModel):
     def delete_tool_parameters_cache(self):
         cache = ToolParameterCache(
             tenant_id=self.tenant_id,
-            provider=f'{self.provider_type}.{self.provider_name}',
+            provider=f'{self.provider_type.value}.{self.provider_name}',
             tool_name=self.tool_runtime.identity.name,
             cache_type=ToolParameterCacheType.PARAMETER,
             identity_id=self.identity_id

+ 19 - 22
api/core/tools/utils/message_transformer.py

@@ -1,4 +1,5 @@
 import logging
+from collections.abc import Generator
 from mimetypes import guess_extension
 
 from core.file.file_obj import FileTransferMethod, FileType, FileVar
@@ -9,20 +10,18 @@ logger = logging.getLogger(__name__)
 
 class ToolFileMessageTransformer:
     @classmethod
-    def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
+    def transform_tool_invoke_messages(cls, messages: Generator[ToolInvokeMessage, None, None],
                                        user_id: str,
                                        tenant_id: str,
-                                       conversation_id: str) -> list[ToolInvokeMessage]:
+                                       conversation_id: str) -> Generator[ToolInvokeMessage, None, None]:
         """
         Transform tool message and handle file download
         """
-        result = []
-
         for message in messages:
             if message.type == ToolInvokeMessage.MessageType.TEXT:
-                result.append(message)
+                yield message
             elif message.type == ToolInvokeMessage.MessageType.LINK:
-                result.append(message)
+                yield message
             elif message.type == ToolInvokeMessage.MessageType.IMAGE:
                 # try to download image
                 try:
@@ -35,20 +34,20 @@ class ToolFileMessageTransformer:
                     
                     url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
 
-                    result.append(ToolInvokeMessage(
+                    yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.IMAGE_LINK,
                         message=url,
                         save_as=message.save_as,
                         meta=message.meta.copy() if message.meta is not None else {},
-                    ))
+                    )
                 except Exception as e:
                     logger.exception(e)
-                    result.append(ToolInvokeMessage(
+                    yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.TEXT,
                         message=f"Failed to download image: {message.message}, you can try to download it yourself.",
                         meta=message.meta.copy() if message.meta is not None else {},
                         save_as=message.save_as,
-                    ))
+                    )
             elif message.type == ToolInvokeMessage.MessageType.BLOB:
                 # get mime type and save blob to storage
                 mimetype = message.meta.get('mime_type', 'octet/stream')
@@ -67,43 +66,41 @@ class ToolFileMessageTransformer:
 
                 # check if file is image
                 if 'image' in mimetype:
-                    result.append(ToolInvokeMessage(
+                    yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.IMAGE_LINK,
                         message=url,
                         save_as=message.save_as,
                         meta=message.meta.copy() if message.meta is not None else {},
-                    ))
+                    )
                 else:
-                    result.append(ToolInvokeMessage(
+                    yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.LINK,
                         message=url,
                         save_as=message.save_as,
                         meta=message.meta.copy() if message.meta is not None else {},
-                    ))
+                    )
             elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
                 file_var: FileVar = message.meta.get('file_var')
                 if file_var:
                     if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
                         url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
                         if file_var.type == FileType.IMAGE:
-                            result.append(ToolInvokeMessage(
+                            yield ToolInvokeMessage(
                                 type=ToolInvokeMessage.MessageType.IMAGE_LINK,
                                 message=url,
                                 save_as=message.save_as,
                                 meta=message.meta.copy() if message.meta is not None else {},
-                            ))
+                            )
                         else:
-                            result.append(ToolInvokeMessage(
+                            yield ToolInvokeMessage(
                                 type=ToolInvokeMessage.MessageType.LINK,
                                 message=url,
                                 save_as=message.save_as,
                                 meta=message.meta.copy() if message.meta is not None else {},
-                            ))
+                            )
             else:
-                result.append(message)
-
-        return result
+                yield message
     
     @classmethod
     def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
-        return f'/files/tools/{tool_file_id}{extension or ".bin"}'
+        return f'/files/tools/{tool_file_id}{extension or ".bin"}'

+ 2 - 1
api/core/workflow/nodes/tool/entities.py

@@ -3,12 +3,13 @@ from typing import Any, Literal, Union
 from pydantic import BaseModel, field_validator
 from pydantic_core.core_schema import ValidationInfo
 
+from core.tools.entities.tool_entities import ToolProviderType
 from core.workflow.entities.base_node_data_entities import BaseNodeData
 
 
 class ToolEntity(BaseModel):
     provider_id: str
-    provider_type: Literal['builtin', 'api', 'workflow']
+    provider_type: ToolProviderType
     provider_name: str # redundancy
     tool_name: str
     tool_label: str # redundancy

+ 1 - 1
api/core/workflow/nodes/tool/tool_node.py

@@ -32,7 +32,7 @@ class ToolNode(BaseNode):
 
         # fetch tool icon
         tool_info = {
-            'provider_type': node_data.provider_type,
+            'provider_type': node_data.provider_type.value,
             'provider_id': node_data.provider_id
         }
 

+ 38 - 5
api/services/plugin/plugin_invoke_service.py

@@ -1,16 +1,49 @@
 from collections.abc import Generator
-from typing import Any
+from typing import Any, Union
 
-from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.callback_handler.plugin_tool_callback_handler import DifyPluginCallbackHandler
+from core.model_runtime.entities.model_entities import ModelType
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
+from core.tools.tool_engine import ToolEngine
+from core.tools.tool_manager import ToolManager
+from core.tools.utils.message_transformer import ToolFileMessageTransformer
+from core.workflow.entities.node_entities import NodeType
 from models.account import Tenant
+from services.tools.tools_transform_service import ToolTransformService
 
 
 class PluginInvokeService:
     @classmethod
-    def invoke_tool(cls, user_id: str, tenant: Tenant, 
-                    tool_provider: str, tool_name: str,
+    def invoke_tool(cls, user_id: str, invoke_from: InvokeFrom, tenant: Tenant, 
+                    tool_provider_type: ToolProviderType, tool_provider: str, tool_name: str,
                     tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
         """
         Invokes a tool with the given user ID and tool parameters.
         """
-        
+        tool_runtime = ToolManager.get_tool_runtime(tool_provider_type, provider_id=tool_provider, 
+                                                    tool_name=tool_name, tenant_id=tenant.id, 
+                                                    invoke_from=invoke_from)
+        
+        response = ToolEngine.plugin_invoke(tool_runtime, 
+                                            tool_parameters, 
+                                            user_id, 
+                                            callback=DifyPluginCallbackHandler())
+        response = ToolFileMessageTransformer.transform_tool_invoke_messages(response)
+        return ToolTransformService.transform_messages_to_dict(response)
+        
+    @classmethod
+    def invoke_model(cls, user_id: str, tenant: Tenant, 
+                     model_provider: str, model_name: str, model_type: ModelType,
+                     model_parameters: dict[str, Any]) -> Union[dict, Generator[ToolInvokeMessage]]:
+        """
+        Invokes a model with the given user ID and model parameters.
+        """
+
+    @classmethod
+    def invoke_workflow_node(cls, user_id: str, tenant: Tenant, 
+                              node_type: NodeType, node_data: dict[str, Any],
+                              inputs: dict[str, Any]) -> Generator[ToolInvokeMessage]:
+        """
+        Invokes a workflow node with the given user ID and node parameters.
+        """

+ 22 - 10
api/services/tools/tools_transform_service.py

@@ -1,5 +1,6 @@
 import json
 import logging
+from collections.abc import Generator
 from typing import Optional, Union
 
 from flask import current_app
@@ -9,6 +10,7 @@ from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import (
     ApiProviderAuthType,
+    ToolInvokeMessage,
     ToolParameter,
     ToolProviderCredentials,
     ToolProviderType,
@@ -24,8 +26,8 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
 logger = logging.getLogger(__name__)
 
 class ToolTransformService:
-    @staticmethod
-    def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
+    @classmethod
+    def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
         """
             get tool provider icon url
         """
@@ -45,8 +47,8 @@ class ToolTransformService:
         
         return ''
         
-    @staticmethod
-    def repack_provider(provider: Union[dict, UserToolProvider]):
+    @classmethod
+    def repack_provider(cls, provider: Union[dict, UserToolProvider]):
         """
             repack provider
 
@@ -65,8 +67,9 @@ class ToolTransformService:
                 icon=provider.icon
             )
 
-    @staticmethod
+    @classmethod
     def builtin_provider_to_user_provider(
+        cls,
         provider_controller: BuiltinToolProviderController,
         db_provider: Optional[BuiltinToolProvider],
         decrypt_credentials: bool = True,
@@ -126,8 +129,9 @@ class ToolTransformService:
 
         return result
     
-    @staticmethod
+    @classmethod
     def api_provider_to_controller(
+        cls,
         db_provider: ApiToolProvider,
     ) -> ApiToolProviderController:
         """
@@ -142,8 +146,9 @@ class ToolTransformService:
 
         return controller
     
-    @staticmethod
+    @classmethod
     def workflow_provider_to_controller(
+        cls,
         db_provider: WorkflowToolProvider
     ) -> WorkflowToolProviderController:
         """
@@ -179,8 +184,9 @@ class ToolTransformService:
             labels=labels or []
         )
 
-    @staticmethod
+    @classmethod
     def api_provider_to_user_provider(
+        cls,
         provider_controller: ApiToolProviderController,
         db_provider: ApiToolProvider,
         decrypt_credentials: bool = True,
@@ -231,8 +237,9 @@ class ToolTransformService:
 
         return result
     
-    @staticmethod
+    @classmethod
     def tool_to_user_tool(
+        cls,
         tool: Union[ApiToolBundle, WorkflowTool, Tool], 
         credentials: dict = None, 
         tenant_id: str = None,
@@ -287,4 +294,9 @@ class ToolTransformService:
                 ),
                 parameters=tool.parameters,
                 labels=labels
-            )
+            )
+        
+    @classmethod
+    def transform_messages_to_dict(cls, responses: Generator[ToolInvokeMessage, None, None]):
+        for response in responses:
+            yield response.model_dump()