tool.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. from abc import ABC, abstractmethod
  2. from enum import Enum
  3. from typing import Any, Optional, Union
  4. from pydantic import BaseModel
  5. from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
  6. from core.tools.entities.tool_entities import (
  7. ToolDescription,
  8. ToolIdentity,
  9. ToolInvokeMessage,
  10. ToolParameter,
  11. ToolRuntimeImageVariable,
  12. ToolRuntimeVariable,
  13. ToolRuntimeVariablePool,
  14. )
  15. from core.tools.tool_file_manager import ToolFileManager
  16. class Tool(BaseModel, ABC):
  17. identity: ToolIdentity = None
  18. parameters: Optional[list[ToolParameter]] = None
  19. description: ToolDescription = None
  20. is_team_authorization: bool = False
  21. agent_callback: Optional[DifyAgentCallbackHandler] = None
  22. use_callback: bool = False
  23. class Runtime(BaseModel):
  24. """
  25. Meta data of a tool call processing
  26. """
  27. def __init__(self, **data: Any):
  28. super().__init__(**data)
  29. if not self.runtime_parameters:
  30. self.runtime_parameters = {}
  31. tenant_id: str = None
  32. tool_id: str = None
  33. credentials: dict[str, Any] = None
  34. runtime_parameters: dict[str, Any] = None
  35. runtime: Runtime = None
  36. variables: ToolRuntimeVariablePool = None
  37. def __init__(self, **data: Any):
  38. super().__init__(**data)
  39. if not self.agent_callback:
  40. self.use_callback = False
  41. else:
  42. self.use_callback = True
  43. class VARIABLE_KEY(Enum):
  44. IMAGE = 'image'
  45. def fork_tool_runtime(self, meta: dict[str, Any], agent_callback: DifyAgentCallbackHandler = None) -> 'Tool':
  46. """
  47. fork a new tool with meta data
  48. :param meta: the meta data of a tool call processing, tenant_id is required
  49. :return: the new tool
  50. """
  51. return self.__class__(
  52. identity=self.identity.copy() if self.identity else None,
  53. parameters=self.parameters.copy() if self.parameters else None,
  54. description=self.description.copy() if self.description else None,
  55. runtime=Tool.Runtime(**meta),
  56. agent_callback=agent_callback
  57. )
  58. def load_variables(self, variables: ToolRuntimeVariablePool):
  59. """
  60. load variables from database
  61. :param conversation_id: the conversation id
  62. """
  63. self.variables = variables
  64. def set_image_variable(self, variable_name: str, image_key: str) -> None:
  65. """
  66. set an image variable
  67. """
  68. if not self.variables:
  69. return
  70. self.variables.set_file(self.identity.name, variable_name, image_key)
  71. def set_text_variable(self, variable_name: str, text: str) -> None:
  72. """
  73. set a text variable
  74. """
  75. if not self.variables:
  76. return
  77. self.variables.set_text(self.identity.name, variable_name, text)
  78. def get_variable(self, name: Union[str, Enum]) -> Optional[ToolRuntimeVariable]:
  79. """
  80. get a variable
  81. :param name: the name of the variable
  82. :return: the variable
  83. """
  84. if not self.variables:
  85. return None
  86. if isinstance(name, Enum):
  87. name = name.value
  88. for variable in self.variables.pool:
  89. if variable.name == name:
  90. return variable
  91. return None
  92. def get_default_image_variable(self) -> Optional[ToolRuntimeVariable]:
  93. """
  94. get the default image variable
  95. :return: the image variable
  96. """
  97. if not self.variables:
  98. return None
  99. return self.get_variable(self.VARIABLE_KEY.IMAGE)
  100. def get_variable_file(self, name: Union[str, Enum]) -> Optional[bytes]:
  101. """
  102. get a variable file
  103. :param name: the name of the variable
  104. :return: the variable file
  105. """
  106. variable = self.get_variable(name)
  107. if not variable:
  108. return None
  109. if not isinstance(variable, ToolRuntimeImageVariable):
  110. return None
  111. message_file_id = variable.value
  112. # get file binary
  113. file_binary = ToolFileManager.get_file_binary_by_message_file_id(message_file_id)
  114. if not file_binary:
  115. return None
  116. return file_binary[0]
  117. def list_variables(self) -> list[ToolRuntimeVariable]:
  118. """
  119. list all variables
  120. :return: the variables
  121. """
  122. if not self.variables:
  123. return []
  124. return self.variables.pool
  125. def list_default_image_variables(self) -> list[ToolRuntimeVariable]:
  126. """
  127. list all image variables
  128. :return: the image variables
  129. """
  130. if not self.variables:
  131. return []
  132. result = []
  133. for variable in self.variables.pool:
  134. if variable.name.startswith(self.VARIABLE_KEY.IMAGE.value):
  135. result.append(variable)
  136. return result
  137. def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]:
  138. # check if tool_parameters is a string
  139. if isinstance(tool_parameters, str):
  140. # check if this tool has only one parameter
  141. parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM]
  142. if parameters and len(parameters) == 1:
  143. tool_parameters = {
  144. parameters[0].name: tool_parameters
  145. }
  146. else:
  147. raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
  148. # update tool_parameters
  149. if self.runtime.runtime_parameters:
  150. tool_parameters.update(self.runtime.runtime_parameters)
  151. # hit callback
  152. if self.use_callback:
  153. self.agent_callback.on_tool_start(
  154. tool_name=self.identity.name,
  155. tool_inputs=tool_parameters
  156. )
  157. try:
  158. result = self._invoke(
  159. user_id=user_id,
  160. tool_parameters=tool_parameters,
  161. )
  162. except Exception as e:
  163. if self.use_callback:
  164. self.agent_callback.on_tool_error(e)
  165. raise e
  166. if not isinstance(result, list):
  167. result = [result]
  168. # hit callback
  169. if self.use_callback:
  170. self.agent_callback.on_tool_end(
  171. tool_name=self.identity.name,
  172. tool_inputs=tool_parameters,
  173. tool_outputs=self._convert_tool_response_to_str(result)
  174. )
  175. return result
  176. def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
  177. """
  178. Handle tool response
  179. """
  180. result = ''
  181. for response in tool_response:
  182. if response.type == ToolInvokeMessage.MessageType.TEXT:
  183. result += response.message
  184. elif response.type == ToolInvokeMessage.MessageType.LINK:
  185. result += f"result link: {response.message}. please tell user to check it."
  186. elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
  187. response.type == ToolInvokeMessage.MessageType.IMAGE:
  188. result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
  189. elif response.type == ToolInvokeMessage.MessageType.BLOB:
  190. if len(response.message) > 114:
  191. result += str(response.message[:114]) + '...'
  192. else:
  193. result += str(response.message)
  194. else:
  195. result += f"tool response: {response.message}."
  196. return result
  197. @abstractmethod
  198. def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]:
  199. pass
  200. def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
  201. """
  202. validate the credentials
  203. :param credentials: the credentials
  204. :param parameters: the parameters
  205. """
  206. pass
  207. def get_runtime_parameters(self) -> list[ToolParameter]:
  208. """
  209. get the runtime parameters
  210. interface for developer to dynamic change the parameters of a tool depends on the variables pool
  211. :return: the runtime parameters
  212. """
  213. return self.parameters
  214. def get_all_runtime_parameters(self) -> list[ToolParameter]:
  215. """
  216. get all runtime parameters
  217. :return: all runtime parameters
  218. """
  219. parameters = self.parameters or []
  220. parameters = parameters.copy()
  221. user_parameters = self.get_runtime_parameters() or []
  222. user_parameters = user_parameters.copy()
  223. # override parameters
  224. for parameter in user_parameters:
  225. # check if parameter in tool parameters
  226. found = False
  227. for tool_parameter in parameters:
  228. if tool_parameter.name == parameter.name:
  229. found = True
  230. break
  231. if found:
  232. # override parameter
  233. tool_parameter.type = parameter.type
  234. tool_parameter.form = parameter.form
  235. tool_parameter.required = parameter.required
  236. tool_parameter.default = parameter.default
  237. tool_parameter.options = parameter.options
  238. tool_parameter.llm_description = parameter.llm_description
  239. else:
  240. # add new parameter
  241. parameters.append(parameter)
  242. return parameters
  243. def is_tool_available(self) -> bool:
  244. """
  245. check if the tool is available
  246. :return: if the tool is available
  247. """
  248. return True
  249. def create_image_message(self, image: str, save_as: str = '') -> ToolInvokeMessage:
  250. """
  251. create an image message
  252. :param image: the url of the image
  253. :return: the image message
  254. """
  255. return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.IMAGE,
  256. message=image,
  257. save_as=save_as)
  258. def create_link_message(self, link: str, save_as: str = '') -> ToolInvokeMessage:
  259. """
  260. create a link message
  261. :param link: the url of the link
  262. :return: the link message
  263. """
  264. return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.LINK,
  265. message=link,
  266. save_as=save_as)
  267. def create_text_message(self, text: str, save_as: str = '') -> ToolInvokeMessage:
  268. """
  269. create a text message
  270. :param text: the text
  271. :return: the text message
  272. """
  273. return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.TEXT,
  274. message=text,
  275. save_as=save_as
  276. )
  277. def create_blob_message(self, blob: bytes, meta: dict = None, save_as: str = '') -> ToolInvokeMessage:
  278. """
  279. create a blob message
  280. :param blob: the blob
  281. :return: the blob message
  282. """
  283. return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB,
  284. message=blob, meta=meta,
  285. save_as=save_as
  286. )