tool.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. from abc import ABC, abstractmethod
  2. from collections.abc import Generator
  3. from copy import deepcopy
  4. from typing import TYPE_CHECKING, Any, Optional
  5. from core.tools.__base.tool_runtime import ToolRuntime
  6. from core.tools.entities.tool_entities import (
  7. ToolEntity,
  8. ToolInvokeMessage,
  9. ToolParameter,
  10. ToolProviderType,
  11. )
  12. from core.tools.utils.tool_parameter_converter import ToolParameterConverter
  13. if TYPE_CHECKING:
  14. from core.file.file_obj import FileVar
  15. class Tool(ABC):
  16. """
  17. The base class of a tool
  18. """
  19. entity: ToolEntity
  20. runtime: ToolRuntime
  21. def __init__(self, entity: ToolEntity, runtime: ToolRuntime) -> None:
  22. self.entity = entity
  23. self.runtime = runtime
  24. def fork_tool_runtime(self, runtime: ToolRuntime) -> "Tool":
  25. """
  26. fork a new tool with meta data
  27. :param meta: the meta data of a tool call processing, tenant_id is required
  28. :return: the new tool
  29. """
  30. return self.__class__(
  31. entity=self.entity.model_copy(),
  32. runtime=runtime,
  33. )
  34. @abstractmethod
  35. def tool_provider_type(self) -> ToolProviderType:
  36. """
  37. get the tool provider type
  38. :return: the tool provider type
  39. """
  40. def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage]:
  41. if self.runtime and self.runtime.runtime_parameters:
  42. tool_parameters.update(self.runtime.runtime_parameters)
  43. # try parse tool parameters into the correct type
  44. tool_parameters = self._transform_tool_parameters_type(tool_parameters)
  45. result = self._invoke(
  46. user_id=user_id,
  47. tool_parameters=tool_parameters,
  48. )
  49. if isinstance(result, ToolInvokeMessage):
  50. def single_generator():
  51. yield result
  52. return single_generator()
  53. elif isinstance(result, list):
  54. def generator():
  55. yield from result
  56. return generator()
  57. else:
  58. return result
  59. def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
  60. """
  61. Transform tool parameters type
  62. """
  63. # Temp fix for the issue that the tool parameters will be converted to empty while validating the credentials
  64. result = deepcopy(tool_parameters)
  65. for parameter in self.entity.parameters:
  66. if parameter.name in tool_parameters:
  67. result[parameter.name] = ToolParameterConverter.cast_parameter_by_type(
  68. tool_parameters[parameter.name], parameter.type
  69. )
  70. return result
  71. @abstractmethod
  72. def _invoke(
  73. self, user_id: str, tool_parameters: dict[str, Any]
  74. ) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
  75. pass
  76. def get_runtime_parameters(self) -> list[ToolParameter]:
  77. """
  78. get the runtime parameters
  79. interface for developer to dynamic change the parameters of a tool depends on the variables pool
  80. :return: the runtime parameters
  81. """
  82. return self.entity.parameters
  83. def get_merged_runtime_parameters(self) -> list[ToolParameter]:
  84. """
  85. get merged runtime parameters
  86. :return: merged runtime parameters
  87. """
  88. parameters = self.entity.parameters
  89. parameters = parameters.copy()
  90. user_parameters = self.get_runtime_parameters() or []
  91. user_parameters = user_parameters.copy()
  92. # override parameters
  93. for parameter in user_parameters:
  94. # check if parameter in tool parameters
  95. for tool_parameter in parameters:
  96. if tool_parameter.name == parameter.name:
  97. # override parameter
  98. tool_parameter.type = parameter.type
  99. tool_parameter.form = parameter.form
  100. tool_parameter.required = parameter.required
  101. tool_parameter.default = parameter.default
  102. tool_parameter.options = parameter.options
  103. tool_parameter.llm_description = parameter.llm_description
  104. break
  105. else:
  106. # add new parameter
  107. parameters.append(parameter)
  108. return parameters
  109. def create_image_message(self, image: str, save_as: str = "") -> ToolInvokeMessage:
  110. """
  111. create an image message
  112. :param image: the url of the image
  113. :return: the image message
  114. """
  115. return ToolInvokeMessage(
  116. type=ToolInvokeMessage.MessageType.IMAGE, message=ToolInvokeMessage.TextMessage(text=image), save_as=save_as
  117. )
  118. def create_file_var_message(self, file_var: "FileVar") -> ToolInvokeMessage:
  119. return ToolInvokeMessage(
  120. type=ToolInvokeMessage.MessageType.FILE_VAR, message=None, meta={"file_var": file_var}, save_as=""
  121. )
  122. def create_link_message(self, link: str, save_as: str = "") -> ToolInvokeMessage:
  123. """
  124. create a link message
  125. :param link: the url of the link
  126. :return: the link message
  127. """
  128. return ToolInvokeMessage(
  129. type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=link), save_as=save_as
  130. )
  131. def create_text_message(self, text: str, save_as: str = "") -> ToolInvokeMessage:
  132. """
  133. create a text message
  134. :param text: the text
  135. :return: the text message
  136. """
  137. return ToolInvokeMessage(
  138. type=ToolInvokeMessage.MessageType.TEXT, message=ToolInvokeMessage.TextMessage(text=text), save_as=save_as
  139. )
  140. def create_blob_message(self, blob: bytes, meta: Optional[dict] = None, save_as: str = "") -> ToolInvokeMessage:
  141. """
  142. create a blob message
  143. :param blob: the blob
  144. :return: the blob message
  145. """
  146. return ToolInvokeMessage(
  147. type=ToolInvokeMessage.MessageType.BLOB,
  148. message=ToolInvokeMessage.BlobMessage(blob=blob),
  149. meta=meta,
  150. save_as=save_as,
  151. )
  152. def create_json_message(self, object: dict) -> ToolInvokeMessage:
  153. """
  154. create a json message
  155. """
  156. return ToolInvokeMessage(
  157. type=ToolInvokeMessage.MessageType.JSON, message=ToolInvokeMessage.JsonMessage(json_object=object)
  158. )