tool_manager.py 30 KB


  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Union, cast
  8. from core.plugin.manager.tool import PluginToolManager
  9. from core.tools.__base.tool_runtime import ToolRuntime
  10. from core.tools.plugin_tool.provider import PluginToolProviderController
  11. from core.tools.plugin_tool.tool import PluginTool
  12. if TYPE_CHECKING:
  13. from core.workflow.nodes.tool.entities import ToolEntity
  14. from configs import dify_config
  15. from core.agent.entities import AgentToolEntity
  16. from core.app.entities.app_invoke_entities import InvokeFrom
  17. from core.helper.module_import_helper import load_single_subclass_from_source
  18. from core.helper.position_helper import is_filtered
  19. from core.model_runtime.utils.encoders import jsonable_encoder
  20. from core.tools.__base.tool import Tool
  21. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  22. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  23. from core.tools.builtin_tool.tool import BuiltinTool
  24. from core.tools.custom_tool.provider import ApiToolProviderController
  25. from core.tools.custom_tool.tool import ApiTool
  26. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  27. from core.tools.entities.common_entities import I18nObject
  28. from core.tools.entities.tool_entities import (
  29. ApiProviderAuthType,
  30. ToolInvokeFrom,
  31. ToolParameter,
  32. ToolProviderID,
  33. ToolProviderType,
  34. )
  35. from core.tools.errors import ToolProviderNotFoundError
  36. from core.tools.tool_label_manager import ToolLabelManager
  37. from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
  38. from core.tools.utils.tool_parameter_converter import ToolParameterConverter
  39. from core.tools.workflow_as_tool.tool import WorkflowTool
  40. from extensions.ext_database import db
  41. from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
  42. from services.tools.tools_transform_service import ToolTransformService
  43. logger = logging.getLogger(__name__)
  44. class ToolManager:
  45. _builtin_provider_lock = Lock()
  46. _hardcoded_providers = {}
  47. _builtin_providers_loaded = False
  48. _builtin_tools_labels = {}
  49. @classmethod
  50. def get_builtin_provider(
  51. cls, provider: str, tenant_id: str
  52. ) -> BuiltinToolProviderController | PluginToolProviderController:
  53. """
  54. get the builtin provider
  55. :param provider: the name of the provider
  56. :param tenant_id: the id of the tenant
  57. :return: the provider
  58. """
  59. # split provider to
  60. if len(cls._hardcoded_providers) == 0:
  61. # init the builtin providers
  62. cls.load_hardcoded_providers_cache()
  63. if provider not in cls._hardcoded_providers:
  64. # get plugin provider
  65. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  66. if plugin_provider:
  67. return plugin_provider
  68. return cls._hardcoded_providers[provider]
  69. @classmethod
  70. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  71. """
  72. get the plugin provider
  73. """
  74. manager = PluginToolManager()
  75. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  76. if not provider_entity:
  77. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  78. return PluginToolProviderController(
  79. entity=provider_entity.declaration,
  80. tenant_id=tenant_id,
  81. )
  82. @classmethod
  83. def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
  84. """
  85. get the builtin tool
  86. :param provider: the name of the provider
  87. :param tool_name: the name of the tool
  88. :param tenant_id: the id of the tenant
  89. :return: the provider, the tool
  90. """
  91. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  92. tool = provider_controller.get_tool(tool_name)
  93. return tool
  94. @classmethod
  95. def get_tool_runtime(
  96. cls,
  97. provider_type: ToolProviderType,
  98. provider_id: str,
  99. tool_name: str,
  100. tenant_id: str,
  101. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  102. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  103. ) -> Union[BuiltinTool, ApiTool, WorkflowTool]:
  104. """
  105. get the tool runtime
  106. :param provider_type: the type of the provider
  107. :param provider_name: the name of the provider
  108. :param tool_name: the name of the tool
  109. :return: the tool
  110. """
  111. if provider_type == ToolProviderType.BUILT_IN:
  112. # check if the builtin tool need credentials
  113. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  114. builtin_tool = provider_controller.get_tool(tool_name)
  115. if not builtin_tool:
  116. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  117. if not provider_controller.need_credentials:
  118. return cast(
  119. BuiltinTool,
  120. builtin_tool.fork_tool_runtime(
  121. runtime=ToolRuntime(
  122. tenant_id=tenant_id,
  123. credentials={},
  124. invoke_from=invoke_from,
  125. tool_invoke_from=tool_invoke_from,
  126. )
  127. ),
  128. )
  129. if isinstance(provider_controller, PluginToolProviderController):
  130. provider_id_entity = ToolProviderID(provider_id)
  131. # get credentials
  132. builtin_provider: BuiltinToolProvider | None = (
  133. db.session.query(BuiltinToolProvider)
  134. .filter(
  135. BuiltinToolProvider.tenant_id == tenant_id,
  136. (BuiltinToolProvider.provider == provider_id)
  137. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  138. )
  139. .first()
  140. )
  141. if builtin_provider is None:
  142. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  143. else:
  144. builtin_provider: BuiltinToolProvider | None = (
  145. db.session.query(BuiltinToolProvider)
  146. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  147. .first()
  148. )
  149. if builtin_provider is None:
  150. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  151. # decrypt the credentials
  152. credentials = builtin_provider.credentials
  153. tool_configuration = ProviderConfigEncrypter(
  154. tenant_id=tenant_id,
  155. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  156. provider_type=provider_controller.provider_type.value,
  157. provider_identity=provider_controller.entity.identity.name,
  158. )
  159. decrypted_credentials = tool_configuration.decrypt(credentials)
  160. return cast(
  161. BuiltinTool,
  162. builtin_tool.fork_tool_runtime(
  163. runtime=ToolRuntime(
  164. tenant_id=tenant_id,
  165. credentials=decrypted_credentials,
  166. runtime_parameters={},
  167. invoke_from=invoke_from,
  168. tool_invoke_from=tool_invoke_from,
  169. )
  170. ),
  171. )
  172. elif provider_type == ToolProviderType.API:
  173. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  174. # decrypt the credentials
  175. tool_configuration = ProviderConfigEncrypter(
  176. tenant_id=tenant_id,
  177. config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
  178. provider_type=api_provider.provider_type.value,
  179. provider_identity=api_provider.entity.identity.name,
  180. )
  181. decrypted_credentials = tool_configuration.decrypt(credentials)
  182. return cast(
  183. ApiTool,
  184. api_provider.get_tool(tool_name).fork_tool_runtime(
  185. runtime=ToolRuntime(
  186. tenant_id=tenant_id,
  187. credentials=decrypted_credentials,
  188. invoke_from=invoke_from,
  189. tool_invoke_from=tool_invoke_from,
  190. )
  191. ),
  192. )
  193. elif provider_type == ToolProviderType.WORKFLOW:
  194. workflow_provider = (
  195. db.session.query(WorkflowToolProvider)
  196. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  197. .first()
  198. )
  199. if workflow_provider is None:
  200. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  201. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  202. return cast(
  203. WorkflowTool,
  204. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  205. runtime=ToolRuntime(
  206. tenant_id=tenant_id,
  207. credentials={},
  208. invoke_from=invoke_from,
  209. tool_invoke_from=tool_invoke_from,
  210. )
  211. ),
  212. )
  213. elif provider_type == ToolProviderType.APP:
  214. raise NotImplementedError("app provider not implemented")
  215. else:
  216. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  217. @classmethod
  218. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
  219. """
  220. init runtime parameter
  221. """
  222. parameter_value = parameters.get(parameter_rule.name)
  223. if not parameter_value and parameter_value != 0:
  224. # get default value
  225. parameter_value = parameter_rule.default
  226. if not parameter_value and parameter_rule.required:
  227. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  228. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  229. # check if tool_parameter_config in options
  230. options = [x.value for x in parameter_rule.options]
  231. if parameter_value is not None and parameter_value not in options:
  232. raise ValueError(
  233. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
  234. )
  235. return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type)
  236. @classmethod
  237. def get_agent_tool_runtime(
  238. cls,
  239. tenant_id: str,
  240. app_id: str,
  241. agent_tool: AgentToolEntity,
  242. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  243. ) -> Tool:
  244. """
  245. get the agent tool runtime
  246. """
  247. tool_entity = cls.get_tool_runtime(
  248. provider_type=agent_tool.provider_type,
  249. provider_id=agent_tool.provider_id,
  250. tool_name=agent_tool.tool_name,
  251. tenant_id=tenant_id,
  252. invoke_from=invoke_from,
  253. tool_invoke_from=ToolInvokeFrom.AGENT,
  254. )
  255. runtime_parameters = {}
  256. parameters = tool_entity.get_merged_runtime_parameters()
  257. for parameter in parameters:
  258. # check file types
  259. if parameter.type == ToolParameter.ToolParameterType.FILE:
  260. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  261. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  262. # save tool parameter to tool entity memory
  263. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  264. runtime_parameters[parameter.name] = value
  265. # decrypt runtime parameters
  266. encryption_manager = ToolParameterConfigurationManager(
  267. tenant_id=tenant_id,
  268. tool_runtime=tool_entity,
  269. provider_name=agent_tool.provider_id,
  270. provider_type=agent_tool.provider_type,
  271. identity_id=f"AGENT.{app_id}",
  272. )
  273. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  274. if not tool_entity.runtime:
  275. raise Exception("tool missing runtime")
  276. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  277. return tool_entity
  278. @classmethod
  279. def get_workflow_tool_runtime(
  280. cls,
  281. tenant_id: str,
  282. app_id: str,
  283. node_id: str,
  284. workflow_tool: "ToolEntity",
  285. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  286. ) -> Tool:
  287. """
  288. get the workflow tool runtime
  289. """
  290. tool_runtime = cls.get_tool_runtime(
  291. provider_type=workflow_tool.provider_type,
  292. provider_id=workflow_tool.provider_id,
  293. tool_name=workflow_tool.tool_name,
  294. tenant_id=tenant_id,
  295. invoke_from=invoke_from,
  296. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  297. )
  298. runtime_parameters = {}
  299. parameters = tool_runtime.get_merged_runtime_parameters()
  300. for parameter in parameters:
  301. # save tool parameter to tool entity memory
  302. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  303. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  304. runtime_parameters[parameter.name] = value
  305. # decrypt runtime parameters
  306. encryption_manager = ToolParameterConfigurationManager(
  307. tenant_id=tenant_id,
  308. tool_runtime=tool_runtime,
  309. provider_name=workflow_tool.provider_id,
  310. provider_type=workflow_tool.provider_type,
  311. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  312. )
  313. if runtime_parameters:
  314. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  315. if not tool_runtime.runtime:
  316. raise Exception("tool missing runtime")
  317. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  318. return tool_runtime
  319. @classmethod
  320. def get_tool_runtime_from_plugin(
  321. cls,
  322. tool_type: ToolProviderType,
  323. tenant_id: str,
  324. provider: str,
  325. tool_name: str,
  326. tool_parameters: dict[str, Any],
  327. ) -> Tool:
  328. """
  329. get tool runtime from plugin
  330. """
  331. tool_entity = cls.get_tool_runtime(
  332. provider_type=tool_type,
  333. provider_id=provider,
  334. tool_name=tool_name,
  335. tenant_id=tenant_id,
  336. invoke_from=InvokeFrom.SERVICE_API,
  337. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  338. )
  339. runtime_parameters = {}
  340. parameters = tool_entity.get_merged_runtime_parameters()
  341. for parameter in parameters:
  342. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  343. # save tool parameter to tool entity memory
  344. value = cls._init_runtime_parameter(parameter, tool_parameters)
  345. runtime_parameters[parameter.name] = value
  346. if not tool_entity.runtime:
  347. raise Exception("tool missing runtime")
  348. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  349. return tool_entity
  350. @classmethod
  351. def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
  352. """
  353. get the absolute path of the icon of the builtin provider
  354. :param provider: the name of the provider
  355. :param tenant_id: the id of the tenant
  356. :return: the absolute path of the icon, the mime type of the icon
  357. """
  358. # get provider
  359. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  360. absolute_path = path.join(
  361. path.dirname(path.realpath(__file__)),
  362. "builtin_tool",
  363. "providers",
  364. provider,
  365. "_assets",
  366. provider_controller.entity.identity.icon,
  367. )
  368. # check if the icon exists
  369. if not path.exists(absolute_path):
  370. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  371. # get the mime type
  372. mime_type, _ = mimetypes.guess_type(absolute_path)
  373. mime_type = mime_type or "application/octet-stream"
  374. return absolute_path, mime_type
  375. @classmethod
  376. def list_hardcoded_providers(cls):
  377. # use cache first
  378. if cls._builtin_providers_loaded:
  379. yield from list(cls._hardcoded_providers.values())
  380. return
  381. with cls._builtin_provider_lock:
  382. if cls._builtin_providers_loaded:
  383. yield from list(cls._hardcoded_providers.values())
  384. return
  385. yield from cls._list_hardcoded_providers()
  386. @classmethod
  387. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  388. """
  389. list all the plugin providers
  390. """
  391. manager = PluginToolManager()
  392. provider_entities = manager.fetch_tool_providers(tenant_id)
  393. return [
  394. PluginToolProviderController(
  395. entity=provider.declaration,
  396. tenant_id=tenant_id,
  397. )
  398. for provider in provider_entities
  399. ]
  400. @classmethod
  401. def list_builtin_providers(
  402. cls, tenant_id: str
  403. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  404. """
  405. list all the builtin providers
  406. """
  407. yield from cls.list_hardcoded_providers()
  408. # get plugin providers
  409. yield from cls.list_plugin_providers(tenant_id)
  410. @classmethod
  411. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  412. """
  413. list all the builtin providers
  414. """
  415. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  416. if provider_path.startswith("__"):
  417. continue
  418. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  419. if provider_path.startswith("__"):
  420. continue
  421. # init provider
  422. try:
  423. provider_class = load_single_subclass_from_source(
  424. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  425. script_path=path.join(
  426. path.dirname(path.realpath(__file__)),
  427. "builtin_tool",
  428. "providers",
  429. provider_path,
  430. f"{provider_path}.py",
  431. ),
  432. parent_type=BuiltinToolProviderController,
  433. )
  434. provider: BuiltinToolProviderController = provider_class()
  435. cls._hardcoded_providers[provider.entity.identity.name] = provider
  436. for tool in provider.get_tools():
  437. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  438. yield provider
  439. except Exception as e:
  440. logger.error(f"load builtin provider error: {e}")
  441. continue
  442. # set builtin providers loaded
  443. cls._builtin_providers_loaded = True
  444. @classmethod
  445. def load_hardcoded_providers_cache(cls):
  446. for _ in cls.list_hardcoded_providers():
  447. pass
  448. @classmethod
  449. def clear_hardcoded_providers_cache(cls):
  450. cls._hardcoded_providers = {}
  451. cls._builtin_providers_loaded = False
  452. @classmethod
  453. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  454. """
  455. get the tool label
  456. :param tool_name: the name of the tool
  457. :return: the label of the tool
  458. """
  459. if len(cls._builtin_tools_labels) == 0:
  460. # init the builtin providers
  461. cls.load_hardcoded_providers_cache()
  462. if tool_name not in cls._builtin_tools_labels:
  463. return None
  464. return cls._builtin_tools_labels[tool_name]
  465. @classmethod
  466. def list_providers_from_api(
  467. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  468. ) -> list[ToolProviderApiEntity]:
  469. result_providers: dict[str, ToolProviderApiEntity] = {}
  470. filters = []
  471. if not typ:
  472. filters.extend(["builtin", "api", "workflow"])
  473. else:
  474. filters.append(typ)
  475. if "builtin" in filters:
  476. # get builtin providers
  477. builtin_providers = cls.list_builtin_providers(tenant_id)
  478. # get db builtin providers
  479. db_builtin_providers: list[BuiltinToolProvider] = (
  480. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  481. )
  482. # rewrite db_builtin_providers
  483. for db_provider in db_builtin_providers:
  484. try:
  485. ToolProviderID(db_provider.provider)
  486. except Exception:
  487. db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
  488. find_db_builtin_provider = lambda provider: next(
  489. (x for x in db_builtin_providers if x.provider == provider), None
  490. )
  491. # append builtin providers
  492. for provider in builtin_providers:
  493. # handle include, exclude
  494. if is_filtered(
  495. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  496. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  497. data=provider,
  498. name_func=lambda x: x.identity.name,
  499. ):
  500. continue
  501. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  502. provider_controller=provider,
  503. db_provider=find_db_builtin_provider(provider.entity.identity.name),
  504. decrypt_credentials=False,
  505. )
  506. if isinstance(provider, PluginToolProviderController):
  507. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  508. else:
  509. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  510. # get db api providers
  511. if "api" in filters:
  512. db_api_providers: list[ApiToolProvider] = (
  513. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  514. )
  515. api_provider_controllers = [
  516. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  517. for provider in db_api_providers
  518. ]
  519. # get labels
  520. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  521. for api_provider_controller in api_provider_controllers:
  522. user_provider = ToolTransformService.api_provider_to_user_provider(
  523. provider_controller=api_provider_controller["controller"],
  524. db_provider=api_provider_controller["provider"],
  525. decrypt_credentials=False,
  526. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  527. )
  528. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  529. if "workflow" in filters:
  530. # get workflow providers
  531. workflow_providers: list[WorkflowToolProvider] = (
  532. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  533. )
  534. workflow_provider_controllers = []
  535. for provider in workflow_providers:
  536. try:
  537. workflow_provider_controllers.append(
  538. ToolTransformService.workflow_provider_to_controller(db_provider=provider)
  539. )
  540. except Exception as e:
  541. # app has been deleted
  542. pass
  543. labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
  544. for provider_controller in workflow_provider_controllers:
  545. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  546. provider_controller=provider_controller,
  547. labels=labels.get(provider_controller.provider_id, []),
  548. )
  549. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  550. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  551. @classmethod
  552. def get_api_provider_controller(
  553. cls, tenant_id: str, provider_id: str
  554. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  555. """
  556. get the api provider
  557. :param provider_name: the name of the provider
  558. :return: the provider controller, the credentials
  559. """
  560. provider: ApiToolProvider | None = (
  561. db.session.query(ApiToolProvider)
  562. .filter(
  563. ApiToolProvider.id == provider_id,
  564. ApiToolProvider.tenant_id == tenant_id,
  565. )
  566. .first()
  567. )
  568. if provider is None:
  569. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  570. controller = ApiToolProviderController.from_db(
  571. provider,
  572. ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  573. )
  574. controller.load_bundled_tools(provider.tools)
  575. return controller, provider.credentials
  576. @classmethod
  577. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  578. """
  579. get api provider
  580. """
  581. provider_obj: ApiToolProvider | None = (
  582. db.session.query(ApiToolProvider)
  583. .filter(
  584. ApiToolProvider.tenant_id == tenant_id,
  585. ApiToolProvider.name == provider,
  586. )
  587. .first()
  588. )
  589. if provider_obj is None:
  590. raise ValueError(f"you have not added provider {provider}")
  591. try:
  592. credentials = json.loads(provider_obj.credentials_str) or {}
  593. except:
  594. credentials = {}
  595. # package tool provider controller
  596. controller = ApiToolProviderController.from_db(
  597. provider_obj,
  598. ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  599. )
  600. # init tool configuration
  601. tool_configuration = ProviderConfigEncrypter(
  602. tenant_id=tenant_id,
  603. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  604. provider_type=controller.provider_type.value,
  605. provider_identity=controller.entity.identity.name,
  606. )
  607. decrypted_credentials = tool_configuration.decrypt(credentials)
  608. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  609. try:
  610. icon = json.loads(provider_obj.icon)
  611. except:
  612. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  613. # add tool labels
  614. labels = ToolLabelManager.get_tool_labels(controller)
  615. return jsonable_encoder(
  616. {
  617. "schema_type": provider_obj.schema_type,
  618. "schema": provider_obj.schema,
  619. "tools": provider_obj.tools,
  620. "icon": icon,
  621. "description": provider_obj.description,
  622. "credentials": masked_credentials,
  623. "privacy_policy": provider_obj.privacy_policy,
  624. "custom_disclaimer": provider_obj.custom_disclaimer,
  625. "labels": labels,
  626. }
  627. )
  628. @classmethod
  629. def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
  630. """
  631. get the tool icon
  632. :param tenant_id: the id of the tenant
  633. :param provider_type: the type of the provider
  634. :param provider_id: the id of the provider
  635. :return:
  636. """
  637. provider_type = provider_type
  638. provider_id = provider_id
  639. if provider_type == ToolProviderType.BUILT_IN:
  640. return (
  641. dify_config.CONSOLE_API_URL
  642. + "/console/api/workspaces/current/tool-provider/builtin/"
  643. + provider_id
  644. + "/icon"
  645. )
  646. elif provider_type == ToolProviderType.API:
  647. try:
  648. api_provider: ApiToolProvider | None = (
  649. db.session.query(ApiToolProvider)
  650. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  651. .first()
  652. )
  653. if not api_provider:
  654. raise ValueError("api tool not found")
  655. return json.loads(api_provider.icon)
  656. except:
  657. return {"background": "#252525", "content": "\ud83d\ude01"}
  658. elif provider_type == ToolProviderType.WORKFLOW:
  659. workflow_provider: WorkflowToolProvider | None = (
  660. db.session.query(WorkflowToolProvider)
  661. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  662. .first()
  663. )
  664. if workflow_provider is None:
  665. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  666. return json.loads(workflow_provider.icon)
  667. else:
  668. raise ValueError(f"provider type {provider_type} not found")
  669. ToolManager.load_hardcoded_providers_cache()