Browse Source

refactor: tool response to generator

Yeuoly 9 months ago
parent
commit
563d81277b

+ 2 - 0
api/controllers/inner_api/plugin/plugin.py

@@ -21,6 +21,8 @@ class PluginInvokeModelApi(Resource):
         parser.add_argument('parameters', type=dict, required=True, location='json')
         parser.add_argument('parameters', type=dict, required=True, location='json')
 
 
         args = parser.parse_args()
         args = parser.parse_args()
+
+        
         
         
 
 
 class PluginInvokeToolApi(Resource):
 class PluginInvokeToolApi(Resource):

+ 4 - 2
api/core/agent/entities.py

@@ -1,14 +1,16 @@
 from enum import Enum
 from enum import Enum
-from typing import Any, Literal, Optional, Union
+from typing import Any, Optional, Union
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
+from core.tools.entities.tool_entities import ToolProviderType
+
 
 
 class AgentToolEntity(BaseModel):
 class AgentToolEntity(BaseModel):
     """
     """
     Agent Tool Entity.
     Agent Tool Entity.
     """
     """
-    provider_type: Literal["builtin", "api", "workflow"]
+    provider_type: ToolProviderType
     provider_id: str
     provider_id: str
     tool_name: str
     tool_name: str
     tool_parameters: dict[str, Any] = {}
     tool_parameters: dict[str, Any] = {}

+ 5 - 0
api/core/callback_handler/plugin_tool_callback_handler.py

@@ -0,0 +1,5 @@
+from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
+
+
+class DifyPluginCallbackHandler(DifyAgentCallbackHandler):
+    """Callback Handler that prints to std out."""

+ 3 - 2
api/core/tools/tool/api_tool.py

@@ -1,4 +1,5 @@
 import json
 import json
+from collections.abc import Generator
 from os import getenv
 from os import getenv
 from typing import Any
 from typing import Any
 from urllib.parse import urlencode
 from urllib.parse import urlencode
@@ -269,7 +270,7 @@ class ApiTool(Tool):
         except ValueError as e:
         except ValueError as e:
             return value
             return value
 
 
-    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
         """
         """
         invoke http request
         invoke http request
         """
         """
@@ -283,4 +284,4 @@ class ApiTool(Tool):
         response = self.validate_and_parse_response(response)
         response = self.validate_and_parse_response(response)
 
 
         # assemble invoke message
         # assemble invoke message
-        return self.create_text_message(response)
+        yield self.create_text_message(response)

+ 3 - 2
api/core/tools/tool/dataset_retriever_tool.py

@@ -1,3 +1,4 @@
+from collections.abc import Generator
 from typing import Any
 from typing import Any
 
 
 from core.app.app_config.entities import DatasetRetrieveConfigEntity
 from core.app.app_config.entities import DatasetRetrieveConfigEntity
@@ -86,7 +87,7 @@ class DatasetRetrieverTool(Tool):
     def tool_provider_type(self) -> ToolProviderType:
     def tool_provider_type(self) -> ToolProviderType:
         return ToolProviderType.DATASET_RETRIEVAL
         return ToolProviderType.DATASET_RETRIEVAL
 
 
-    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
         """
         """
         invoke dataset retriever tool
         invoke dataset retriever tool
         """
         """
@@ -97,7 +98,7 @@ class DatasetRetrieverTool(Tool):
         # invoke dataset retriever tool
         # invoke dataset retriever tool
         result = self.retrival_tool._run(query=query)
         result = self.retrival_tool._run(query=query)
 
 
-        return self.create_text_message(text=result)
+        yield self.create_text_message(text=result)
 
 
     def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
     def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
         """
         """

+ 3 - 5
api/core/tools/tool/tool.py

@@ -1,4 +1,5 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
+from collections.abc import Generator
 from copy import deepcopy
 from copy import deepcopy
 from enum import Enum
 from enum import Enum
 from typing import Any, Optional, Union
 from typing import Any, Optional, Union
@@ -190,7 +191,7 @@ class Tool(BaseModel, ABC):
 
 
         return result
         return result
 
 
-    def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
+    def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
         # update tool_parameters
         # update tool_parameters
         if self.runtime.runtime_parameters:
         if self.runtime.runtime_parameters:
             tool_parameters.update(self.runtime.runtime_parameters)
             tool_parameters.update(self.runtime.runtime_parameters)
@@ -203,9 +204,6 @@ class Tool(BaseModel, ABC):
             tool_parameters=tool_parameters,
             tool_parameters=tool_parameters,
         )
         )
 
 
-        if not isinstance(result, list):
-            result = [result]
-
         return result
         return result
 
 
     def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
     def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
@@ -221,7 +219,7 @@ class Tool(BaseModel, ABC):
         return result
         return result
 
 
     @abstractmethod
     @abstractmethod
-    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
         pass
         pass
     
     
     def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
     def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:

+ 5 - 7
api/core/tools/tool/workflow_tool.py

@@ -1,5 +1,6 @@
 import json
 import json
 import logging
 import logging
+from collections.abc import Generator
 from copy import deepcopy
 from copy import deepcopy
 from typing import Any, Union
 from typing import Any, Union
 
 
@@ -34,7 +35,7 @@ class WorkflowTool(Tool):
 
 
     def _invoke(
     def _invoke(
         self, user_id: str, tool_parameters: dict[str, Any]
         self, user_id: str, tool_parameters: dict[str, Any]
-    ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
+    ) -> Generator[ToolInvokeMessage, None, None]:
         """
         """
             invoke the tool
             invoke the tool
         """
         """
@@ -46,6 +47,7 @@ class WorkflowTool(Tool):
 
 
         from core.app.apps.workflow.app_generator import WorkflowAppGenerator
         from core.app.apps.workflow.app_generator import WorkflowAppGenerator
         generator = WorkflowAppGenerator()
         generator = WorkflowAppGenerator()
+
         result = generator.generate(
         result = generator.generate(
             app_model=app, 
             app_model=app, 
             workflow=workflow, 
             workflow=workflow, 
@@ -64,16 +66,12 @@ class WorkflowTool(Tool):
         if data.get('error'):
         if data.get('error'):
             raise Exception(data.get('error'))
             raise Exception(data.get('error'))
         
         
-        result = []
-
         outputs = data.get('outputs', {})
         outputs = data.get('outputs', {})
         outputs, files = self._extract_files(outputs)
         outputs, files = self._extract_files(outputs)
         for file in files:
         for file in files:
-            result.append(self.create_file_var_message(file))
+            yield self.create_file_var_message(file)
         
         
-        result.append(self.create_text_message(json.dumps(outputs, ensure_ascii=False)))
-
-        return result
+        yield self.create_text_message(json.dumps(outputs, ensure_ascii=False))
 
 
     def _get_user(self, user_id: str) -> Union[EndUser, Account]:
     def _get_user(self, user_id: str) -> Union[EndUser, Account]:
         """
         """

+ 53 - 13
api/core/tools/tool_engine.py

@@ -1,4 +1,5 @@
 import json
 import json
+from collections.abc import Generator
 from copy import deepcopy
 from copy import deepcopy
 from datetime import datetime, timezone
 from datetime import datetime, timezone
 from mimetypes import guess_type
 from mimetypes import guess_type
@@ -8,6 +9,7 @@ from yarl import URL
 
 
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.app.entities.app_invoke_entities import InvokeFrom
 from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
 from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
+from core.callback_handler.plugin_tool_callback_handler import DifyPluginCallbackHandler
 from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
 from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
 from core.file.file_obj import FileTransferMethod
 from core.file.file_obj import FileTransferMethod
 from core.ops.ops_trace_manager import TraceQueueManager
 from core.ops.ops_trace_manager import TraceQueueManager
@@ -64,16 +66,25 @@ class ToolEngine:
                 tool_inputs=tool_parameters
                 tool_inputs=tool_parameters
             )
             )
 
 
-            meta, response = ToolEngine._invoke(tool, tool_parameters, user_id)
-            response = ToolFileMessageTransformer.transform_tool_invoke_messages(
-                messages=response, 
-                user_id=user_id, 
-                tenant_id=tenant_id, 
+            messages = ToolEngine._invoke(tool, tool_parameters, user_id)
+            invocation_meta_dict = {'meta': None}
+
+            def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage, None, None]):
+                for message in messages:
+                    if isinstance(message, ToolInvokeMeta):
+                        invocation_meta_dict['meta'] = message
+                    else:
+                        yield message
+
+            messages = ToolFileMessageTransformer.transform_tool_invoke_messages(
+                messages=message_callback(invocation_meta_dict, messages),
+                user_id=user_id,
+                tenant_id=tenant_id,
                 conversation_id=message.conversation_id
                 conversation_id=message.conversation_id
             )
             )
 
 
             # extract binary data from tool invoke message
             # extract binary data from tool invoke message
-            binary_files = ToolEngine._extract_tool_response_binary(response)
+            binary_files = ToolEngine._extract_tool_response_binary(messages)
             # create message file
             # create message file
             message_files = ToolEngine._create_message_files(
             message_files = ToolEngine._create_message_files(
                 tool_messages=binary_files,
                 tool_messages=binary_files,
@@ -82,7 +93,9 @@ class ToolEngine:
                 user_id=user_id
                 user_id=user_id
             )
             )
 
 
-            plain_text = ToolEngine._convert_tool_response_to_str(response)
+            plain_text = ToolEngine._convert_tool_response_to_str(messages)
+
+            meta = invocation_meta_dict['meta']
 
 
             # hit the callback handler
             # hit the callback handler
             agent_tool_callback.on_tool_end(
             agent_tool_callback.on_tool_end(
@@ -127,7 +140,7 @@ class ToolEngine:
                         user_id: str, workflow_id: str, 
                         user_id: str, workflow_id: str, 
                         workflow_tool_callback: DifyWorkflowCallbackHandler,
                         workflow_tool_callback: DifyWorkflowCallbackHandler,
                         workflow_call_depth: int,
                         workflow_call_depth: int,
-                        ) -> list[ToolInvokeMessage]:
+        ) -> Generator[ToolInvokeMessage, None, None]:
         """
         """
         Workflow invokes the tool with the given arguments.
         Workflow invokes the tool with the given arguments.
         """
         """
@@ -154,10 +167,38 @@ class ToolEngine:
         except Exception as e:
         except Exception as e:
             workflow_tool_callback.on_tool_error(e)
             workflow_tool_callback.on_tool_error(e)
             raise e
             raise e
-        
+    
+    @staticmethod
+    def plugin_invoke(tool: Tool, tool_parameters: dict, user_id: str,
+                      callback: DifyPluginCallbackHandler
+        ) -> Generator[ToolInvokeMessage, None, None]:
+        """
+        Plugin invokes the tool with the given arguments.
+        """
+        try:
+            # hit the callback handler
+            callback.on_tool_start(
+                tool_name=tool.identity.name, 
+                tool_inputs=tool_parameters
+            )
+
+            response = tool.invoke(user_id, tool_parameters)
+
+            # hit the callback handler
+            callback.on_tool_end(
+                tool_name=tool.identity.name,
+                tool_inputs=tool_parameters,
+                tool_outputs=response,
+            )
+
+            return response
+        except Exception as e:
+            callback.on_tool_error(e)
+            raise e
+    
     @staticmethod
     @staticmethod
     def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \
     def _invoke(tool: Tool, tool_parameters: dict, user_id: str) \
-          -> tuple[ToolInvokeMeta, list[ToolInvokeMessage]]:
+          -> Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]:
         """
         """
         Invoke the tool with the given arguments.
         Invoke the tool with the given arguments.
         """
         """
@@ -170,16 +211,15 @@ class ToolEngine:
             'tool_icon': tool.identity.icon
             'tool_icon': tool.identity.icon
         })
         })
         try:
         try:
-            response = tool.invoke(user_id, tool_parameters)
+            yield from tool.invoke(user_id, tool_parameters)
         except Exception as e:
         except Exception as e:
             meta.error = str(e)
             meta.error = str(e)
             raise ToolEngineInvokeError(meta)
             raise ToolEngineInvokeError(meta)
         finally:
         finally:
             ended_at = datetime.now(timezone.utc)
             ended_at = datetime.now(timezone.utc)
             meta.time_cost = (ended_at - started_at).total_seconds()
             meta.time_cost = (ended_at - started_at).total_seconds()
+            yield meta
 
 
-        return meta, response
-    
     @staticmethod
     @staticmethod
     def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
     def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
         """
         """

+ 13 - 37
api/core/tools/tool_manager.py

@@ -18,6 +18,7 @@ from core.tools.entities.tool_entities import (
     ApiProviderAuthType,
     ApiProviderAuthType,
     ToolInvokeFrom,
     ToolInvokeFrom,
     ToolParameter,
     ToolParameter,
+    ToolProviderType,
 )
 )
 from core.tools.errors import ToolProviderNotFoundError
 from core.tools.errors import ToolProviderNotFoundError
 from core.tools.provider.api_tool_provider import ApiToolProviderController
 from core.tools.provider.api_tool_provider import ApiToolProviderController
@@ -26,6 +27,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
 from core.tools.tool.api_tool import ApiTool
 from core.tools.tool.api_tool import ApiTool
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.tool import Tool
 from core.tools.tool.tool import Tool
+from core.tools.tool.workflow_tool import WorkflowTool
 from core.tools.tool_label_manager import ToolLabelManager
 from core.tools.tool_label_manager import ToolLabelManager
 from core.tools.utils.configuration import (
 from core.tools.utils.configuration import (
     ToolConfigurationManager,
     ToolConfigurationManager,
@@ -78,37 +80,13 @@ class ToolManager:
         return tool
         return tool
 
 
     @classmethod
     @classmethod
-    def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
-            -> Union[BuiltinTool, ApiTool]:
-        """
-            get the tool
-
-            :param provider_type: the type of the provider
-            :param provider_name: the name of the provider
-            :param tool_name: the name of the tool
-
-            :return: the tool
-        """
-        if provider_type == 'builtin':
-            return cls.get_builtin_tool(provider_id, tool_name)
-        elif provider_type == 'api':
-            if tenant_id is None:
-                raise ValueError('tenant id is required for api provider')
-            api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id)
-            return api_provider.get_tool(tool_name)
-        elif provider_type == 'app':
-            raise NotImplementedError('app provider not implemented')
-        else:
-            raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
-
-    @classmethod
-    def get_tool_runtime(cls, provider_type: str,
+    def get_tool_runtime(cls, provider_type: ToolProviderType,
                          provider_id: str,
                          provider_id: str,
                          tool_name: str,
                          tool_name: str,
                          tenant_id: str,
                          tenant_id: str,
                          invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
                          invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
                          tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
                          tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-        -> Union[BuiltinTool, ApiTool]:
+        -> Union[BuiltinTool, ApiTool, WorkflowTool]:
         """
         """
             get the tool runtime
             get the tool runtime
 
 
@@ -118,7 +96,7 @@ class ToolManager:
 
 
             :return: the tool
             :return: the tool
         """
         """
-        if provider_type == 'builtin':
+        if provider_type == ToolProviderType.BUILT_IN:
             builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
             builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
 
 
             # check if the builtin tool need credentials
             # check if the builtin tool need credentials
@@ -155,7 +133,7 @@ class ToolManager:
                 'tool_invoke_from': tool_invoke_from,
                 'tool_invoke_from': tool_invoke_from,
             })
             })
 
 
-        elif provider_type == 'api':
+        elif provider_type == ToolProviderType.API:
             if tenant_id is None:
             if tenant_id is None:
                 raise ValueError('tenant id is required for api provider')
                 raise ValueError('tenant id is required for api provider')
 
 
@@ -171,7 +149,7 @@ class ToolManager:
                 'invoke_from': invoke_from,
                 'invoke_from': invoke_from,
                 'tool_invoke_from': tool_invoke_from,
                 'tool_invoke_from': tool_invoke_from,
             })
             })
-        elif provider_type == 'workflow':
+        elif provider_type == ToolProviderType.WORKFLOW:
             workflow_provider = db.session.query(WorkflowToolProvider).filter(
             workflow_provider = db.session.query(WorkflowToolProvider).filter(
                 WorkflowToolProvider.tenant_id == tenant_id,
                 WorkflowToolProvider.tenant_id == tenant_id,
                 WorkflowToolProvider.id == provider_id
                 WorkflowToolProvider.id == provider_id
@@ -190,10 +168,10 @@ class ToolManager:
                 'invoke_from': invoke_from,
                 'invoke_from': invoke_from,
                 'tool_invoke_from': tool_invoke_from,
                 'tool_invoke_from': tool_invoke_from,
             })
             })
-        elif provider_type == 'app':
+        elif provider_type == ToolProviderType.APP:
             raise NotImplementedError('app provider not implemented')
             raise NotImplementedError('app provider not implemented')
         else:
         else:
-            raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
+            raise ToolProviderNotFoundError(f'provider type {provider_type.value} not found')
 
 
     @classmethod
     @classmethod
     def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
     def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
@@ -554,7 +532,7 @@ class ToolManager:
         })
         })
 
 
     @classmethod
     @classmethod
-    def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]:
+    def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
         """
         """
             get the tool icon
             get the tool icon
 
 
@@ -563,14 +541,12 @@ class ToolManager:
             :param provider_id: the id of the provider
             :param provider_id: the id of the provider
             :return:
             :return:
         """
         """
-        provider_type = provider_type
-        provider_id = provider_id
-        if provider_type == 'builtin':
+        if provider_type == ToolProviderType.BUILT_IN:
             return (current_app.config.get("CONSOLE_API_URL")
             return (current_app.config.get("CONSOLE_API_URL")
                     + "/console/api/workspaces/current/tool-provider/builtin/"
                     + "/console/api/workspaces/current/tool-provider/builtin/"
                     + provider_id
                     + provider_id
                     + "/icon")
                     + "/icon")
-        elif provider_type == 'api':
+        elif provider_type == ToolProviderType.API:
             try:
             try:
                 provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
                 provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
                     ApiToolProvider.tenant_id == tenant_id,
                     ApiToolProvider.tenant_id == tenant_id,
@@ -582,7 +558,7 @@ class ToolManager:
                     "background": "#252525",
                     "background": "#252525",
                     "content": "\ud83d\ude01"
                     "content": "\ud83d\ude01"
                 }
                 }
-        elif provider_type == 'workflow':
+        elif provider_type == ToolProviderType.WORKFLOW:
             provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
             provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
                 WorkflowToolProvider.tenant_id == tenant_id,
                 WorkflowToolProvider.tenant_id == tenant_id,
                 WorkflowToolProvider.id == provider_id
                 WorkflowToolProvider.id == provider_id

+ 4 - 3
api/core/tools/utils/configuration.py

@@ -9,6 +9,7 @@ from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolPr
 from core.tools.entities.tool_entities import (
 from core.tools.entities.tool_entities import (
     ToolParameter,
     ToolParameter,
     ToolProviderCredentials,
     ToolProviderCredentials,
+    ToolProviderType,
 )
 )
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.tool.tool import Tool
 from core.tools.tool.tool import Tool
@@ -108,7 +109,7 @@ class ToolParameterConfigurationManager(BaseModel):
     tenant_id: str
     tenant_id: str
     tool_runtime: Tool
     tool_runtime: Tool
     provider_name: str
     provider_name: str
-    provider_type: str
+    provider_type: ToolProviderType
     identity_id: str
     identity_id: str
 
 
     def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
     def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
@@ -191,7 +192,7 @@ class ToolParameterConfigurationManager(BaseModel):
         """
         """
         cache = ToolParameterCache(
         cache = ToolParameterCache(
             tenant_id=self.tenant_id,
             tenant_id=self.tenant_id,
-            provider=f'{self.provider_type}.{self.provider_name}',
+            provider=f'{self.provider_type.value}.{self.provider_name}',
             tool_name=self.tool_runtime.identity.name,
             tool_name=self.tool_runtime.identity.name,
             cache_type=ToolParameterCacheType.PARAMETER,
             cache_type=ToolParameterCacheType.PARAMETER,
             identity_id=self.identity_id
             identity_id=self.identity_id
@@ -221,7 +222,7 @@ class ToolParameterConfigurationManager(BaseModel):
     def delete_tool_parameters_cache(self):
     def delete_tool_parameters_cache(self):
         cache = ToolParameterCache(
         cache = ToolParameterCache(
             tenant_id=self.tenant_id,
             tenant_id=self.tenant_id,
-            provider=f'{self.provider_type}.{self.provider_name}',
+            provider=f'{self.provider_type.value}.{self.provider_name}',
             tool_name=self.tool_runtime.identity.name,
             tool_name=self.tool_runtime.identity.name,
             cache_type=ToolParameterCacheType.PARAMETER,
             cache_type=ToolParameterCacheType.PARAMETER,
             identity_id=self.identity_id
             identity_id=self.identity_id

+ 19 - 22
api/core/tools/utils/message_transformer.py

@@ -1,4 +1,5 @@
 import logging
 import logging
+from collections.abc import Generator
 from mimetypes import guess_extension
 from mimetypes import guess_extension
 
 
 from core.file.file_obj import FileTransferMethod, FileType, FileVar
 from core.file.file_obj import FileTransferMethod, FileType, FileVar
@@ -9,20 +10,18 @@ logger = logging.getLogger(__name__)
 
 
 class ToolFileMessageTransformer:
 class ToolFileMessageTransformer:
     @classmethod
     @classmethod
-    def transform_tool_invoke_messages(cls, messages: list[ToolInvokeMessage],
+    def transform_tool_invoke_messages(cls, messages: Generator[ToolInvokeMessage, None, None],
                                        user_id: str,
                                        user_id: str,
                                        tenant_id: str,
                                        tenant_id: str,
-                                       conversation_id: str) -> list[ToolInvokeMessage]:
+                                       conversation_id: str) -> Generator[ToolInvokeMessage, None, None]:
         """
         """
         Transform tool message and handle file download
         Transform tool message and handle file download
         """
         """
-        result = []
-
         for message in messages:
         for message in messages:
             if message.type == ToolInvokeMessage.MessageType.TEXT:
             if message.type == ToolInvokeMessage.MessageType.TEXT:
-                result.append(message)
+                yield message
             elif message.type == ToolInvokeMessage.MessageType.LINK:
             elif message.type == ToolInvokeMessage.MessageType.LINK:
-                result.append(message)
+                yield message
             elif message.type == ToolInvokeMessage.MessageType.IMAGE:
             elif message.type == ToolInvokeMessage.MessageType.IMAGE:
                 # try to download image
                 # try to download image
                 try:
                 try:
@@ -35,20 +34,20 @@ class ToolFileMessageTransformer:
                     
                     
                     url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
                     url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
 
 
-                    result.append(ToolInvokeMessage(
+                    yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.IMAGE_LINK,
                         type=ToolInvokeMessage.MessageType.IMAGE_LINK,
                         message=url,
                         message=url,
                         save_as=message.save_as,
                         save_as=message.save_as,
                         meta=message.meta.copy() if message.meta is not None else {},
                         meta=message.meta.copy() if message.meta is not None else {},
-                    ))
+                    )
                 except Exception as e:
                 except Exception as e:
                     logger.exception(e)
                     logger.exception(e)
-                    result.append(ToolInvokeMessage(
+                    yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.TEXT,
                         type=ToolInvokeMessage.MessageType.TEXT,
                         message=f"Failed to download image: {message.message}, you can try to download it yourself.",
                         message=f"Failed to download image: {message.message}, you can try to download it yourself.",
                         meta=message.meta.copy() if message.meta is not None else {},
                         meta=message.meta.copy() if message.meta is not None else {},
                         save_as=message.save_as,
                         save_as=message.save_as,
-                    ))
+                    )
             elif message.type == ToolInvokeMessage.MessageType.BLOB:
             elif message.type == ToolInvokeMessage.MessageType.BLOB:
                 # get mime type and save blob to storage
                 # get mime type and save blob to storage
                 mimetype = message.meta.get('mime_type', 'octet/stream')
                 mimetype = message.meta.get('mime_type', 'octet/stream')
@@ -67,43 +66,41 @@ class ToolFileMessageTransformer:
 
 
                 # check if file is image
                 # check if file is image
                 if 'image' in mimetype:
                 if 'image' in mimetype:
-                    result.append(ToolInvokeMessage(
+                    yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.IMAGE_LINK,
                         type=ToolInvokeMessage.MessageType.IMAGE_LINK,
                         message=url,
                         message=url,
                         save_as=message.save_as,
                         save_as=message.save_as,
                         meta=message.meta.copy() if message.meta is not None else {},
                         meta=message.meta.copy() if message.meta is not None else {},
-                    ))
+                    )
                 else:
                 else:
-                    result.append(ToolInvokeMessage(
+                    yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.LINK,
                         type=ToolInvokeMessage.MessageType.LINK,
                         message=url,
                         message=url,
                         save_as=message.save_as,
                         save_as=message.save_as,
                         meta=message.meta.copy() if message.meta is not None else {},
                         meta=message.meta.copy() if message.meta is not None else {},
-                    ))
+                    )
             elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
             elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
                 file_var: FileVar = message.meta.get('file_var')
                 file_var: FileVar = message.meta.get('file_var')
                 if file_var:
                 if file_var:
                     if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
                     if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
                         url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
                         url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
                         if file_var.type == FileType.IMAGE:
                         if file_var.type == FileType.IMAGE:
-                            result.append(ToolInvokeMessage(
+                            yield ToolInvokeMessage(
                                 type=ToolInvokeMessage.MessageType.IMAGE_LINK,
                                 type=ToolInvokeMessage.MessageType.IMAGE_LINK,
                                 message=url,
                                 message=url,
                                 save_as=message.save_as,
                                 save_as=message.save_as,
                                 meta=message.meta.copy() if message.meta is not None else {},
                                 meta=message.meta.copy() if message.meta is not None else {},
-                            ))
+                            )
                         else:
                         else:
-                            result.append(ToolInvokeMessage(
+                            yield ToolInvokeMessage(
                                 type=ToolInvokeMessage.MessageType.LINK,
                                 type=ToolInvokeMessage.MessageType.LINK,
                                 message=url,
                                 message=url,
                                 save_as=message.save_as,
                                 save_as=message.save_as,
                                 meta=message.meta.copy() if message.meta is not None else {},
                                 meta=message.meta.copy() if message.meta is not None else {},
-                            ))
+                            )
             else:
             else:
-                result.append(message)
-
-        return result
+                yield message
     
     
     @classmethod
     @classmethod
     def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
     def get_tool_file_url(cls, tool_file_id: str, extension: str) -> str:
-        return f'/files/tools/{tool_file_id}{extension or ".bin"}'
+        return f'/files/tools/{tool_file_id}{extension or ".bin"}'

+ 2 - 1
api/core/workflow/nodes/tool/entities.py

@@ -3,12 +3,13 @@ from typing import Any, Literal, Union
 from pydantic import BaseModel, field_validator
 from pydantic import BaseModel, field_validator
 from pydantic_core.core_schema import ValidationInfo
 from pydantic_core.core_schema import ValidationInfo
 
 
+from core.tools.entities.tool_entities import ToolProviderType
 from core.workflow.entities.base_node_data_entities import BaseNodeData
 from core.workflow.entities.base_node_data_entities import BaseNodeData
 
 
 
 
 class ToolEntity(BaseModel):
 class ToolEntity(BaseModel):
     provider_id: str
     provider_id: str
-    provider_type: Literal['builtin', 'api', 'workflow']
+    provider_type: ToolProviderType
     provider_name: str # redundancy
     provider_name: str # redundancy
     tool_name: str
     tool_name: str
     tool_label: str # redundancy
     tool_label: str # redundancy

+ 1 - 1
api/core/workflow/nodes/tool/tool_node.py

@@ -32,7 +32,7 @@ class ToolNode(BaseNode):
 
 
         # fetch tool icon
         # fetch tool icon
         tool_info = {
         tool_info = {
-            'provider_type': node_data.provider_type,
+            'provider_type': node_data.provider_type.value,
             'provider_id': node_data.provider_id
             'provider_id': node_data.provider_id
         }
         }
 
 

+ 38 - 5
api/services/plugin/plugin_invoke_service.py

@@ -1,16 +1,49 @@
 from collections.abc import Generator
 from collections.abc import Generator
-from typing import Any
+from typing import Any, Union
 
 
-from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.callback_handler.plugin_tool_callback_handler import DifyPluginCallbackHandler
+from core.model_runtime.entities.model_entities import ModelType
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType
+from core.tools.tool_engine import ToolEngine
+from core.tools.tool_manager import ToolManager
+from core.tools.utils.message_transformer import ToolFileMessageTransformer
+from core.workflow.entities.node_entities import NodeType
 from models.account import Tenant
 from models.account import Tenant
+from services.tools.tools_transform_service import ToolTransformService
 
 
 
 
 class PluginInvokeService:
 class PluginInvokeService:
     @classmethod
     @classmethod
-    def invoke_tool(cls, user_id: str, tenant: Tenant, 
-                    tool_provider: str, tool_name: str,
+    def invoke_tool(cls, user_id: str, invoke_from: InvokeFrom, tenant: Tenant, 
+                    tool_provider_type: ToolProviderType, tool_provider: str, tool_name: str,
                     tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
                     tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
         """
         """
         Invokes a tool with the given user ID and tool parameters.
         Invokes a tool with the given user ID and tool parameters.
         """
         """
-        
+        tool_runtime = ToolManager.get_tool_runtime(tool_provider_type, provider_id=tool_provider, 
+                                                    tool_name=tool_name, tenant_id=tenant.id, 
+                                                    invoke_from=invoke_from)
+        
+        response = ToolEngine.plugin_invoke(tool_runtime, 
+                                            tool_parameters, 
+                                            user_id, 
+                                            callback=DifyPluginCallbackHandler())
+        response = ToolFileMessageTransformer.transform_tool_invoke_messages(response)
+        return ToolTransformService.transform_messages_to_dict(response)
+        
+    @classmethod
+    def invoke_model(cls, user_id: str, tenant: Tenant, 
+                     model_provider: str, model_name: str, model_type: ModelType,
+                     model_parameters: dict[str, Any]) -> Union[dict, Generator[ToolInvokeMessage]]:
+        """
+        Invokes a model with the given user ID and model parameters.
+        """
+
+    @classmethod
+    def invoke_workflow_node(cls, user_id: str, tenant: Tenant, 
+                              node_type: NodeType, node_data: dict[str, Any],
+                              inputs: dict[str, Any]) -> Generator[ToolInvokeMessage]:
+        """
+        Invokes a workflow node with the given user ID and node parameters.
+        """

+ 22 - 10
api/services/tools/tools_transform_service.py

@@ -1,5 +1,6 @@
 import json
 import json
 import logging
 import logging
+from collections.abc import Generator
 from typing import Optional, Union
 from typing import Optional, Union
 
 
 from flask import current_app
 from flask import current_app
@@ -9,6 +10,7 @@ from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import (
 from core.tools.entities.tool_entities import (
     ApiProviderAuthType,
     ApiProviderAuthType,
+    ToolInvokeMessage,
     ToolParameter,
     ToolParameter,
     ToolProviderCredentials,
     ToolProviderCredentials,
     ToolProviderType,
     ToolProviderType,
@@ -24,8 +26,8 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 class ToolTransformService:
 class ToolTransformService:
-    @staticmethod
-    def get_tool_provider_icon_url(provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
+    @classmethod
+    def get_tool_provider_icon_url(cls, provider_type: str, provider_name: str, icon: str) -> Union[str, dict]:
         """
         """
             get tool provider icon url
             get tool provider icon url
         """
         """
@@ -45,8 +47,8 @@ class ToolTransformService:
         
         
         return ''
         return ''
         
         
-    @staticmethod
-    def repack_provider(provider: Union[dict, UserToolProvider]):
+    @classmethod
+    def repack_provider(cls, provider: Union[dict, UserToolProvider]):
         """
         """
             repack provider
             repack provider
 
 
@@ -65,8 +67,9 @@ class ToolTransformService:
                 icon=provider.icon
                 icon=provider.icon
             )
             )
 
 
-    @staticmethod
+    @classmethod
     def builtin_provider_to_user_provider(
     def builtin_provider_to_user_provider(
+        cls,
         provider_controller: BuiltinToolProviderController,
         provider_controller: BuiltinToolProviderController,
         db_provider: Optional[BuiltinToolProvider],
         db_provider: Optional[BuiltinToolProvider],
         decrypt_credentials: bool = True,
         decrypt_credentials: bool = True,
@@ -126,8 +129,9 @@ class ToolTransformService:
 
 
         return result
         return result
     
     
-    @staticmethod
+    @classmethod
     def api_provider_to_controller(
     def api_provider_to_controller(
+        cls,
         db_provider: ApiToolProvider,
         db_provider: ApiToolProvider,
     ) -> ApiToolProviderController:
     ) -> ApiToolProviderController:
         """
         """
@@ -142,8 +146,9 @@ class ToolTransformService:
 
 
         return controller
         return controller
     
     
-    @staticmethod
+    @classmethod
     def workflow_provider_to_controller(
     def workflow_provider_to_controller(
+        cls,
         db_provider: WorkflowToolProvider
         db_provider: WorkflowToolProvider
     ) -> WorkflowToolProviderController:
     ) -> WorkflowToolProviderController:
         """
         """
@@ -179,8 +184,9 @@ class ToolTransformService:
             labels=labels or []
             labels=labels or []
         )
         )
 
 
-    @staticmethod
+    @classmethod
     def api_provider_to_user_provider(
     def api_provider_to_user_provider(
+        cls,
         provider_controller: ApiToolProviderController,
         provider_controller: ApiToolProviderController,
         db_provider: ApiToolProvider,
         db_provider: ApiToolProvider,
         decrypt_credentials: bool = True,
         decrypt_credentials: bool = True,
@@ -231,8 +237,9 @@ class ToolTransformService:
 
 
         return result
         return result
     
     
-    @staticmethod
+    @classmethod
     def tool_to_user_tool(
     def tool_to_user_tool(
+        cls,
         tool: Union[ApiToolBundle, WorkflowTool, Tool], 
         tool: Union[ApiToolBundle, WorkflowTool, Tool], 
         credentials: dict = None, 
         credentials: dict = None, 
         tenant_id: str = None,
         tenant_id: str = None,
@@ -287,4 +294,9 @@ class ToolTransformService:
                 ),
                 ),
                 parameters=tool.parameters,
                 parameters=tool.parameters,
                 labels=labels
                 labels=labels
-            )
+            )
+        
+    @classmethod
+    def transform_messages_to_dict(cls, responses: Generator[ToolInvokeMessage, None, None]):
+        for response in responses:
+            yield response.model_dump()