Yeuoly пре 10 месеци
родитељ
комит
cf4e9f317e

+ 2 - 2
api/core/callback_handler/agent_tool_callback_handler.py

@@ -1,5 +1,5 @@
 import os
-from collections.abc import Mapping, Sequence
+from collections.abc import Iterable, Mapping
 from typing import Any, Optional, TextIO, Union
 
 from pydantic import BaseModel
@@ -55,7 +55,7 @@ class DifyAgentCallbackHandler(BaseModel):
         self,
         tool_name: str,
         tool_inputs: Mapping[str, Any],
-        tool_outputs: Sequence[ToolInvokeMessage],
+        tool_outputs: Iterable[ToolInvokeMessage] | str,
         message_id: Optional[str] = None,
         timer: Optional[Any] = None,
         trace_manager: Optional[TraceQueueManager] = None

+ 32 - 27
api/core/tools/tool_engine.py

@@ -1,9 +1,9 @@
 import json
-from collections.abc import Generator, Mapping
+from collections.abc import Generator, Iterable
 from copy import deepcopy
 from datetime import datetime, timezone
 from mimetypes import guess_type
-from typing import Any, Optional, Union
+from typing import Any, Optional, Union, cast
 
 from yarl import URL
 
@@ -40,7 +40,7 @@ class ToolEngine:
         user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
         agent_tool_callback: DifyAgentCallbackHandler,
         trace_manager: Optional[TraceQueueManager] = None
-    ) -> tuple[str, list[tuple[MessageFile, bool]], ToolInvokeMeta]:
+    ) -> tuple[str, list[tuple[MessageFile, str]], ToolInvokeMeta]:
         """
         Agent invokes the tool with the given arguments.
         """
@@ -67,9 +67,9 @@ class ToolEngine:
             )
 
             messages = ToolEngine._invoke(tool, tool_parameters, user_id)
-            invocation_meta_dict = {'meta': None}
+            invocation_meta_dict: dict[str, ToolInvokeMeta] = {}
 
-            def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage, None, None]):
+            def message_callback(invocation_meta_dict: dict, messages: Generator[ToolInvokeMessage | ToolInvokeMeta, None, None]):
                 for message in messages:
                     if isinstance(message, ToolInvokeMeta):
                         invocation_meta_dict['meta'] = message
@@ -136,7 +136,7 @@ class ToolEngine:
         return error_response, [], ToolInvokeMeta.error_instance(error_response)
 
     @staticmethod
-    def workflow_invoke(tool: Tool, tool_parameters: Mapping[str, Any],
+    def workflow_invoke(tool: Tool, tool_parameters: dict[str, Any],
                         user_id: str,
                         workflow_tool_callback: DifyWorkflowCallbackHandler,
                         workflow_call_depth: int,
@@ -156,6 +156,7 @@ class ToolEngine:
 
             if tool.runtime and tool.runtime.runtime_parameters:
                 tool_parameters = {**tool.runtime.runtime_parameters, **tool_parameters}
+
             response = tool.invoke(user_id=user_id, tool_parameters=tool_parameters)
 
             # hit the callback handler
@@ -204,6 +205,9 @@ class ToolEngine:
         """
         Invoke the tool with the given arguments.
         """
+        if not tool.runtime:
+            raise ValueError("missing runtime in tool")
+
         started_at = datetime.now(timezone.utc)
         meta = ToolInvokeMeta(time_cost=0.0, error=None, tool_config={
             'tool_name': tool.identity.name,
@@ -223,42 +227,42 @@ class ToolEngine:
             yield meta
 
     @staticmethod
-    def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
+    def _convert_tool_response_to_str(tool_response: Generator[ToolInvokeMessage, None, None]) -> str:
         """
         Handle tool response
         """
         result = ''
         for response in tool_response:
             if response.type == ToolInvokeMessage.MessageType.TEXT:
-                result += response.message
+                result += cast(ToolInvokeMessage.TextMessage, response.message).text
             elif response.type == ToolInvokeMessage.MessageType.LINK:
-                result += f"result link: {response.message}. please tell user to check it."
+                result += f"result link: {cast(ToolInvokeMessage.TextMessage, response.message).text}. please tell user to check it."
             elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
                  response.type == ToolInvokeMessage.MessageType.IMAGE:
                 result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
             elif response.type == ToolInvokeMessage.MessageType.JSON:
-                result += f"tool response: {json.dumps(response.message, ensure_ascii=False)}."
+                result += f"tool response: {json.dumps(cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False)}."
             else:
                 result += f"tool response: {response.message}."
 
         return result
     
     @staticmethod
-    def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
+    def _extract_tool_response_binary(tool_response: Generator[ToolInvokeMessage, None, None]) -> Generator[ToolInvokeMessageBinary, None, None]:
         """
         Extract tool response binary
         """
-        result = []
-
         for response in tool_response:
             if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
                 response.type == ToolInvokeMessage.MessageType.IMAGE:
                 mimetype = None
+                if not response.meta:
+                    raise ValueError("missing meta data")
                 if response.meta.get('mime_type'):
                     mimetype = response.meta.get('mime_type')
                 else:
                     try:
-                        url = URL(response.message)
+                        url = URL(cast(ToolInvokeMessage.TextMessage, response.message).text)
                         extension = url.suffix
                         guess_type_result, _ = guess_type(f'a{extension}')
                         if guess_type_result:
@@ -269,35 +273,36 @@ class ToolEngine:
                 if not mimetype:
                     mimetype = 'image/jpeg'
                     
-                result.append(ToolInvokeMessageBinary(
+                yield ToolInvokeMessageBinary(
                     mimetype=response.meta.get('mime_type', 'image/jpeg'),
-                    url=response.message,
+                    url=cast(ToolInvokeMessage.TextMessage, response.message).text,
                     save_as=response.save_as,
-                ))
+                )
             elif response.type == ToolInvokeMessage.MessageType.BLOB:
-                result.append(ToolInvokeMessageBinary(
+                if not response.meta:
+                    raise ValueError("missing meta data")
+                
+                yield ToolInvokeMessageBinary(
                     mimetype=response.meta.get('mime_type', 'octet/stream'),
-                    url=response.message,
+                    url=cast(ToolInvokeMessage.TextMessage, response.message).text,
                     save_as=response.save_as,
-                ))
+                )
             elif response.type == ToolInvokeMessage.MessageType.LINK:
                 # check if there is a mime type in meta
                 if response.meta and 'mime_type' in response.meta:
-                    result.append(ToolInvokeMessageBinary(
+                    yield ToolInvokeMessageBinary(
                         mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
-                        url=response.message,
+                        url=cast(ToolInvokeMessage.TextMessage, response.message).text,
                         save_as=response.save_as,
-                    ))
-
-        return result
+                    )
     
     @staticmethod
     def _create_message_files(
-        tool_messages: list[ToolInvokeMessageBinary],
+        tool_messages: Iterable[ToolInvokeMessageBinary],
         agent_message: Message,
         invoke_from: InvokeFrom,
         user_id: str
-    ) -> list[tuple[Any, str]]:
+    ) -> list[tuple[MessageFile, str]]:
         """
         Create message file
 

+ 3 - 3
api/core/workflow/nodes/tool/tool_node.py

@@ -1,4 +1,4 @@
-from collections.abc import Generator, Mapping, Sequence
+from collections.abc import Generator, Sequence
 from os import path
 from typing import Any, cast
 
@@ -100,7 +100,7 @@ class ToolNode(BaseNode):
         variable_pool: VariablePool,
         node_data: ToolNodeData,
         for_log: bool = False,
-    ) -> Mapping[str, Any]:
+    ) -> dict[str, Any]:
         """
         Generate parameters based on the given tool parameters, variable pool, and node data.
 
@@ -110,7 +110,7 @@ class ToolNode(BaseNode):
             node_data (ToolNodeData): The data associated with the tool node.
 
         Returns:
-            Mapping[str, Any]: A dictionary containing the generated parameters.
+            dict[str, Any]: A dictionary containing the generated parameters.
 
         """
         tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters}

+ 5 - 0
api/models/base.py

@@ -0,0 +1,5 @@
+from sqlalchemy.orm import DeclarativeBase
+
+
+class Base(DeclarativeBase):
+    pass

+ 16 - 12
api/models/model.py

@@ -14,6 +14,7 @@ from core.file.tool_file_parser import ToolFileParser
 from core.file.upload_file_parser import UploadFileParser
 from extensions.ext_database import db
 from libs.helper import generate_string
+from models.base import Base
 
 from .account import Account, Tenant
 from .types import StringUUID
@@ -211,7 +212,7 @@ class App(db.Model):
         return tags if tags else []
 
 
-class AppModelConfig(db.Model):
+class AppModelConfig(Base):
     __tablename__ = 'app_model_configs'
     __table_args__ = (
         db.PrimaryKeyConstraint('id', name='app_model_config_pkey'),
@@ -550,6 +551,9 @@ class Conversation(db.Model):
             else:
                 app_model_config = db.session.query(AppModelConfig).filter(
                     AppModelConfig.id == self.app_model_config_id).first()
+                
+                if not app_model_config:
+                    raise ValueError("app config not found")
 
                 model_config = app_model_config.to_dict()
 
@@ -640,7 +644,7 @@ class Conversation(db.Model):
         return self.override_model_configs is not None
 
 
-class Message(db.Model):
+class Message(Base):
     __tablename__ = 'messages'
     __table_args__ = (
         db.PrimaryKeyConstraint('id', name='message_pkey'),
@@ -932,7 +936,7 @@ class MessageFeedback(db.Model):
         return account
 
 
-class MessageFile(db.Model):
+class MessageFile(Base):
     __tablename__ = 'message_files'
     __table_args__ = (
         db.PrimaryKeyConstraint('id', name='message_file_pkey'),
@@ -940,15 +944,15 @@ class MessageFile(db.Model):
         db.Index('message_file_created_by_idx', 'created_by')
     )
 
-    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
-    message_id = db.Column(StringUUID, nullable=False)
-    type = db.Column(db.String(255), nullable=False)
-    transfer_method = db.Column(db.String(255), nullable=False)
-    url = db.Column(db.Text, nullable=True)
-    belongs_to = db.Column(db.String(255), nullable=True)
-    upload_file_id = db.Column(StringUUID, nullable=True)
-    created_by_role = db.Column(db.String(255), nullable=False)
-    created_by = db.Column(StringUUID, nullable=False)
+    id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
+    message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+    type: Mapped[str] = mapped_column(db.String(255), nullable=False)
+    transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
+    url: Mapped[str] = mapped_column(db.Text, nullable=True)
+    belongs_to: Mapped[str] = mapped_column(db.String(255), nullable=True)
+    upload_file_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
+    created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
+    created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 

+ 2 - 4
api/models/tools.py

@@ -1,12 +1,13 @@
 import json
 
 from sqlalchemy import ForeignKey
-from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
+from sqlalchemy.orm import Mapped, mapped_column
 
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
 from extensions.ext_database import db
+from models.base import Base
 
 from .model import Account, App, Tenant
 from .types import StringUUID
@@ -277,9 +278,6 @@ class ToolConversationVariables(db.Model):
     @property
     def variables(self) -> dict:
         return json.loads(self.variables_str)
-    
-class Base(DeclarativeBase):
-    pass
 
 class ToolFile(Base):
     """