tool_manager.py 25 KB


  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Union, cast
  8. if TYPE_CHECKING:
  9. from core.workflow.nodes.tool.entities import ToolEntity
  10. from configs import dify_config
  11. from core.agent.entities import AgentToolEntity
  12. from core.app.entities.app_invoke_entities import InvokeFrom
  13. from core.helper.module_import_helper import load_single_subclass_from_source
  14. from core.helper.position_helper import is_filtered
  15. from core.model_runtime.utils.encoders import jsonable_encoder
  16. from core.tools.__base.tool import Tool
  17. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  18. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  19. from core.tools.builtin_tool.tool import BuiltinTool
  20. from core.tools.custom_tool.provider import ApiToolProviderController
  21. from core.tools.custom_tool.tool import ApiTool
  22. from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
  23. from core.tools.entities.common_entities import I18nObject
  24. from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
  25. from core.tools.errors import ToolProviderNotFoundError
  26. from core.tools.tool_label_manager import ToolLabelManager
  27. from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
  28. from core.tools.utils.tool_parameter_converter import ToolParameterConverter
  29. from core.tools.workflow_as_tool.tool import WorkflowTool
  30. from extensions.ext_database import db
  31. from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
  32. from services.tools.tools_transform_service import ToolTransformService
  33. logger = logging.getLogger(__name__)
  34. class ToolManager:
  35. _builtin_provider_lock = Lock()
  36. _builtin_providers = {}
  37. _builtin_providers_loaded = False
  38. _builtin_tools_labels = {}
  39. @classmethod
  40. def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
  41. """
  42. get the builtin provider
  43. :param provider: the name of the provider
  44. :return: the provider
  45. """
  46. if len(cls._builtin_providers) == 0:
  47. # init the builtin providers
  48. cls.load_builtin_providers_cache()
  49. if provider not in cls._builtin_providers:
  50. raise ToolProviderNotFoundError(f"builtin provider {provider} not found")
  51. return cls._builtin_providers[provider]
  52. @classmethod
  53. def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool | None:
  54. """
  55. get the builtin tool
  56. :param provider: the name of the provider
  57. :param tool_name: the name of the tool
  58. :return: the provider, the tool
  59. """
  60. provider_controller = cls.get_builtin_provider(provider)
  61. tool = provider_controller.get_tool(tool_name)
  62. return tool
  63. @classmethod
  64. def get_tool_runtime(
  65. cls,
  66. provider_type: ToolProviderType,
  67. provider_id: str,
  68. tool_name: str,
  69. tenant_id: str,
  70. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  71. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  72. ) -> Union[BuiltinTool, ApiTool, WorkflowTool]:
  73. """
  74. get the tool runtime
  75. :param provider_type: the type of the provider
  76. :param provider_name: the name of the provider
  77. :param tool_name: the name of the tool
  78. :return: the tool
  79. """
  80. if provider_type == ToolProviderType.BUILT_IN:
  81. builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
  82. if not builtin_tool:
  83. raise ValueError(f"tool {tool_name} not found")
  84. # check if the builtin tool need credentials
  85. provider_controller = cls.get_builtin_provider(provider_id)
  86. if not provider_controller.need_credentials:
  87. return cast(
  88. BuiltinTool,
  89. builtin_tool.fork_tool_runtime(
  90. runtime={
  91. "tenant_id": tenant_id,
  92. "credentials": {},
  93. "invoke_from": invoke_from,
  94. "tool_invoke_from": tool_invoke_from,
  95. }
  96. ),
  97. )
  98. # get credentials
  99. builtin_provider: BuiltinToolProvider | None = (
  100. db.session.query(BuiltinToolProvider)
  101. .filter(
  102. BuiltinToolProvider.tenant_id == tenant_id,
  103. BuiltinToolProvider.provider == provider_id,
  104. )
  105. .first()
  106. )
  107. if builtin_provider is None:
  108. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  109. # decrypt the credentials
  110. credentials = builtin_provider.credentials
  111. controller = cls.get_builtin_provider(provider_id)
  112. tool_configuration = ProviderConfigEncrypter(
  113. tenant_id=tenant_id,
  114. config=controller.get_credentials_schema(),
  115. provider_type=controller.provider_type.value,
  116. provider_identity=controller.identity.name,
  117. )
  118. decrypted_credentials = tool_configuration.decrypt(credentials)
  119. return cast(
  120. BuiltinTool,
  121. builtin_tool.fork_tool_runtime(
  122. runtime={
  123. "tenant_id": tenant_id,
  124. "credentials": decrypted_credentials,
  125. "runtime_parameters": {},
  126. "invoke_from": invoke_from,
  127. "tool_invoke_from": tool_invoke_from,
  128. }
  129. ),
  130. )
  131. elif provider_type == ToolProviderType.API:
  132. if tenant_id is None:
  133. raise ValueError("tenant id is required for api provider")
  134. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  135. # decrypt the credentials
  136. tool_configuration = ProviderConfigEncrypter(
  137. tenant_id=tenant_id,
  138. config=api_provider.get_credentials_schema(),
  139. provider_type=api_provider.provider_type.value,
  140. provider_identity=api_provider.identity.name,
  141. )
  142. decrypted_credentials = tool_configuration.decrypt(credentials)
  143. return cast(
  144. ApiTool,
  145. api_provider.get_tool(tool_name).fork_tool_runtime(
  146. runtime={
  147. "tenant_id": tenant_id,
  148. "credentials": decrypted_credentials,
  149. "invoke_from": invoke_from,
  150. "tool_invoke_from": tool_invoke_from,
  151. }
  152. ),
  153. )
  154. elif provider_type == ToolProviderType.WORKFLOW:
  155. workflow_provider = (
  156. db.session.query(WorkflowToolProvider)
  157. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  158. .first()
  159. )
  160. if workflow_provider is None:
  161. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  162. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  163. return cast(
  164. WorkflowTool,
  165. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  166. runtime={
  167. "tenant_id": tenant_id,
  168. "credentials": {},
  169. "invoke_from": invoke_from,
  170. "tool_invoke_from": tool_invoke_from,
  171. }
  172. ),
  173. )
  174. elif provider_type == ToolProviderType.APP:
  175. raise NotImplementedError("app provider not implemented")
  176. else:
  177. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  178. @classmethod
  179. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
  180. """
  181. init runtime parameter
  182. """
  183. parameter_value = parameters.get(parameter_rule.name)
  184. if not parameter_value and parameter_value != 0:
  185. # get default value
  186. parameter_value = parameter_rule.default
  187. if not parameter_value and parameter_rule.required:
  188. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  189. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  190. # check if tool_parameter_config in options
  191. options = [x.value for x in parameter_rule.options]
  192. if parameter_value is not None and parameter_value not in options:
  193. raise ValueError(
  194. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
  195. )
  196. return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type)
  197. @classmethod
  198. def get_agent_tool_runtime(
  199. cls, tenant_id: str, app_id: str, agent_tool: AgentToolEntity, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER
  200. ) -> Tool:
  201. """
  202. get the agent tool runtime
  203. """
  204. tool_entity = cls.get_tool_runtime(
  205. provider_type=agent_tool.provider_type,
  206. provider_id=agent_tool.provider_id,
  207. tool_name=agent_tool.tool_name,
  208. tenant_id=tenant_id,
  209. invoke_from=invoke_from,
  210. tool_invoke_from=ToolInvokeFrom.AGENT,
  211. )
  212. runtime_parameters = {}
  213. parameters = tool_entity.get_all_runtime_parameters()
  214. for parameter in parameters:
  215. # check file types
  216. if parameter.type == ToolParameter.ToolParameterType.FILE:
  217. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  218. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  219. # save tool parameter to tool entity memory
  220. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  221. runtime_parameters[parameter.name] = value
  222. # decrypt runtime parameters
  223. encryption_manager = ToolParameterConfigurationManager(
  224. tenant_id=tenant_id,
  225. tool_runtime=tool_entity,
  226. provider_name=agent_tool.provider_id,
  227. provider_type=agent_tool.provider_type,
  228. identity_id=f"AGENT.{app_id}",
  229. )
  230. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  231. if not tool_entity.runtime:
  232. raise Exception("tool missing runtime")
  233. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  234. return tool_entity
  235. @classmethod
  236. def get_workflow_tool_runtime(
  237. cls,
  238. tenant_id: str,
  239. app_id: str,
  240. node_id: str,
  241. workflow_tool: "ToolEntity",
  242. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  243. ) -> Tool:
  244. """
  245. get the workflow tool runtime
  246. """
  247. tool_entity = cls.get_tool_runtime(
  248. provider_type=workflow_tool.provider_type,
  249. provider_id=workflow_tool.provider_id,
  250. tool_name=workflow_tool.tool_name,
  251. tenant_id=tenant_id,
  252. invoke_from=invoke_from,
  253. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  254. )
  255. runtime_parameters = {}
  256. parameters = tool_entity.get_all_runtime_parameters()
  257. for parameter in parameters:
  258. # save tool parameter to tool entity memory
  259. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  260. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  261. runtime_parameters[parameter.name] = value
  262. # decrypt runtime parameters
  263. encryption_manager = ToolParameterConfigurationManager(
  264. tenant_id=tenant_id,
  265. tool_runtime=tool_entity,
  266. provider_name=workflow_tool.provider_id,
  267. provider_type=workflow_tool.provider_type,
  268. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  269. )
  270. if runtime_parameters:
  271. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  272. if not tool_entity.runtime:
  273. raise Exception("tool missing runtime")
  274. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  275. return tool_entity
  276. @classmethod
  277. def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]:
  278. """
  279. get the absolute path of the icon of the builtin provider
  280. :param provider: the name of the provider
  281. :return: the absolute path of the icon, the mime type of the icon
  282. """
  283. # get provider
  284. provider_controller = cls.get_builtin_provider(provider)
  285. absolute_path = path.join(
  286. path.dirname(path.realpath(__file__)),
  287. "builtin_tool",
  288. "providers",
  289. provider,
  290. "_assets",
  291. provider_controller.identity.icon,
  292. )
  293. # check if the icon exists
  294. if not path.exists(absolute_path):
  295. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  296. # get the mime type
  297. mime_type, _ = mimetypes.guess_type(absolute_path)
  298. mime_type = mime_type or "application/octet-stream"
  299. return absolute_path, mime_type
  300. @classmethod
  301. def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  302. # use cache first
  303. if cls._builtin_providers_loaded:
  304. yield from list(cls._builtin_providers.values())
  305. return
  306. with cls._builtin_provider_lock:
  307. if cls._builtin_providers_loaded:
  308. yield from list(cls._builtin_providers.values())
  309. return
  310. yield from cls._list_builtin_providers()
  311. @classmethod
  312. def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  313. """
  314. list all the builtin providers
  315. """
  316. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  317. if provider_path.startswith("__"):
  318. continue
  319. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  320. if provider_path.startswith("__"):
  321. continue
  322. # init provider
  323. try:
  324. provider_class = load_single_subclass_from_source(
  325. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  326. script_path=path.join(
  327. path.dirname(path.realpath(__file__)),
  328. "builtin_tool",
  329. "providers",
  330. provider_path,
  331. f"{provider_path}.py",
  332. ),
  333. parent_type=BuiltinToolProviderController,
  334. )
  335. provider: BuiltinToolProviderController = provider_class()
  336. cls._builtin_providers[provider.identity.name] = provider
  337. for tool in provider.get_tools():
  338. cls._builtin_tools_labels[tool.identity.name] = tool.identity.label
  339. yield provider
  340. except Exception as e:
  341. logger.error(f"load builtin provider error: {e}")
  342. continue
  343. # set builtin providers loaded
  344. cls._builtin_providers_loaded = True
  345. @classmethod
  346. def load_builtin_providers_cache(cls):
  347. for _ in cls.list_builtin_providers():
  348. pass
  349. @classmethod
  350. def clear_builtin_providers_cache(cls):
  351. cls._builtin_providers = {}
  352. cls._builtin_providers_loaded = False
  353. @classmethod
  354. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  355. """
  356. get the tool label
  357. :param tool_name: the name of the tool
  358. :return: the label of the tool
  359. """
  360. if len(cls._builtin_tools_labels) == 0:
  361. # init the builtin providers
  362. cls.load_builtin_providers_cache()
  363. if tool_name not in cls._builtin_tools_labels:
  364. return None
  365. return cls._builtin_tools_labels[tool_name]
  366. @classmethod
  367. def user_list_providers(
  368. cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral
  369. ) -> list[UserToolProvider]:
  370. result_providers: dict[str, UserToolProvider] = {}
  371. filters = []
  372. if not typ:
  373. filters.extend(["builtin", "api", "workflow"])
  374. else:
  375. filters.append(typ)
  376. if "builtin" in filters:
  377. # get builtin providers
  378. builtin_providers = cls.list_builtin_providers()
  379. # get db builtin providers
  380. db_builtin_providers: list[BuiltinToolProvider] = (
  381. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  382. )
  383. find_db_builtin_provider = lambda provider: next(
  384. (x for x in db_builtin_providers if x.provider == provider), None
  385. )
  386. # append builtin providers
  387. for provider in builtin_providers:
  388. # handle include, exclude
  389. if is_filtered(
  390. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  391. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  392. data=provider,
  393. name_func=lambda x: x.identity.name,
  394. ):
  395. continue
  396. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  397. provider_controller=provider,
  398. db_provider=find_db_builtin_provider(provider.identity.name),
  399. decrypt_credentials=False,
  400. )
  401. result_providers[provider.identity.name] = user_provider
  402. # get db api providers
  403. if "api" in filters:
  404. db_api_providers: list[ApiToolProvider] = (
  405. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  406. )
  407. api_provider_controllers = [
  408. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  409. for provider in db_api_providers
  410. ]
  411. # get labels
  412. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  413. for api_provider_controller in api_provider_controllers:
  414. user_provider = ToolTransformService.api_provider_to_user_provider(
  415. provider_controller=api_provider_controller["controller"],
  416. db_provider=api_provider_controller["provider"],
  417. decrypt_credentials=False,
  418. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  419. )
  420. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  421. if "workflow" in filters:
  422. # get workflow providers
  423. workflow_providers: list[WorkflowToolProvider] = (
  424. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  425. )
  426. workflow_provider_controllers = []
  427. for provider in workflow_providers:
  428. try:
  429. workflow_provider_controllers.append(
  430. ToolTransformService.workflow_provider_to_controller(db_provider=provider)
  431. )
  432. except Exception as e:
  433. # app has been deleted
  434. pass
  435. labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
  436. for provider_controller in workflow_provider_controllers:
  437. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  438. provider_controller=provider_controller,
  439. labels=labels.get(provider_controller.provider_id, []),
  440. )
  441. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  442. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  443. @classmethod
  444. def get_api_provider_controller(
  445. cls, tenant_id: str, provider_id: str
  446. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  447. """
  448. get the api provider
  449. :param provider_name: the name of the provider
  450. :return: the provider controller, the credentials
  451. """
  452. provider: ApiToolProvider | None = (
  453. db.session.query(ApiToolProvider)
  454. .filter(
  455. ApiToolProvider.id == provider_id,
  456. ApiToolProvider.tenant_id == tenant_id,
  457. )
  458. .first()
  459. )
  460. if provider is None:
  461. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  462. controller = ApiToolProviderController.from_db(
  463. provider,
  464. ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  465. )
  466. controller.load_bundled_tools(provider.tools)
  467. return controller, provider.credentials
  468. @classmethod
  469. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  470. """
  471. get api provider
  472. """
  473. """
  474. get tool provider
  475. """
  476. provider_obj: ApiToolProvider | None = (
  477. db.session.query(ApiToolProvider)
  478. .filter(
  479. ApiToolProvider.tenant_id == tenant_id,
  480. ApiToolProvider.name == provider,
  481. )
  482. .first()
  483. )
  484. if provider_obj is None:
  485. raise ValueError(f"you have not added provider {provider}")
  486. try:
  487. credentials = json.loads(provider_obj.credentials_str) or {}
  488. except:
  489. credentials = {}
  490. # package tool provider controller
  491. controller = ApiToolProviderController.from_db(
  492. provider_obj,
  493. ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  494. )
  495. # init tool configuration
  496. tool_configuration = ProviderConfigEncrypter(
  497. tenant_id=tenant_id,
  498. config=controller.get_credentials_schema(),
  499. provider_type=controller.provider_type.value,
  500. provider_identity=controller.identity.name,
  501. )
  502. decrypted_credentials = tool_configuration.decrypt(credentials)
  503. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  504. try:
  505. icon = json.loads(provider_obj.icon)
  506. except:
  507. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  508. # add tool labels
  509. labels = ToolLabelManager.get_tool_labels(controller)
  510. return jsonable_encoder(
  511. {
  512. "schema_type": provider_obj.schema_type,
  513. "schema": provider_obj.schema,
  514. "tools": provider_obj.tools,
  515. "icon": icon,
  516. "description": provider_obj.description,
  517. "credentials": masked_credentials,
  518. "privacy_policy": provider_obj.privacy_policy,
  519. "custom_disclaimer": provider_obj.custom_disclaimer,
  520. "labels": labels,
  521. }
  522. )
  523. @classmethod
  524. def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
  525. """
  526. get the tool icon
  527. :param tenant_id: the id of the tenant
  528. :param provider_type: the type of the provider
  529. :param provider_id: the id of the provider
  530. :return:
  531. """
  532. provider_type = provider_type
  533. provider_id = provider_id
  534. if provider_type == ToolProviderType.BUILT_IN:
  535. return (
  536. dify_config.CONSOLE_API_URL
  537. + "/console/api/workspaces/current/tool-provider/builtin/"
  538. + provider_id
  539. + "/icon"
  540. )
  541. elif provider_type == ToolProviderType.API:
  542. try:
  543. api_provider: ApiToolProvider | None = (
  544. db.session.query(ApiToolProvider)
  545. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  546. .first()
  547. )
  548. if not api_provider:
  549. raise ValueError("api tool not found")
  550. return json.loads(api_provider.icon)
  551. except:
  552. return {"background": "#252525", "content": "\ud83d\ude01"}
  553. elif provider_type == ToolProviderType.WORKFLOW:
  554. workflow_provider: WorkflowToolProvider | None = (
  555. db.session.query(WorkflowToolProvider)
  556. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  557. .first()
  558. )
  559. if workflow_provider is None:
  560. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  561. return json.loads(workflow_provider.icon)
  562. else:
  563. raise ValueError(f"provider type {provider_type} not found")
  564. ToolManager.load_builtin_providers_cache()