|
@@ -1,9 +1,9 @@
|
|
|
import json
|
|
|
-from collections.abc import Generator, Mapping
|
|
|
+from collections.abc import Generator, Iterable
|
|
|
from copy import deepcopy
|
|
|
from datetime import datetime, timezone
|
|
|
from mimetypes import guess_type
|
|
|
-from typing import Any, Optional, Union
|
|
|
+from typing import Any, Optional, Union, cast
|
|
|
|
|
|
from yarl import URL
|
|
|
|
|
@@ -40,7 +40,7 @@ class ToolEngine:
|
|
|
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
|
|
|
agent_tool_callback: DifyAgentCallbackHandler,
|
|
|
trace_manager: Optional[TraceQueueManager] = None
|
|
|
- ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
|
|
|
+ ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]:
|
|
|
"""
|
|
|
Agent invokes the tool with the given arguments.
|
|
|
"""
|
|
@@ -67,9 +67,9 @@ class ToolEngine:
|
|
|
)
|
|
|
|
|
|
messages = ToolEngine._invoke(tool, tool_parameters, user_id)
|
|
|
- invocation_meta_dict = {'meta': None}
|
|
|
+ invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
|
|
|
|
|
|
- def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage, None, None]):
|
|
|
+ def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]):
|
|
|
for message in messages:
|
|
|
if isinstance(message, ToolInvokeMeta):
|
|
|
invocation_meta_dict['meta'] = message
|
|
@@ -136,7 +136,7 @@ class ToolEngine:
|
|
|
return error_response, [], ToolInvokeMeta.error_instance(error_response)
|
|
|
|
|
|
@staticmethod
|
|
|
- def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any],
|
|
|
+ def workflow_invoke(tool: Tool, tool_parameters: dict[str, Any],
|
|
|
user_id: str,
|
|
|
workflow_tool_callback: DifyWorkflowCallbackHandler,
|
|
|
workflow_call_depth: int,
|
|
@@ -156,6 +156,7 @@ class ToolEngine:
|
|
|
|
|
|
if tool.runtime and tool.runtime.runtime_parameters:
|
|
|
tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
|
|
|
+
|
|
|
response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters)
|
|
|
|
|
|
# hit the callback handler
|
|
@@ -204,6 +205,9 @@ class ToolEngine:
|
|
|
"""
|
|
|
Invoke the tool with the given arguments.
|
|
|
"""
|
|
|
+ if not tool.runtime:
|
|
|
+ raise ValueError("missing runtime in tool")
|
|
|
+
|
|
|
started_at = datetime.now(timezone.utc)
|
|
|
meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={
|
|
|
'tool_name': tool.identity.name,
|
|
@@ -223,42 +227,42 @@ class ToolEngine:
|
|
|
yield meta
|
|
|
|
|
|
@staticmethod
|
|
|
- def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
|
|
|
+ def _convert_tool_response_to_str(tool_response: Generator[ToolInvokeMessage, None, None]) -> str:
|
|
|
"""
|
|
|
Handle tool response
|
|
|
"""
|
|
|
result = ''
|
|
|
for response in tool_response:
|
|
|
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
|
|
- result += response.message
|
|
|
+ result += cast(ToolInvokeMessage.TextMessage, response.message).text
|
|
|
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
|
|
- result += f"result link: {response.message}. please tell user to check it."
|
|
|
+ result += f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}. please tell user to check it."
|
|
|
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
|
|
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
|
|
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
|
|
|
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
|
|
- result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}."
|
|
|
+ result += f"tool response: {json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False)}."
|
|
|
else:
|
|
|
result += f"tool response: {response.message}."
|
|
|
|
|
|
return result
|
|
|
|
|
|
@staticmethod
|
|
|
- def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
|
|
|
+ def _extract_tool_response_binary(tool_response: Generator[ToolInvokeMessage, None, None]) -> Generator[ToolInvokeMessageBinary, None, None]:
|
|
|
"""
|
|
|
Extract tool response binary
|
|
|
"""
|
|
|
- result = []
|
|
|
-
|
|
|
for response in tool_response:
|
|
|
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
|
|
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
|
|
mimetype = None
|
|
|
+ if not response.meta:
|
|
|
+ raise ValueError("missing meta data")
|
|
|
if response.meta.get('mime_type'):
|
|
|
mimetype = response.meta.get('mime_type')
|
|
|
else:
|
|
|
try:
|
|
|
- url = URL(response.message)
|
|
|
+ url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
|
|
|
extension = url.suffix
|
|
|
guess_type_result, _ = guess_type(f'a{extension}')
|
|
|
if guess_type_result:
|
|
@@ -269,35 +273,36 @@ class ToolEngine:
|
|
|
if not mimetype:
|
|
|
mimetype = 'image/jpeg'
|
|
|
|
|
|
- result.append(ToolInvokeMessageBinary(
|
|
|
+ yield ToolInvokeMessageBinary(
|
|
|
mimetype=response.meta.get('mime_type', 'image/jpeg'),
|
|
|
- url=response.message,
|
|
|
+ url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
|
|
save_as=response.save_as,
|
|
|
- ))
|
|
|
+ )
|
|
|
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
|
|
- result.append(ToolInvokeMessageBinary(
|
|
|
+ if not response.meta:
|
|
|
+ raise ValueError("missing meta data")
|
|
|
+
|
|
|
+ yield ToolInvokeMessageBinary(
|
|
|
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
|
|
- url=response.message,
|
|
|
+ url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
|
|
save_as=response.save_as,
|
|
|
- ))
|
|
|
+ )
|
|
|
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
|
|
# check if there is a mime type in meta
|
|
|
if response.meta and 'mime_type' in response.meta:
|
|
|
- result.append(ToolInvokeMessageBinary(
|
|
|
+ yield ToolInvokeMessageBinary(
|
|
|
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
|
|
|
- url=response.message,
|
|
|
+ url=cast(ToolInvokeMessage.TextMessage, response.message).text,
|
|
|
save_as=response.save_as,
|
|
|
- ))
|
|
|
-
|
|
|
- return result
|
|
|
+ )
|
|
|
|
|
|
@staticmethod
|
|
|
def _create_message_files(
|
|
|
- tool_messages: list[ToolInvokeMessageBinary],
|
|
|
+ tool_messages: Iterable[ToolInvokeMessageBinary],
|
|
|
agent_message: Message,
|
|
|
invoke_from: InvokeFrom,
|
|
|
user_id: str
|
|
|
- ) -> list[tuple[Any, str]]:
|
|
|
+ ) -> list[tuple[MessageFile, str]]:
|
|
|
"""
|
|
|
Create message file
|
|
|
|