Browse Source

refactor: tool message transformer

Yeuoly 10 months ago
parent
commit
c28998a6f0

+ 5 - 2
api/core/tools/entities/tool_entities.py

@@ -96,6 +96,9 @@ class ToolInvokeMessage(BaseModel):
     class JsonMessage(BaseModel):
         json_object: dict
 
+    class BlobMessage(BaseModel):
+        blob: bytes
+
     class MessageType(Enum):
         TEXT = "text"
         IMAGE = "image"
@@ -109,7 +112,7 @@ class ToolInvokeMessage(BaseModel):
     """
         plain text, image url or link url
     """
-    message: JsonMessage | TextMessage | None
+    message: JsonMessage | TextMessage | BlobMessage | None
     meta: dict[str, Any] | None = None
     save_as: str = ''
 
@@ -321,7 +324,7 @@ class ToolRuntimeVariablePool(BaseModel):
 
         self.pool.append(variable)
 
-    def set_file(self, tool_name: str, value: str, name: str = None) -> None:
+    def set_file(self, tool_name: str, value: str, name: Optional[str] = None) -> None:
         """
             set an image variable
 

+ 5 - 5
api/core/tools/tool_file_manager.py

@@ -80,8 +80,8 @@ class ToolFileManager:
     def create_file_by_url(
         user_id: str,
         tenant_id: str,
-        conversation_id: str,
         file_url: str,
+        conversation_id: Optional[str] = None,
     ) -> ToolFile:
         """
         create file
@@ -131,7 +131,7 @@ class ToolFileManager:
 
         :return: the binary of the file, mime type
         """
-        tool_file: ToolFile = (
+        tool_file: ToolFile | None = (
             db.session.query(ToolFile)
             .filter(
                 ToolFile.id == id,
@@ -155,7 +155,7 @@ class ToolFileManager:
 
         :return: the binary of the file, mime type
         """
-        message_file: MessageFile = (
+        message_file: MessageFile | None = (
             db.session.query(MessageFile)
             .filter(
                 MessageFile.id == id,
@@ -173,7 +173,7 @@ class ToolFileManager:
             tool_file_id = None
 
 
-        tool_file: ToolFile = (
+        tool_file: ToolFile | None = (
             db.session.query(ToolFile)
             .filter(
                 ToolFile.id == tool_file_id,
@@ -197,7 +197,7 @@ class ToolFileManager:
 
         :return: the binary of the file, mime type
         """
-        tool_file: ToolFile = (
+        tool_file: ToolFile | None = (
             db.session.query(ToolFile)
             .filter(
                 ToolFile.id == tool_file_id,

+ 30 - 14
api/core/tools/utils/message_transformer.py

@@ -1,8 +1,9 @@
 import logging
 from collections.abc import Generator
 from mimetypes import guess_extension
+from typing import Optional
 
-from core.file.file_obj import FileTransferMethod, FileType
+from core.file.file_obj import FileTransferMethod, FileType, FileVar
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.tool_file_manager import ToolFileManager
 
@@ -13,7 +14,7 @@ class ToolFileMessageTransformer:
     def transform_tool_invoke_messages(cls, messages: Generator[ToolInvokeMessage, None, None],
                                        user_id: str,
                                        tenant_id: str,
-                                       conversation_id: str) -> Generator[ToolInvokeMessage, None, None]:
+                                       conversation_id: Optional[str] = None) -> Generator[ToolInvokeMessage, None, None]:
         """
         Transform tool message and handle file download
         """
@@ -25,18 +26,23 @@ class ToolFileMessageTransformer:
             elif message.type == ToolInvokeMessage.MessageType.IMAGE:
                 # try to download image
                 try:
+                    if not conversation_id:
+                        raise 
+
+                    assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+
                     file = ToolFileManager.create_file_by_url(
                         user_id=user_id,
                         tenant_id=tenant_id,
+                        file_url=message.message.text,
                         conversation_id=conversation_id,
-                        file_url=message.message
                     )
 
                     url = f'/files/tools/{file.id}{guess_extension(file.mimetype) or ".png"}'
 
                     yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.IMAGE_LINK,
-                        message=url,
+                        message=ToolInvokeMessage.TextMessage(text=url),
                         save_as=message.save_as,
                         meta=message.meta.copy() if message.meta is not None else {},
                     )
@@ -44,57 +50,67 @@ class ToolFileMessageTransformer:
                     logger.exception(e)
                     yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.TEXT,
-                        message=f"Failed to download image: {message.message}, you can try to download it yourself.",
+                        message=ToolInvokeMessage.TextMessage(
+                            text=f"Failed to download image: {message.message}, you can try to download it yourself."
+                        ),
                         meta=message.meta.copy() if message.meta is not None else {},
                         save_as=message.save_as,
                     )
             elif message.type == ToolInvokeMessage.MessageType.BLOB:
                 # get mime type and save blob to storage
+                assert message.meta
+
                 mimetype = message.meta.get('mime_type', 'octet/stream')
                 # if message is str, encode it to bytes
-                if isinstance(message.message, str):
-                    message.message = message.message.encode('utf-8')
+
+                if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
+                    raise ValueError("unexpected message type")
 
                 file = ToolFileManager.create_file_by_raw(
                     user_id=user_id, tenant_id=tenant_id,
                     conversation_id=conversation_id,
-                    file_binary=message.message,
+                    file_binary=message.message.blob,
                     mimetype=mimetype
                 )
 
-                url = cls.get_tool_file_url(file.id, guess_extension(file.mimetype))
+                extension = guess_extension(file.mimetype) or ".bin"
+                url = cls.get_tool_file_url(file.id, extension)
 
                 # check if file is image
                 if 'image' in mimetype:
                     yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.IMAGE_LINK,
-                        message=url,
+                        message=ToolInvokeMessage.TextMessage(text=url),
                         save_as=message.save_as,
                         meta=message.meta.copy() if message.meta is not None else {},
                     )
                 else:
                     yield ToolInvokeMessage(
                         type=ToolInvokeMessage.MessageType.LINK,
-                        message=url,
+                        message=ToolInvokeMessage.TextMessage(text=url),
                         save_as=message.save_as,
                         meta=message.meta.copy() if message.meta is not None else {},
                     )
             elif message.type == ToolInvokeMessage.MessageType.FILE_VAR:
-                file_var = message.meta.get('file_var')
+                assert message.meta
+
+                file_var: FileVar | None = message.meta.get('file_var')
                 if file_var:
                     if file_var.transfer_method == FileTransferMethod.TOOL_FILE:
+                        assert file_var.related_id and file_var.extension
+
                         url = cls.get_tool_file_url(file_var.related_id, file_var.extension)
                         if file_var.type == FileType.IMAGE:
                             yield ToolInvokeMessage(
                                 type=ToolInvokeMessage.MessageType.IMAGE_LINK,
-                                message=url,
+                                message=ToolInvokeMessage.TextMessage(text=url),
                                 save_as=message.save_as,
                                 meta=message.meta.copy() if message.meta is not None else {},
                             )
                         else:
                             yield ToolInvokeMessage(
                                 type=ToolInvokeMessage.MessageType.LINK,
-                                message=url,
+                                message=ToolInvokeMessage.TextMessage(text=url),
                                 save_as=message.save_as,
                                 meta=message.meta.copy() if message.meta is not None else {},
                             )

+ 2 - 2
api/core/workflow/nodes/tool/tool_node.py

@@ -1,4 +1,4 @@
-from collections.abc import Mapping, Sequence
+from collections.abc import Generator, Mapping, Sequence
 from os import path
 from typing import Any, cast
 
@@ -145,7 +145,7 @@ class ToolNode(BaseNode):
         assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
         return list(variable.value) if variable else []
 
-    def _convert_tool_messages(self, messages: list[ToolInvokeMessage]):
+    def _convert_tool_messages(self, messages: Generator[ToolInvokeMessage, None, None]):
         """
         Convert ToolInvokeMessages into tuple[plain_text, files]
         """

+ 10 - 8
api/models/tools.py

@@ -1,6 +1,7 @@
 import json
 
 from sqlalchemy import ForeignKey
+from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
 
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_bundle import ApiToolBundle
@@ -277,7 +278,7 @@ class ToolConversationVariables(db.Model):
     def variables(self) -> dict:
         return json.loads(self.variables_str)
     
-class ToolFile(db.Model):
+class ToolFile(DeclarativeBase):
     """
     store the file created by agent
     """
@@ -288,16 +289,17 @@ class ToolFile(db.Model):
         db.Index('tool_file_conversation_id_idx', 'conversation_id'),
     )
 
-    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
     # conversation user id
-    user_id = db.Column(StringUUID, nullable=False)
+    user_id: Mapped[str] = mapped_column(StringUUID)
     # tenant id
-    tenant_id = db.Column(StringUUID, nullable=False)
+    tenant_id: Mapped[StringUUID] = mapped_column(StringUUID)
     # conversation id
-    conversation_id = db.Column(StringUUID, nullable=True)
+    conversation_id: Mapped[StringUUID] = mapped_column(nullable=True)
     # file key
-    file_key = db.Column(db.String(255), nullable=False)
+    file_key: Mapped[str] = mapped_column(db.String(255), nullable=False)
     # mime type
-    mimetype = db.Column(db.String(255), nullable=False)
+    mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False)
     # original url
-    original_url = db.Column(db.String(2048), nullable=True)
+    original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True)
+