瀏覽代碼

fix: tool type

Yeuoly 10 月之前
父節點
當前提交
c8b0160ea9
共有 1 個文件被更改,包括 19 次插入13 次删除
  1. 19 13
      api/core/tools/tool/tool.py

+ 19 - 13
api/core/tools/tool/tool.py

@@ -4,7 +4,7 @@ from copy import deepcopy
 from enum import Enum
 from typing import TYPE_CHECKING, Any, Optional, Union
 
-from pydantic import BaseModel, ConfigDict, field_validator
+from pydantic import BaseModel, ConfigDict, Field, field_validator
 from pydantic_core.core_schema import ValidationInfo
 
 from core.app.entities.app_invoke_entities import InvokeFrom
@@ -27,8 +27,8 @@ if TYPE_CHECKING:
 
 
 class Tool(BaseModel, ABC):
-    identity: Optional[ToolIdentity] = None
-    parameters: Optional[list[ToolParameter]] = None
+    identity: ToolIdentity
+    parameters: list[ToolParameter] = Field(default_factory=list)
     description: Optional[ToolDescription] = None
     is_team_authorization: bool = False
 
@@ -194,10 +194,8 @@ class Tool(BaseModel, ABC):
 
         return result
 
-    def invoke(self, user_id: str, tool_parameters: Mapping[str, Any]) -> Generator[ToolInvokeMessage]:
-        # update tool_parameters
-        # TODO: Fix type error.
-        if self.runtime.runtime_parameters:
+    def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
+        if self.runtime and self.runtime.runtime_parameters:
             tool_parameters.update(self.runtime.runtime_parameters)
 
         # try parse tool parameters into the correct type
@@ -210,7 +208,7 @@ class Tool(BaseModel, ABC):
 
         return result
 
-    def _transform_tool_parameters_type(self, tool_parameters: Mapping[str, Any]) -> dict[str, Any]:
+    def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
         """
         Transform tool parameters type
         """
@@ -289,7 +287,7 @@ class Tool(BaseModel, ABC):
         :return: the image message
         """
         return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
-                                 message=image,
+                                 message=ToolInvokeMessage.TextMessage(text=image),
                                  save_as=save_as)
 
     def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
@@ -308,7 +306,7 @@ class Tool(BaseModel, ABC):
         :return: the link message
         """
         return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
-                                 message=link,
+                                 message=ToolInvokeMessage.TextMessage(text=link),
                                  save_as=save_as)
 
     def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
@@ -320,7 +318,7 @@ class Tool(BaseModel, ABC):
         """
         return ToolInvokeMessage(
             type=ToolInvokeMessage.MessageType.TEXT,
-            message=text,
+            message=ToolInvokeMessage.TextMessage(text=text),
             save_as=save_as
         )
 
@@ -331,10 +329,18 @@ class Tool(BaseModel, ABC):
         :param blob: the blob
         :return: the blob message
         """
-        return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, message=blob, meta=meta, save_as=save_as)
+        return ToolInvokeMessage(
+            type=ToolInvokeMessage.MessageType.BLOB, 
+            message=ToolInvokeMessage.BlobMessage(blob=blob), 
+            meta=meta, 
+            save_as=save_as
+        )
 
     def create_json_message(self, object: dict) -> ToolInvokeMessage:
         """
         create a json message
         """
-        return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.JSON, message=object)
+        return ToolInvokeMessage(
+            type=ToolInvokeMessage.MessageType.JSON, 
+            message=ToolInvokeMessage.JsonMessage(json_object=object)
+        )