tool_manager.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import Any, Union
  8. from flask import current_app
  9. from core.agent.entities import AgentToolEntity
  10. from core.model_runtime.utils.encoders import jsonable_encoder
  11. from core.provider_manager import ProviderManager
  12. from core.tools import *
  13. from core.tools.entities.common_entities import I18nObject
  14. from core.tools.entities.tool_entities import (
  15. ApiProviderAuthType,
  16. ToolParameter,
  17. )
  18. from core.tools.entities.user_entities import UserToolProvider
  19. from core.tools.errors import ToolProviderNotFoundError
  20. from core.tools.provider.api_tool_provider import ApiBasedToolProviderController
  21. from core.tools.provider.builtin._positions import BuiltinToolProviderSort
  22. from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
  23. from core.tools.provider.model_tool_provider import ModelToolProviderController
  24. from core.tools.tool.api_tool import ApiTool
  25. from core.tools.tool.builtin_tool import BuiltinTool
  26. from core.tools.tool.tool import Tool
  27. from core.tools.utils.configuration import (
  28. ToolConfigurationManager,
  29. ToolParameterConfigurationManager,
  30. )
  31. from core.utils.module_import_helper import load_single_subclass_from_source
  32. from core.workflow.nodes.tool.entities import ToolEntity
  33. from extensions.ext_database import db
  34. from models.tools import ApiToolProvider, BuiltinToolProvider
  35. from services.tools_transform_service import ToolTransformService
  36. logger = logging.getLogger(__name__)
  37. class ToolManager:
  38. _builtin_provider_lock = Lock()
  39. _builtin_providers = {}
  40. _builtin_providers_loaded = False
  41. _builtin_tools_labels = {}
  42. @classmethod
  43. def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
  44. """
  45. get the builtin provider
  46. :param provider: the name of the provider
  47. :return: the provider
  48. """
  49. if len(cls._builtin_providers) == 0:
  50. # init the builtin providers
  51. cls.load_builtin_providers_cache()
  52. if provider not in cls._builtin_providers:
  53. raise ToolProviderNotFoundError(f'builtin provider {provider} not found')
  54. return cls._builtin_providers[provider]
  55. @classmethod
  56. def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
  57. """
  58. get the builtin tool
  59. :param provider: the name of the provider
  60. :param tool_name: the name of the tool
  61. :return: the provider, the tool
  62. """
  63. provider_controller = cls.get_builtin_provider(provider)
  64. tool = provider_controller.get_tool(tool_name)
  65. return tool
  66. @classmethod
  67. def get_tool(cls, provider_type: str, provider_id: str, tool_name: str, tenant_id: str = None) \
  68. -> Union[BuiltinTool, ApiTool]:
  69. """
  70. get the tool
  71. :param provider_type: the type of the provider
  72. :param provider_name: the name of the provider
  73. :param tool_name: the name of the tool
  74. :return: the tool
  75. """
  76. if provider_type == 'builtin':
  77. return cls.get_builtin_tool(provider_id, tool_name)
  78. elif provider_type == 'api':
  79. if tenant_id is None:
  80. raise ValueError('tenant id is required for api provider')
  81. api_provider, _ = cls.get_api_provider_controller(tenant_id, provider_id)
  82. return api_provider.get_tool(tool_name)
  83. elif provider_type == 'app':
  84. raise NotImplementedError('app provider not implemented')
  85. else:
  86. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  87. @classmethod
  88. def get_tool_runtime(cls, provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \
  89. -> Union[BuiltinTool, ApiTool]:
  90. """
  91. get the tool runtime
  92. :param provider_type: the type of the provider
  93. :param provider_name: the name of the provider
  94. :param tool_name: the name of the tool
  95. :return: the tool
  96. """
  97. if provider_type == 'builtin':
  98. builtin_tool = cls.get_builtin_tool(provider_name, tool_name)
  99. # check if the builtin tool need credentials
  100. provider_controller = cls.get_builtin_provider(provider_name)
  101. if not provider_controller.need_credentials:
  102. return builtin_tool.fork_tool_runtime(meta={
  103. 'tenant_id': tenant_id,
  104. 'credentials': {},
  105. })
  106. # get credentials
  107. builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
  108. BuiltinToolProvider.tenant_id == tenant_id,
  109. BuiltinToolProvider.provider == provider_name,
  110. ).first()
  111. if builtin_provider is None:
  112. raise ToolProviderNotFoundError(f'builtin provider {provider_name} not found')
  113. # decrypt the credentials
  114. credentials = builtin_provider.credentials
  115. controller = cls.get_builtin_provider(provider_name)
  116. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
  117. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  118. return builtin_tool.fork_tool_runtime(meta={
  119. 'tenant_id': tenant_id,
  120. 'credentials': decrypted_credentials,
  121. 'runtime_parameters': {}
  122. })
  123. elif provider_type == 'api':
  124. if tenant_id is None:
  125. raise ValueError('tenant id is required for api provider')
  126. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_name)
  127. # decrypt the credentials
  128. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
  129. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  130. return api_provider.get_tool(tool_name).fork_tool_runtime(meta={
  131. 'tenant_id': tenant_id,
  132. 'credentials': decrypted_credentials,
  133. })
  134. elif provider_type == 'model':
  135. if tenant_id is None:
  136. raise ValueError('tenant id is required for model provider')
  137. # get model provider
  138. model_provider = cls.get_model_provider(tenant_id, provider_name)
  139. # get tool
  140. model_tool = model_provider.get_tool(tool_name)
  141. return model_tool.fork_tool_runtime(meta={
  142. 'tenant_id': tenant_id,
  143. 'credentials': model_tool.model_configuration['model_instance'].credentials
  144. })
  145. elif provider_type == 'app':
  146. raise NotImplementedError('app provider not implemented')
  147. else:
  148. raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
  149. @classmethod
  150. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
  151. """
  152. init runtime parameter
  153. """
  154. parameter_value = parameters.get(parameter_rule.name)
  155. if not parameter_value:
  156. # get default value
  157. parameter_value = parameter_rule.default
  158. if not parameter_value and parameter_rule.required:
  159. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  160. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  161. # check if tool_parameter_config in options
  162. options = list(map(lambda x: x.value, parameter_rule.options))
  163. if parameter_value not in options:
  164. raise ValueError(
  165. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}")
  166. # convert tool parameter config to correct type
  167. try:
  168. if parameter_rule.type == ToolParameter.ToolParameterType.NUMBER:
  169. # check if tool parameter is integer
  170. if isinstance(parameter_value, int):
  171. parameter_value = parameter_value
  172. elif isinstance(parameter_value, float):
  173. parameter_value = parameter_value
  174. elif isinstance(parameter_value, str):
  175. if '.' in parameter_value:
  176. parameter_value = float(parameter_value)
  177. else:
  178. parameter_value = int(parameter_value)
  179. elif parameter_rule.type == ToolParameter.ToolParameterType.BOOLEAN:
  180. parameter_value = bool(parameter_value)
  181. elif parameter_rule.type not in [ToolParameter.ToolParameterType.SELECT,
  182. ToolParameter.ToolParameterType.STRING]:
  183. parameter_value = str(parameter_value)
  184. elif parameter_rule.type == ToolParameter.ToolParameterType:
  185. parameter_value = str(parameter_value)
  186. except Exception as e:
  187. raise ValueError(f"tool parameter {parameter_rule.name} value {parameter_value} is not correct type")
  188. return parameter_value
  189. @classmethod
  190. def get_agent_tool_runtime(cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity) -> Tool:
  191. """
  192. get the agent tool runtime
  193. """
  194. tool_entity = cls.get_tool_runtime(
  195. provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id,
  196. tool_name=agent_tool.tool_name,
  197. tenant_id=tenant_id,
  198. )
  199. runtime_parameters = {}
  200. parameters = tool_entity.get_all_runtime_parameters()
  201. for parameter in parameters:
  202. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  203. # save tool parameter to tool entity memory
  204. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  205. runtime_parameters[parameter.name] = value
  206. # decrypt runtime parameters
  207. encryption_manager = ToolParameterConfigurationManager(
  208. tenant_id=tenant_id,
  209. tool_runtime=tool_entity,
  210. provider_name=agent_tool.provider_id,
  211. provider_type=agent_tool.provider_type,
  212. identity_id=f'AGENT.{app_id}'
  213. )
  214. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  215. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  216. return tool_entity
  217. @classmethod
  218. def get_workflow_tool_runtime(cls, tenant_id: str, app_id: str, node_id: str, workflow_tool: ToolEntity):
  219. """
  220. get the workflow tool runtime
  221. """
  222. tool_entity = cls.get_tool_runtime(
  223. provider_type=workflow_tool.provider_type,
  224. provider_name=workflow_tool.provider_id,
  225. tool_name=workflow_tool.tool_name,
  226. tenant_id=tenant_id,
  227. )
  228. runtime_parameters = {}
  229. parameters = tool_entity.get_all_runtime_parameters()
  230. for parameter in parameters:
  231. # save tool parameter to tool entity memory
  232. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  233. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  234. runtime_parameters[parameter.name] = value
  235. # decrypt runtime parameters
  236. encryption_manager = ToolParameterConfigurationManager(
  237. tenant_id=tenant_id,
  238. tool_runtime=tool_entity,
  239. provider_name=workflow_tool.provider_id,
  240. provider_type=workflow_tool.provider_type,
  241. identity_id=f'WORKFLOW.{app_id}.{node_id}'
  242. )
  243. if runtime_parameters:
  244. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  245. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  246. return tool_entity
  247. @classmethod
  248. def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]:
  249. """
  250. get the absolute path of the icon of the builtin provider
  251. :param provider: the name of the provider
  252. :return: the absolute path of the icon, the mime type of the icon
  253. """
  254. # get provider
  255. provider_controller = cls.get_builtin_provider(provider)
  256. absolute_path = path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider, '_assets',
  257. provider_controller.identity.icon)
  258. # check if the icon exists
  259. if not path.exists(absolute_path):
  260. raise ToolProviderNotFoundError(f'builtin provider {provider} icon not found')
  261. # get the mime type
  262. mime_type, _ = mimetypes.guess_type(absolute_path)
  263. mime_type = mime_type or 'application/octet-stream'
  264. return absolute_path, mime_type
  265. @classmethod
  266. def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  267. # use cache first
  268. if cls._builtin_providers_loaded:
  269. yield from list(cls._builtin_providers.values())
  270. return
  271. with cls._builtin_provider_lock:
  272. if cls._builtin_providers_loaded:
  273. yield from list(cls._builtin_providers.values())
  274. return
  275. yield from cls._list_builtin_providers()
  276. @classmethod
  277. def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  278. """
  279. list all the builtin providers
  280. """
  281. for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
  282. if provider.startswith('__'):
  283. continue
  284. if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)):
  285. if provider.startswith('__'):
  286. continue
  287. # init provider
  288. try:
  289. provider_class = load_single_subclass_from_source(
  290. module_name=f'core.tools.provider.builtin.{provider}.{provider}',
  291. script_path=path.join(path.dirname(path.realpath(__file__)),
  292. 'provider', 'builtin', provider, f'{provider}.py'),
  293. parent_type=BuiltinToolProviderController)
  294. provider: BuiltinToolProviderController = provider_class()
  295. cls._builtin_providers[provider.identity.name] = provider
  296. for tool in provider.get_tools():
  297. cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
  298. yield provider
  299. except Exception as e:
  300. logger.error(f'load builtin provider {provider} error: {e}')
  301. continue
  302. # set builtin providers loaded
  303. cls._builtin_providers_loaded = True
  304. @classmethod
  305. def load_builtin_providers_cache(cls):
  306. for _ in cls.list_builtin_providers():
  307. pass
  308. @classmethod
  309. def clear_builtin_providers_cache(cls):
  310. cls._builtin_providers = {}
  311. cls._builtin_providers_loaded = False
  312. # @classmethod
  313. # def list_model_providers(cls, tenant_id: str = None) -> list[ModelToolProviderController]:
  314. # """
  315. # list all the model providers
  316. # :return: the list of the model providers
  317. # """
  318. # tenant_id = tenant_id or 'ffffffff-ffff-ffff-ffff-ffffffffffff'
  319. # # get configurations
  320. # model_configurations = ModelToolConfigurationManager.get_all_configuration()
  321. # # get all providers
  322. # provider_manager = ProviderManager()
  323. # configurations = provider_manager.get_configurations(tenant_id).values()
  324. # # get model providers
  325. # model_providers: list[ModelToolProviderController] = []
  326. # for configuration in configurations:
  327. # # all the model tool should be configurated
  328. # if configuration.provider.provider not in model_configurations:
  329. # continue
  330. # if not ModelToolProviderController.is_configuration_valid(configuration):
  331. # continue
  332. # model_providers.append(ModelToolProviderController.from_db(configuration))
  333. # return model_providers
  334. @classmethod
  335. def get_model_provider(cls, tenant_id: str, provider_name: str) -> ModelToolProviderController:
  336. """
  337. get the model provider
  338. :param provider_name: the name of the provider
  339. :return: the provider
  340. """
  341. # get configurations
  342. provider_manager = ProviderManager()
  343. configurations = provider_manager.get_configurations(tenant_id)
  344. configuration = configurations.get(provider_name)
  345. if configuration is None:
  346. raise ToolProviderNotFoundError(f'model provider {provider_name} not found')
  347. return ModelToolProviderController.from_db(configuration)
  348. @classmethod
  349. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  350. """
  351. get the tool label
  352. :param tool_name: the name of the tool
  353. :return: the label of the tool
  354. """
  355. cls._builtin_tools_labels
  356. if len(cls._builtin_tools_labels) == 0:
  357. # init the builtin providers
  358. cls.load_builtin_providers_cache()
  359. if tool_name not in cls._builtin_tools_labels:
  360. return None
  361. return cls._builtin_tools_labels[tool_name]
  362. @classmethod
  363. def user_list_providers(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
  364. result_providers: dict[str, UserToolProvider] = {}
  365. # get builtin providers
  366. builtin_providers = cls.list_builtin_providers()
  367. # get db builtin providers
  368. db_builtin_providers: list[BuiltinToolProvider] = db.session.query(BuiltinToolProvider). \
  369. filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  370. find_db_builtin_provider = lambda provider: next(
  371. (x for x in db_builtin_providers if x.provider == provider),
  372. None
  373. )
  374. # append builtin providers
  375. for provider in builtin_providers:
  376. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  377. provider_controller=provider,
  378. db_provider=find_db_builtin_provider(provider.identity.name),
  379. decrypt_credentials=False
  380. )
  381. result_providers[provider.identity.name] = user_provider
  382. # # get model tool providers
  383. # model_providers = cls.list_model_providers(tenant_id=tenant_id)
  384. # # append model providers
  385. # for provider in model_providers:
  386. # user_provider = ToolTransformService.model_provider_to_user_provider(
  387. # db_provider=provider,
  388. # )
  389. # result_providers[f'model_provider.{provider.identity.name}'] = user_provider
  390. # get db api providers
  391. db_api_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider). \
  392. filter(ApiToolProvider.tenant_id == tenant_id).all()
  393. for db_api_provider in db_api_providers:
  394. provider_controller = ToolTransformService.api_provider_to_controller(
  395. db_provider=db_api_provider,
  396. )
  397. user_provider = ToolTransformService.api_provider_to_user_provider(
  398. provider_controller=provider_controller,
  399. db_provider=db_api_provider,
  400. decrypt_credentials=False
  401. )
  402. result_providers[db_api_provider.name] = user_provider
  403. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  404. @classmethod
  405. def get_api_provider_controller(cls, tenant_id: str, provider_id: str) -> tuple[
  406. ApiBasedToolProviderController, dict[str, Any]]:
  407. """
  408. get the api provider
  409. :param provider_name: the name of the provider
  410. :return: the provider controller, the credentials
  411. """
  412. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  413. ApiToolProvider.id == provider_id,
  414. ApiToolProvider.tenant_id == tenant_id,
  415. ).first()
  416. if provider is None:
  417. raise ToolProviderNotFoundError(f'api provider {provider_id} not found')
  418. controller = ApiBasedToolProviderController.from_db(
  419. provider,
  420. ApiProviderAuthType.API_KEY if provider.credentials['auth_type'] == 'api_key' else
  421. ApiProviderAuthType.NONE
  422. )
  423. controller.load_bundled_tools(provider.tools)
  424. return controller, provider.credentials
  425. @classmethod
  426. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  427. """
  428. get api provider
  429. """
  430. """
  431. get tool provider
  432. """
  433. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  434. ApiToolProvider.tenant_id == tenant_id,
  435. ApiToolProvider.name == provider,
  436. ).first()
  437. if provider is None:
  438. raise ValueError(f'you have not added provider {provider}')
  439. try:
  440. credentials = json.loads(provider.credentials_str) or {}
  441. except:
  442. credentials = {}
  443. # package tool provider controller
  444. controller = ApiBasedToolProviderController.from_db(
  445. provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
  446. )
  447. # init tool configuration
  448. tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
  449. decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
  450. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  451. try:
  452. icon = json.loads(provider.icon)
  453. except:
  454. icon = {
  455. "background": "#252525",
  456. "content": "\ud83d\ude01"
  457. }
  458. return jsonable_encoder({
  459. 'schema_type': provider.schema_type,
  460. 'schema': provider.schema,
  461. 'tools': provider.tools,
  462. 'icon': icon,
  463. 'description': provider.description,
  464. 'credentials': masked_credentials,
  465. 'privacy_policy': provider.privacy_policy
  466. })
  467. @classmethod
  468. def get_tool_icon(cls, tenant_id: str, provider_type: str, provider_id: str) -> Union[str, dict]:
  469. """
  470. get the tool icon
  471. :param tenant_id: the id of the tenant
  472. :param provider_type: the type of the provider
  473. :param provider_id: the id of the provider
  474. :return:
  475. """
  476. provider_type = provider_type
  477. provider_id = provider_id
  478. if provider_type == 'builtin':
  479. return (current_app.config.get("CONSOLE_API_URL")
  480. + "/console/api/workspaces/current/tool-provider/builtin/"
  481. + provider_id
  482. + "/icon")
  483. elif provider_type == 'api':
  484. try:
  485. provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
  486. ApiToolProvider.tenant_id == tenant_id,
  487. ApiToolProvider.id == provider_id
  488. )
  489. return json.loads(provider.icon)
  490. except:
  491. return {
  492. "background": "#252525",
  493. "content": "\ud83d\ude01"
  494. }
  495. else:
  496. raise ValueError(f"provider type {provider_type} not found")
  497. ToolManager.load_builtin_providers_cache()