tool.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import json
  2. from collections.abc import Generator
  3. from os import getenv
  4. from typing import Any
  5. from urllib.parse import urlencode
  6. import httpx
  7. from core.helper import ssrf_proxy
  8. from core.tools.__base.tool import Tool
  9. from core.tools.__base.tool_runtime import ToolRuntime
  10. from core.tools.entities.tool_bundle import ApiToolBundle
  11. from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
  12. from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError
  13. API_TOOL_DEFAULT_TIMEOUT = (
  14. int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")),
  15. int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")),
  16. )
  17. class ApiTool(Tool):
  18. api_bundle: ApiToolBundle
  19. """
  20. Api tool
  21. """
  22. def __init__(self, entity: ToolEntity, api_bundle: ApiToolBundle, runtime: ToolRuntime):
  23. super().__init__(entity, runtime)
  24. self.api_bundle = api_bundle
  25. def fork_tool_runtime(self, runtime: ToolRuntime):
  26. """
  27. fork a new tool with meta data
  28. :param meta: the meta data of a tool call processing, tenant_id is required
  29. :return: the new tool
  30. """
  31. return self.__class__(
  32. entity=self.entity,
  33. api_bundle=self.api_bundle.model_copy(),
  34. runtime=runtime,
  35. )
  36. def validate_credentials(
  37. self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False
  38. ) -> str:
  39. """
  40. validate the credentials for Api tool
  41. """
  42. # assemble validate request and request parameters
  43. headers = self.assembling_request(parameters)
  44. if format_only:
  45. return ""
  46. response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters)
  47. # validate response
  48. return self.validate_and_parse_response(response)
  49. def tool_provider_type(self) -> ToolProviderType:
  50. return ToolProviderType.API
  51. def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]:
  52. if self.runtime == None:
  53. raise ToolProviderCredentialValidationError("runtime not initialized")
  54. headers = {}
  55. credentials = self.runtime.credentials or {}
  56. if "auth_type" not in credentials:
  57. raise ToolProviderCredentialValidationError("Missing auth_type")
  58. if credentials["auth_type"] == "api_key":
  59. api_key_header = "api_key"
  60. if "api_key_header" in credentials:
  61. api_key_header = credentials["api_key_header"]
  62. if "api_key_value" not in credentials:
  63. raise ToolProviderCredentialValidationError("Missing api_key_value")
  64. elif not isinstance(credentials["api_key_value"], str):
  65. raise ToolProviderCredentialValidationError("api_key_value must be a string")
  66. if "api_key_header_prefix" in credentials:
  67. api_key_header_prefix = credentials["api_key_header_prefix"]
  68. if api_key_header_prefix == "basic" and credentials["api_key_value"]:
  69. credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}'
  70. elif api_key_header_prefix == "bearer" and credentials["api_key_value"]:
  71. credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}'
  72. elif api_key_header_prefix == "custom":
  73. pass
  74. headers[api_key_header] = credentials["api_key_value"]
  75. needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
  76. for parameter in needed_parameters:
  77. if parameter.required and parameter.name not in parameters:
  78. raise ToolParameterValidationError(f"Missing required parameter {parameter.name}")
  79. if parameter.default is not None and parameter.name not in parameters:
  80. parameters[parameter.name] = parameter.default
  81. return headers
  82. def validate_and_parse_response(self, response: httpx.Response) -> str:
  83. """
  84. validate the response
  85. """
  86. if isinstance(response, httpx.Response):
  87. if response.status_code >= 400:
  88. raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}")
  89. if not response.content:
  90. return "Empty response from the tool, please check your parameters and try again."
  91. try:
  92. response = response.json()
  93. try:
  94. return json.dumps(response, ensure_ascii=False)
  95. except Exception as e:
  96. return json.dumps(response)
  97. except Exception as e:
  98. return response.text
  99. else:
  100. raise ValueError(f"Invalid response type {type(response)}")
  101. @staticmethod
  102. def get_parameter_value(parameter, parameters):
  103. if parameter["name"] in parameters:
  104. return parameters[parameter["name"]]
  105. elif parameter.get("required", False):
  106. raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}")
  107. else:
  108. return (parameter.get("schema", {}) or {}).get("default", "")
  109. def do_http_request(
  110. self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any]
  111. ) -> httpx.Response:
  112. """
  113. do http request depending on api bundle
  114. """
  115. method = method.lower()
  116. params = {}
  117. path_params = {}
  118. body = {}
  119. cookies = {}
  120. # check parameters
  121. for parameter in self.api_bundle.openapi.get("parameters", []):
  122. value = self.get_parameter_value(parameter, parameters)
  123. if parameter["in"] == "path":
  124. path_params[parameter["name"]] = value
  125. elif parameter["in"] == "query":
  126. if value != "":
  127. params[parameter["name"]] = value
  128. elif parameter["in"] == "cookie":
  129. cookies[parameter["name"]] = value
  130. elif parameter["in"] == "header":
  131. headers[parameter["name"]] = value
  132. # check if there is a request body and handle it
  133. if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None:
  134. # handle json request body
  135. if "content" in self.api_bundle.openapi["requestBody"]:
  136. for content_type in self.api_bundle.openapi["requestBody"]["content"]:
  137. headers["Content-Type"] = content_type
  138. body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"]
  139. required = body_schema.get("required", [])
  140. properties = body_schema.get("properties", {})
  141. for name, property in properties.items():
  142. if name in parameters:
  143. # convert type
  144. body[name] = self._convert_body_property_type(property, parameters[name])
  145. elif name in required:
  146. raise ToolParameterValidationError(
  147. f"Missing required parameter {name} in operation {self.api_bundle.operation_id}"
  148. )
  149. elif "default" in property:
  150. body[name] = property["default"]
  151. else:
  152. body[name] = None
  153. break
  154. # replace path parameters
  155. for name, value in path_params.items():
  156. url = url.replace(f"{{{name}}}", f"{value}")
  157. # parse http body data if needed, for GET/HEAD/OPTIONS/TRACE, the body is ignored
  158. if "Content-Type" in headers:
  159. if headers["Content-Type"] == "application/json":
  160. body = json.dumps(body)
  161. elif headers["Content-Type"] == "application/x-www-form-urlencoded":
  162. body = urlencode(body)
  163. else:
  164. body = body
  165. if method in {"get", "head", "post", "put", "delete", "patch"}:
  166. response = getattr(ssrf_proxy, method)(
  167. url,
  168. params=params,
  169. headers=headers,
  170. cookies=cookies,
  171. data=body,
  172. timeout=API_TOOL_DEFAULT_TIMEOUT,
  173. follow_redirects=True,
  174. )
  175. return response
  176. else:
  177. raise ValueError(f"Invalid http method {method}")
  178. def _convert_body_property_any_of(
  179. self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10
  180. ) -> Any:
  181. if max_recursive <= 0:
  182. raise Exception("Max recursion depth reached")
  183. for option in any_of or []:
  184. try:
  185. if "type" in option:
  186. # Attempt to convert the value based on the type.
  187. if option["type"] == "integer" or option["type"] == "int":
  188. return int(value)
  189. elif option["type"] == "number":
  190. if "." in str(value):
  191. return float(value)
  192. else:
  193. return int(value)
  194. elif option["type"] == "string":
  195. return str(value)
  196. elif option["type"] == "boolean":
  197. if str(value).lower() in {"true", "1"}:
  198. return True
  199. elif str(value).lower() in {"false", "0"}:
  200. return False
  201. else:
  202. continue # Not a boolean, try next option
  203. elif option["type"] == "null" and not value:
  204. return None
  205. else:
  206. continue # Unsupported type, try next option
  207. elif "anyOf" in option and isinstance(option["anyOf"], list):
  208. # Recursive call to handle nested anyOf
  209. return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1)
  210. except ValueError:
  211. continue # Conversion failed, try next option
  212. # If no option succeeded, you might want to return the value as is or raise an error
  213. return value # or raise ValueError(f"Cannot convert value '{value}' to any specified type in anyOf")
  214. def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any:
  215. try:
  216. if "type" in property:
  217. if property["type"] == "integer" or property["type"] == "int":
  218. return int(value)
  219. elif property["type"] == "number":
  220. # check if it is a float
  221. if "." in str(value):
  222. return float(value)
  223. else:
  224. return int(value)
  225. elif property["type"] == "string":
  226. return str(value)
  227. elif property["type"] == "boolean":
  228. return bool(value)
  229. elif property["type"] == "null":
  230. if value is None:
  231. return None
  232. elif property["type"] == "object" or property["type"] == "array":
  233. if isinstance(value, str):
  234. try:
  235. # an array str like '[1,2]' also can convert to list [1,2] through json.loads
  236. # json not support single quote, but we can support it
  237. value = value.replace("'", '"')
  238. return json.loads(value)
  239. except ValueError:
  240. return value
  241. elif isinstance(value, dict):
  242. return value
  243. else:
  244. return value
  245. else:
  246. raise ValueError(f"Invalid type {property['type']} for property {property}")
  247. elif "anyOf" in property and isinstance(property["anyOf"], list):
  248. return self._convert_body_property_any_of(property, value, property["anyOf"])
  249. except ValueError as e:
  250. return value
  251. def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
  252. """
  253. invoke http request
  254. """
  255. # assemble request
  256. headers = self.assembling_request(tool_parameters)
  257. # do http request
  258. response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters)
  259. # validate response
  260. response = self.validate_and_parse_response(response)
  261. # assemble invoke message
  262. yield self.create_text_message(response)