tool_manager.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import Any, Union
  8. from flask import current_app
  9. from core.agent.entities import AgentToolEntity
  10. from core.model_runtime.utils.encoders import jsonable_encoder
  11. from core.tools import *
  12. from core.tools.entities.common_entities import I18nObject
  13. from core.tools.entities.tool_entities import (
  14. ApiProviderAuthType,
  15. ToolParameter,
  16. )
  17. from core.tools.entities.user_entities import UserToolProvider
  18. from core.tools.errors import ToolProviderNotFoundError
  19. from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
  20. from core.tools.provider.builtin._positions import BuiltinToolProviderSort
  21. from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
  22. from core.tools.tool.api_tool import ApiTool
  23. from core.tools.tool.builtin_tool import BuiltinTool
  24. from core.tools.tool.tool import Tool
  25. from core.tools.utils.configuration import (
  26. ToolConfigurationManager,
  27. ToolParameterConfigurationManager,
  28. )
  29. from core.utils.module_import_helper import load_single_subclass_from_source
  30. from core.workflow.nodes.tool.entities import ToolEntity
  31. from extensions.ext_database import db
  32. from models.tools import ApiToolProvider, BuiltinToolProvider
  33. from services.tools_transform_service import ToolTransformService
  34. logger = logging.getLogger(__name__)
  35. class ToolManager:
  36. _builtin_provider_lock = Lock()
  37. _builtin_providers = {}
  38. _builtin_providers_loaded = False
  39. _builtin_tools_labels = {}
  40. @classmethod
  41. def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
  42. """
  43. get the builtin provider
  44. :param provider: the name of the provider
  45. :return: the provider
  46. """
  47. if len(cls._builtin_providers) == 0:
  48. # init the builtin providers
  49. cls.load_builtin_providers_cache()
  50. if provider not in cls._builtin_providers:
  51. raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
  52. return cls._builtin_providers[provider]
  53. @classmethod
  54. def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
  55. """
  56. get the builtin tool
  57. :param provider: the name of the provider
  58. :param tool_name: the name of the tool
  59. :return: the provider, the tool
  60. """
  61. provider_controller = cls.get_builtin_provider(provider)
  62. tool = provider_controller.get_tool(tool_name)
  63. return tool
  64. @classmethod
  65. def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
  66. -> Union[BuiltinTool, ApiTool]:
  67. """
  68. get the tool
  69. :param provider_type: the type of the provider
  70. :param provider_name: the name of the provider
  71. :param tool_name: the name of the tool
  72. :return: the tool
  73. """
  74. if provider_type == 'builtin':
  75. return cls.get_builtin_tool(provider_id, tool_name)
  76. elif provider_type == 'api':
  77. if tenant_id is None:
  78. raise ValueError('tenant id is required for api provider')
  79. api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id)
  80. return api_provider.get_tool(tool_name)
  81. elif provider_type == 'app':
  82. raise NotImplementedError('app provider not implemented')
  83. else:
  84. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  85. @classmethod
  86. def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \
  87. -> Union[BuiltinTool, ApiTool]:
  88. """
  89. get the tool runtime
  90. :param provider_type: the type of the provider
  91. :param provider_name: the name of the provider
  92. :param tool_name: the name of the tool
  93. :return: the tool
  94. """
  95. if provider_type == 'builtin':
  96. builtin_tool = cls.get_builtin_tool(provider_name, tool_name)
  97. # check if the builtin tool need credentials
  98. provider_controller = cls.get_builtin_provider(provider_name)
  99. if not provider_controller.need_credentials:
  100. return builtin_tool.fork_tool_runtime(meta={
  101. 'tenant_id': tenant_id,
  102. 'credentials': {},
  103. })
  104. # get credentials
  105. builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  106. BuiltinToolProvider.tenant_id == tenant_id,
  107. BuiltinToolProvider.provider == provider_name,
  108. ).first()
  109. if builtin_provider is None:
  110. raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
  111. # decrypt the credentials
  112. credentials = builtin_provider.credentials
  113. controller = cls.get_builtin_provider(provider_name)
  114. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
  115. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  116. return builtin_tool.fork_tool_runtime(meta={
  117. 'tenant_id': tenant_id,
  118. 'credentials': decrypted_credentials,
  119. 'runtime_parameters': {}
  120. })
  121. elif provider_type == 'api':
  122. if tenant_id is None:
  123. raise ValueError('tenant id is required for api provider')
  124. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name)
  125. # decrypt the credentials
  126. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
  127. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  128. return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
  129. 'tenant_id': tenant_id,
  130. 'credentials': decrypted_credentials,
  131. })
  132. elif provider_type == 'app':
  133. raise NotImplementedError('app provider not implemented')
  134. else:
  135. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  136. @classmethod
  137. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
  138. """
  139. init runtime parameter
  140. """
  141. parameter_value = parameters.get(parameter_rule.name)
  142. if not parameter_value:
  143. # get default value
  144. parameter_value = parameter_rule.default
  145. if not parameter_value and parameter_rule.required:
  146. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  147. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  148. # check if tool_parameter_config in options
  149. options = list(map(lambda x: x.value, parameter_rule.options))
  150. if parameter_value not in options:
  151. raise ValueError(
  152. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
  153. # convert tool parameter config to correct type
  154. try:
  155. if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER:
  156. # check if tool parameter is integer
  157. if isinstance(parameter_value, int):
  158. parameter_value = parameter_value
  159. elif isinstance(parameter_value, float):
  160. parameter_value = parameter_value
  161. elif isinstance(parameter_value, str):
  162. if '.' in parameter_value:
  163. parameter_value = float(parameter_value)
  164. else:
  165. parameter_value = int(parameter_value)
  166. elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN:
  167. parameter_value = bool(parameter_value)
  168. elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT,
  169. ToolParameter.ToolParameterType.STRING]:
  170. parameter_value = str(parameter_value)
  171. elif parameter_rule.type == ToolParameter.ToolParameterType:
  172. parameter_value = str(parameter_value)
  173. except Exception as e:
  174. raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type")
  175. return parameter_value
  176. @classmethod
  177. def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool:
  178. """
  179. get the agent tool runtime
  180. """
  181. tool_entity = cls.get_tool_runtime(
  182. provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id,
  183. tool_name=agent_tool.tool_name,
  184. tenant_id=tenant_id,
  185. )
  186. runtime_parameters = {}
  187. parameters = tool_entity.get_all_runtime_parameters()
  188. for parameter in parameters:
  189. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  190. # save tool parameter to tool entity memory
  191. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  192. runtime_parameters[parameter.name] = value
  193. # decrypt runtime parameters
  194. encryption_manager = ToolParameterConfigurationManager(
  195. tenant_id=tenant_id,
  196. tool_runtime=tool_entity,
  197. provider_name=agent_tool.provider_id,
  198. provider_type=agent_tool.provider_type,
  199. identity_id=f'AGENT.{app_id}'
  200. )
  201. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  202. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  203. return tool_entity
  204. @classmethod
  205. def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity):
  206. """
  207. get the workflow tool runtime
  208. """
  209. tool_entity = cls.get_tool_runtime(
  210. provider_type=workflow_tool.provider_type,
  211. provider_name=workflow_tool.provider_id,
  212. tool_name=workflow_tool.tool_name,
  213. tenant_id=tenant_id,
  214. )
  215. runtime_parameters = {}
  216. parameters = tool_entity.get_all_runtime_parameters()
  217. for parameter in parameters:
  218. # save tool parameter to tool entity memory
  219. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  220. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  221. runtime_parameters[parameter.name] = value
  222. # decrypt runtime parameters
  223. encryption_manager = ToolParameterConfigurationManager(
  224. tenant_id=tenant_id,
  225. tool_runtime=tool_entity,
  226. provider_name=workflow_tool.provider_id,
  227. provider_type=workflow_tool.provider_type,
  228. identity_id=f'WORKFLOW.{app_id}.{node_id}'
  229. )
  230. if runtime_parameters:
  231. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  232. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  233. return tool_entity
  234. @classmethod
  235. def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]:
  236. """
  237. get the absolute path of the icon of the builtin provider
  238. :param provider: the name of the provider
  239. :return: the absolute path of the icon, the mime type of the icon
  240. """
  241. # get provider
  242. provider_controller = cls.get_builtin_provider(provider)
  243. absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets',
  244. provider_controller.identity.icon)
  245. # check if the icon exists
  246. if not path.exists(absolute_path):
  247. raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found')
  248. # get the mime type
  249. mime_type, _ = mimetypes.guess_type(absolute_path)
  250. mime_type = mime_type or 'application/octet-stream'
  251. return absolute_path, mime_type
  252. @classmethod
  253. def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  254. # use cache first
  255. if cls._builtin_providers_loaded:
  256. yield from list(cls._builtin_providers.values())
  257. return
  258. with cls._builtin_provider_lock:
  259. if cls._builtin_providers_loaded:
  260. yield from list(cls._builtin_providers.values())
  261. return
  262. yield from cls._list_builtin_providers()
  263. @classmethod
  264. def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  265. """
  266. list all the builtin providers
  267. """
  268. for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
  269. if provider.startswith('__'):
  270. continue
  271. if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)):
  272. if provider.startswith('__'):
  273. continue
  274. # init provider
  275. try:
  276. provider_class = load_single_subclass_from_source(
  277. module_name=f'core.tools.provider.builtin.{provider}.{provider}',
  278. script_path=path.join(path.dirname(path.realpath(__file__)),
  279. 'provider', 'builtin', provider, f'{provider}.py'),
  280. parent_type=BuiltinToolProviderController)
  281. provider: BuiltinToolProviderController = provider_class()
  282. cls._builtin_providers[provider.identity.name] = provider
  283. for tool in provider.get_tools():
  284. cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
  285. yield provider
  286. except Exception as e:
  287. logger.error(f'load builtin provider {provider} error: {e}')
  288. continue
  289. # set builtin providers loaded
  290. cls._builtin_providers_loaded = True
  291. @classmethod
  292. def load_builtin_providers_cache(cls):
  293. for _ in cls.list_builtin_providers():
  294. pass
  295. @classmethod
  296. def clear_builtin_providers_cache(cls):
  297. cls._builtin_providers = {}
  298. cls._builtin_providers_loaded = False
  299. @classmethod
  300. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  301. """
  302. get the tool label
  303. :param tool_name: the name of the tool
  304. :return: the label of the tool
  305. """
  306. cls._builtin_tools_labels
  307. if len(cls._builtin_tools_labels) == 0:
  308. # init the builtin providers
  309. cls.load_builtin_providers_cache()
  310. if tool_name not in cls._builtin_tools_labels:
  311. return None
  312. return cls._builtin_tools_labels[tool_name]
  313. @classmethod
  314. def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
  315. result_providers: dict[str, UserToolProvider] = {}
  316. # get builtin providers
  317. builtin_providers = cls.list_builtin_providers()
  318. # get db builtin providers
  319. db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
  320. filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  321. find_db_builtin_provider = lambda provider: next(
  322. (x for x in db_builtin_providers if x.provider == provider),
  323. None
  324. )
  325. # append builtin providers
  326. for provider in builtin_providers:
  327. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  328. provider_controller=provider,
  329. db_provider=find_db_builtin_provider(provider.identity.name),
  330. decrypt_credentials=False
  331. )
  332. result_providers[provider.identity.name] = user_provider
  333. # get db api providers
  334. db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
  335. filter(ApiToolProvider.tenant_id == tenant_id).all()
  336. for db_api_provider in db_api_providers:
  337. provider_controller = ToolTransformService.api_provider_to_controller(
  338. db_provider=db_api_provider,
  339. )
  340. user_provider = ToolTransformService.api_provider_to_user_provider(
  341. provider_controller=provider_controller,
  342. db_provider=db_api_provider,
  343. decrypt_credentials=False
  344. )
  345. result_providers[db_api_provider.name] = user_provider
  346. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  347. @classmethod
  348. def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
  349. ApiBasedToolProviderController, dict[str, Any]]:
  350. """
  351. get the api provider
  352. :param provider_name: the name of the provider
  353. :return: the provider controller, the credentials
  354. """
  355. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  356. ApiToolProvider.id == provider_id,
  357. ApiToolProvider.tenant_id == tenant_id,
  358. ).first()
  359. if provider is None:
  360. raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
  361. controller = ApiBasedToolProviderController.from_db(
  362. provider,
  363. ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else
  364. ApiProviderAuthType.NONE
  365. )
  366. controller.load_bundled_tools(provider.tools)
  367. return controller, provider.credentials
  368. @classmethod
  369. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  370. """
  371. get api provider
  372. """
  373. """
  374. get tool provider
  375. """
  376. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  377. ApiToolProvider.tenant_id == tenant_id,
  378. ApiToolProvider.name == provider,
  379. ).first()
  380. if provider is None:
  381. raise ValueError(f'you have not added provider {provider}')
  382. try:
  383. credentials = json.loads(provider.credentials_str) or {}
  384. except:
  385. credentials = {}
  386. # package tool provider controller
  387. controller = ApiBasedToolProviderController.from_db(
  388. provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  389. )
  390. # init tool configuration
  391. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
  392. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  393. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  394. try:
  395. icon = json.loads(provider.icon)
  396. except:
  397. icon = {
  398. "background": "#252525",
  399. "content": "\ud83d\ude01"
  400. }
  401. return jsonable_encoder({
  402. 'schema_type': provider.schema_type,
  403. 'schema': provider.schema,
  404. 'tools': provider.tools,
  405. 'icon': icon,
  406. 'description': provider.description,
  407. 'credentials': masked_credentials,
  408. 'privacy_policy': provider.privacy_policy
  409. })
  410. @classmethod
  411. def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]:
  412. """
  413. get the tool icon
  414. :param tenant_id: the id of the tenant
  415. :param provider_type: the type of the provider
  416. :param provider_id: the id of the provider
  417. :return:
  418. """
  419. provider_type = provider_type
  420. provider_id = provider_id
  421. if provider_type == 'builtin':
  422. return (current_app.config.get("CONSOLE_API_URL")
  423. + "/console/api/workspaces/current/tool-provider/builtin/"
  424. + provider_id
  425. + "/icon")
  426. elif provider_type == 'api':
  427. try:
  428. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  429. ApiToolProvider.tenant_id == tenant_id,
  430. ApiToolProvider.id == provider_id
  431. )
  432. return json.loads(provider.icon)
  433. except:
  434. return {
  435. "background": "#252525",
  436. "content": "\ud83d\ude01"
  437. }
  438. else:
  439. raise ValueError(f"provider type {provider_type} not found")
  440. ToolManager.load_builtin_providers_cache()