tool_manager.py 30 KB


  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Union, cast
  8. from core.plugin.manager.tool import PluginToolManager
  9. from core.tools.__base.tool_runtime import ToolRuntime
  10. from core.tools.plugin_tool.provider import PluginToolProviderController
  11. from core.tools.plugin_tool.tool import PluginTool
  12. if TYPE_CHECKING:
  13. from core.workflow.nodes.tool.entities import ToolEntity
  14. from configs import dify_config
  15. from core.agent.entities import AgentToolEntity
  16. from core.app.entities.app_invoke_entities import InvokeFrom
  17. from core.helper.module_import_helper import load_single_subclass_from_source
  18. from core.helper.position_helper import is_filtered
  19. from core.model_runtime.utils.encoders import jsonable_encoder
  20. from core.tools.__base.tool import Tool
  21. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  22. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  23. from core.tools.builtin_tool.tool import BuiltinTool
  24. from core.tools.custom_tool.provider import ApiToolProviderController
  25. from core.tools.custom_tool.tool import ApiTool
  26. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  27. from core.tools.entities.common_entities import I18nObject
  28. from core.tools.entities.tool_entities import (
  29. ApiProviderAuthType,
  30. ToolInvokeFrom,
  31. ToolParameter,
  32. ToolProviderID,
  33. ToolProviderType,
  34. )
  35. from core.tools.errors import ToolProviderNotFoundError
  36. from core.tools.tool_label_manager import ToolLabelManager
  37. from core.tools.utils.configuration import (
  38. ProviderConfigEncrypter,
  39. ToolParameterConfigurationManager,
  40. )
  41. from core.tools.workflow_as_tool.tool import WorkflowTool
  42. from extensions.ext_database import db
  43. from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
  44. from services.tools.tools_transform_service import ToolTransformService
  45. logger = logging.getLogger(__name__)
  46. class ToolManager:
  47. _builtin_provider_lock = Lock()
  48. _hardcoded_providers = {}
  49. _builtin_providers_loaded = False
  50. _builtin_tools_labels = {}
  51. @classmethod
  52. def get_builtin_provider(
  53. cls, provider: str, tenant_id: str
  54. ) -> BuiltinToolProviderController | PluginToolProviderController:
  55. """
  56. get the builtin provider
  57. :param provider: the name of the provider
  58. :param tenant_id: the id of the tenant
  59. :return: the provider
  60. """
  61. # split provider to
  62. if len(cls._hardcoded_providers) == 0:
  63. # init the builtin providers
  64. cls.load_hardcoded_providers_cache()
  65. if provider not in cls._hardcoded_providers:
  66. # get plugin provider
  67. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  68. if plugin_provider:
  69. return plugin_provider
  70. return cls._hardcoded_providers[provider]
  71. @classmethod
  72. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  73. """
  74. get the plugin provider
  75. """
  76. manager = PluginToolManager()
  77. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  78. if not provider_entity:
  79. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  80. return PluginToolProviderController(
  81. entity=provider_entity.declaration,
  82. plugin_id=provider_entity.plugin_id,
  83. tenant_id=tenant_id,
  84. )
  85. @classmethod
  86. def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
  87. """
  88. get the builtin tool
  89. :param provider: the name of the provider
  90. :param tool_name: the name of the tool
  91. :param tenant_id: the id of the tenant
  92. :return: the provider, the tool
  93. """
  94. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  95. tool = provider_controller.get_tool(tool_name)
  96. return tool
  97. @classmethod
  98. def get_tool_runtime(
  99. cls,
  100. provider_type: ToolProviderType,
  101. provider_id: str,
  102. tool_name: str,
  103. tenant_id: str,
  104. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  105. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  106. ) -> Union[BuiltinTool, ApiTool, WorkflowTool]:
  107. """
  108. get the tool runtime
  109. :param provider_type: the type of the provider
  110. :param provider_name: the name of the provider
  111. :param tool_name: the name of the tool
  112. :return: the tool
  113. """
  114. if provider_type == ToolProviderType.BUILT_IN:
  115. # check if the builtin tool need credentials
  116. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  117. builtin_tool = provider_controller.get_tool(tool_name)
  118. if not builtin_tool:
  119. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  120. if not provider_controller.need_credentials:
  121. return cast(
  122. BuiltinTool,
  123. builtin_tool.fork_tool_runtime(
  124. runtime=ToolRuntime(
  125. tenant_id=tenant_id,
  126. credentials={},
  127. invoke_from=invoke_from,
  128. tool_invoke_from=tool_invoke_from,
  129. )
  130. ),
  131. )
  132. if isinstance(provider_controller, PluginToolProviderController):
  133. provider_id_entity = ToolProviderID(provider_id)
  134. # get credentials
  135. builtin_provider: BuiltinToolProvider | None = (
  136. db.session.query(BuiltinToolProvider)
  137. .filter(
  138. BuiltinToolProvider.tenant_id == tenant_id,
  139. (BuiltinToolProvider.provider == provider_id)
  140. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  141. )
  142. .first()
  143. )
  144. if builtin_provider is None:
  145. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  146. else:
  147. builtin_provider: BuiltinToolProvider | None = (
  148. db.session.query(BuiltinToolProvider)
  149. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  150. .first()
  151. )
  152. if builtin_provider is None:
  153. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  154. # decrypt the credentials
  155. credentials = builtin_provider.credentials
  156. tool_configuration = ProviderConfigEncrypter(
  157. tenant_id=tenant_id,
  158. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  159. provider_type=provider_controller.provider_type.value,
  160. provider_identity=provider_controller.entity.identity.name,
  161. )
  162. decrypted_credentials = tool_configuration.decrypt(credentials)
  163. return cast(
  164. BuiltinTool,
  165. builtin_tool.fork_tool_runtime(
  166. runtime=ToolRuntime(
  167. tenant_id=tenant_id,
  168. credentials=decrypted_credentials,
  169. runtime_parameters={},
  170. invoke_from=invoke_from,
  171. tool_invoke_from=tool_invoke_from,
  172. )
  173. ),
  174. )
  175. elif provider_type == ToolProviderType.API:
  176. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  177. # decrypt the credentials
  178. tool_configuration = ProviderConfigEncrypter(
  179. tenant_id=tenant_id,
  180. config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
  181. provider_type=api_provider.provider_type.value,
  182. provider_identity=api_provider.entity.identity.name,
  183. )
  184. decrypted_credentials = tool_configuration.decrypt(credentials)
  185. return cast(
  186. ApiTool,
  187. api_provider.get_tool(tool_name).fork_tool_runtime(
  188. runtime=ToolRuntime(
  189. tenant_id=tenant_id,
  190. credentials=decrypted_credentials,
  191. invoke_from=invoke_from,
  192. tool_invoke_from=tool_invoke_from,
  193. )
  194. ),
  195. )
  196. elif provider_type == ToolProviderType.WORKFLOW:
  197. workflow_provider = (
  198. db.session.query(WorkflowToolProvider)
  199. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  200. .first()
  201. )
  202. if workflow_provider is None:
  203. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  204. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  205. return cast(
  206. WorkflowTool,
  207. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  208. runtime=ToolRuntime(
  209. tenant_id=tenant_id,
  210. credentials={},
  211. invoke_from=invoke_from,
  212. tool_invoke_from=tool_invoke_from,
  213. )
  214. ),
  215. )
  216. elif provider_type == ToolProviderType.APP:
  217. raise NotImplementedError("app provider not implemented")
  218. else:
  219. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  220. @classmethod
  221. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
  222. """
  223. init runtime parameter
  224. """
  225. parameter_value = parameters.get(parameter_rule.name)
  226. if not parameter_value and parameter_value != 0:
  227. # get default value
  228. parameter_value = parameter_rule.default
  229. if not parameter_value and parameter_rule.required:
  230. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  231. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  232. # check if tool_parameter_config in options
  233. options = [x.value for x in parameter_rule.options]
  234. if parameter_value is not None and parameter_value not in options:
  235. raise ValueError(
  236. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
  237. )
  238. return parameter_rule.type.cast_value(parameter_value)
  239. @classmethod
  240. def get_agent_tool_runtime(
  241. cls,
  242. tenant_id: str,
  243. app_id: str,
  244. agent_tool: AgentToolEntity,
  245. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  246. ) -> Tool:
  247. """
  248. get the agent tool runtime
  249. """
  250. tool_entity = cls.get_tool_runtime(
  251. provider_type=agent_tool.provider_type,
  252. provider_id=agent_tool.provider_id,
  253. tool_name=agent_tool.tool_name,
  254. tenant_id=tenant_id,
  255. invoke_from=invoke_from,
  256. tool_invoke_from=ToolInvokeFrom.AGENT,
  257. )
  258. runtime_parameters = {}
  259. parameters = tool_entity.get_merged_runtime_parameters()
  260. for parameter in parameters:
  261. # check file types
  262. if parameter.type in {
  263. ToolParameter.ToolParameterType.SYSTEM_FILES,
  264. ToolParameter.ToolParameterType.FILE,
  265. ToolParameter.ToolParameterType.FILES,
  266. }:
  267. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  268. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  269. # save tool parameter to tool entity memory
  270. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  271. runtime_parameters[parameter.name] = value
  272. # decrypt runtime parameters
  273. encryption_manager = ToolParameterConfigurationManager(
  274. tenant_id=tenant_id,
  275. tool_runtime=tool_entity,
  276. provider_name=agent_tool.provider_id,
  277. provider_type=agent_tool.provider_type,
  278. identity_id=f"AGENT.{app_id}",
  279. )
  280. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  281. if not tool_entity.runtime:
  282. raise Exception("tool missing runtime")
  283. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  284. return tool_entity
  285. @classmethod
  286. def get_workflow_tool_runtime(
  287. cls,
  288. tenant_id: str,
  289. app_id: str,
  290. node_id: str,
  291. workflow_tool: "ToolEntity",
  292. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  293. ) -> Tool:
  294. """
  295. get the workflow tool runtime
  296. """
  297. tool_runtime = cls.get_tool_runtime(
  298. provider_type=workflow_tool.provider_type,
  299. provider_id=workflow_tool.provider_id,
  300. tool_name=workflow_tool.tool_name,
  301. tenant_id=tenant_id,
  302. invoke_from=invoke_from,
  303. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  304. )
  305. runtime_parameters = {}
  306. parameters = tool_runtime.get_merged_runtime_parameters()
  307. for parameter in parameters:
  308. # save tool parameter to tool entity memory
  309. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  310. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  311. runtime_parameters[parameter.name] = value
  312. # decrypt runtime parameters
  313. encryption_manager = ToolParameterConfigurationManager(
  314. tenant_id=tenant_id,
  315. tool_runtime=tool_runtime,
  316. provider_name=workflow_tool.provider_id,
  317. provider_type=workflow_tool.provider_type,
  318. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  319. )
  320. if runtime_parameters:
  321. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  322. if not tool_runtime.runtime:
  323. raise Exception("tool missing runtime")
  324. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  325. return tool_runtime
  326. @classmethod
  327. def get_tool_runtime_from_plugin(
  328. cls,
  329. tool_type: ToolProviderType,
  330. tenant_id: str,
  331. provider: str,
  332. tool_name: str,
  333. tool_parameters: dict[str, Any],
  334. ) -> Tool:
  335. """
  336. get tool runtime from plugin
  337. """
  338. tool_entity = cls.get_tool_runtime(
  339. provider_type=tool_type,
  340. provider_id=provider,
  341. tool_name=tool_name,
  342. tenant_id=tenant_id,
  343. invoke_from=InvokeFrom.SERVICE_API,
  344. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  345. )
  346. runtime_parameters = {}
  347. parameters = tool_entity.get_merged_runtime_parameters()
  348. for parameter in parameters:
  349. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  350. # save tool parameter to tool entity memory
  351. value = cls._init_runtime_parameter(parameter, tool_parameters)
  352. runtime_parameters[parameter.name] = value
  353. if not tool_entity.runtime:
  354. raise Exception("tool missing runtime")
  355. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  356. return tool_entity
  357. @classmethod
  358. def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
  359. """
  360. get the absolute path of the icon of the builtin provider
  361. :param provider: the name of the provider
  362. :param tenant_id: the id of the tenant
  363. :return: the absolute path of the icon, the mime type of the icon
  364. """
  365. # get provider
  366. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  367. absolute_path = path.join(
  368. path.dirname(path.realpath(__file__)),
  369. "builtin_tool",
  370. "providers",
  371. provider,
  372. "_assets",
  373. provider_controller.entity.identity.icon,
  374. )
  375. # check if the icon exists
  376. if not path.exists(absolute_path):
  377. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  378. # get the mime type
  379. mime_type, _ = mimetypes.guess_type(absolute_path)
  380. mime_type = mime_type or "application/octet-stream"
  381. return absolute_path, mime_type
  382. @classmethod
  383. def list_hardcoded_providers(cls):
  384. # use cache first
  385. if cls._builtin_providers_loaded:
  386. yield from list(cls._hardcoded_providers.values())
  387. return
  388. with cls._builtin_provider_lock:
  389. if cls._builtin_providers_loaded:
  390. yield from list(cls._hardcoded_providers.values())
  391. return
  392. yield from cls._list_hardcoded_providers()
  393. @classmethod
  394. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  395. """
  396. list all the plugin providers
  397. """
  398. manager = PluginToolManager()
  399. provider_entities = manager.fetch_tool_providers(tenant_id)
  400. return [
  401. PluginToolProviderController(
  402. entity=provider.declaration,
  403. plugin_id=provider.plugin_id,
  404. tenant_id=tenant_id,
  405. )
  406. for provider in provider_entities
  407. ]
  408. @classmethod
  409. def list_builtin_providers(
  410. cls, tenant_id: str
  411. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  412. """
  413. list all the builtin providers
  414. """
  415. yield from cls.list_hardcoded_providers()
  416. # get plugin providers
  417. yield from cls.list_plugin_providers(tenant_id)
  418. @classmethod
  419. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  420. """
  421. list all the builtin providers
  422. """
  423. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  424. if provider_path.startswith("__"):
  425. continue
  426. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  427. if provider_path.startswith("__"):
  428. continue
  429. # init provider
  430. try:
  431. provider_class = load_single_subclass_from_source(
  432. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  433. script_path=path.join(
  434. path.dirname(path.realpath(__file__)),
  435. "builtin_tool",
  436. "providers",
  437. provider_path,
  438. f"{provider_path}.py",
  439. ),
  440. parent_type=BuiltinToolProviderController,
  441. )
  442. provider: BuiltinToolProviderController = provider_class()
  443. cls._hardcoded_providers[provider.entity.identity.name] = provider
  444. for tool in provider.get_tools():
  445. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  446. yield provider
  447. except Exception as e:
  448. logger.error(f"load builtin provider error: {e}")
  449. continue
  450. # set builtin providers loaded
  451. cls._builtin_providers_loaded = True
  452. @classmethod
  453. def load_hardcoded_providers_cache(cls):
  454. for _ in cls.list_hardcoded_providers():
  455. pass
  456. @classmethod
  457. def clear_hardcoded_providers_cache(cls):
  458. cls._hardcoded_providers = {}
  459. cls._builtin_providers_loaded = False
  460. @classmethod
  461. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  462. """
  463. get the tool label
  464. :param tool_name: the name of the tool
  465. :return: the label of the tool
  466. """
  467. if len(cls._builtin_tools_labels) == 0:
  468. # init the builtin providers
  469. cls.load_hardcoded_providers_cache()
  470. if tool_name not in cls._builtin_tools_labels:
  471. return None
  472. return cls._builtin_tools_labels[tool_name]
  473. @classmethod
  474. def list_providers_from_api(
  475. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  476. ) -> list[ToolProviderApiEntity]:
  477. result_providers: dict[str, ToolProviderApiEntity] = {}
  478. filters = []
  479. if not typ:
  480. filters.extend(["builtin", "api", "workflow"])
  481. else:
  482. filters.append(typ)
  483. if "builtin" in filters:
  484. # get builtin providers
  485. builtin_providers = cls.list_builtin_providers(tenant_id)
  486. # get db builtin providers
  487. db_builtin_providers: list[BuiltinToolProvider] = (
  488. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  489. )
  490. # rewrite db_builtin_providers
  491. for db_provider in db_builtin_providers:
  492. try:
  493. ToolProviderID(db_provider.provider)
  494. except Exception:
  495. db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
  496. find_db_builtin_provider = lambda provider: next(
  497. (x for x in db_builtin_providers if x.provider == provider), None
  498. )
  499. # append builtin providers
  500. for provider in builtin_providers:
  501. # handle include, exclude
  502. if is_filtered(
  503. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  504. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  505. data=provider,
  506. name_func=lambda x: x.identity.name,
  507. ):
  508. continue
  509. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  510. provider_controller=provider,
  511. db_provider=find_db_builtin_provider(provider.entity.identity.name),
  512. decrypt_credentials=False,
  513. )
  514. if isinstance(provider, PluginToolProviderController):
  515. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  516. else:
  517. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  518. # get db api providers
  519. if "api" in filters:
  520. db_api_providers: list[ApiToolProvider] = (
  521. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  522. )
  523. api_provider_controllers = [
  524. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  525. for provider in db_api_providers
  526. ]
  527. # get labels
  528. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  529. for api_provider_controller in api_provider_controllers:
  530. user_provider = ToolTransformService.api_provider_to_user_provider(
  531. provider_controller=api_provider_controller["controller"],
  532. db_provider=api_provider_controller["provider"],
  533. decrypt_credentials=False,
  534. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  535. )
  536. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  537. if "workflow" in filters:
  538. # get workflow providers
  539. workflow_providers: list[WorkflowToolProvider] = (
  540. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  541. )
  542. workflow_provider_controllers = []
  543. for provider in workflow_providers:
  544. try:
  545. workflow_provider_controllers.append(
  546. ToolTransformService.workflow_provider_to_controller(db_provider=provider)
  547. )
  548. except Exception as e:
  549. # app has been deleted
  550. pass
  551. labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
  552. for provider_controller in workflow_provider_controllers:
  553. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  554. provider_controller=provider_controller,
  555. labels=labels.get(provider_controller.provider_id, []),
  556. )
  557. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  558. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  559. @classmethod
  560. def get_api_provider_controller(
  561. cls, tenant_id: str, provider_id: str
  562. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  563. """
  564. get the api provider
  565. :param provider_name: the name of the provider
  566. :return: the provider controller, the credentials
  567. """
  568. provider: ApiToolProvider | None = (
  569. db.session.query(ApiToolProvider)
  570. .filter(
  571. ApiToolProvider.id == provider_id,
  572. ApiToolProvider.tenant_id == tenant_id,
  573. )
  574. .first()
  575. )
  576. if provider is None:
  577. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  578. controller = ApiToolProviderController.from_db(
  579. provider,
  580. ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  581. )
  582. controller.load_bundled_tools(provider.tools)
  583. return controller, provider.credentials
  584. @classmethod
  585. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  586. """
  587. get api provider
  588. """
  589. provider_obj: ApiToolProvider | None = (
  590. db.session.query(ApiToolProvider)
  591. .filter(
  592. ApiToolProvider.tenant_id == tenant_id,
  593. ApiToolProvider.name == provider,
  594. )
  595. .first()
  596. )
  597. if provider_obj is None:
  598. raise ValueError(f"you have not added provider {provider}")
  599. try:
  600. credentials = json.loads(provider_obj.credentials_str) or {}
  601. except:
  602. credentials = {}
  603. # package tool provider controller
  604. controller = ApiToolProviderController.from_db(
  605. provider_obj,
  606. ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  607. )
  608. # init tool configuration
  609. tool_configuration = ProviderConfigEncrypter(
  610. tenant_id=tenant_id,
  611. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  612. provider_type=controller.provider_type.value,
  613. provider_identity=controller.entity.identity.name,
  614. )
  615. decrypted_credentials = tool_configuration.decrypt(credentials)
  616. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  617. try:
  618. icon = json.loads(provider_obj.icon)
  619. except:
  620. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  621. # add tool labels
  622. labels = ToolLabelManager.get_tool_labels(controller)
  623. return jsonable_encoder(
  624. {
  625. "schema_type": provider_obj.schema_type,
  626. "schema": provider_obj.schema,
  627. "tools": provider_obj.tools,
  628. "icon": icon,
  629. "description": provider_obj.description,
  630. "credentials": masked_credentials,
  631. "privacy_policy": provider_obj.privacy_policy,
  632. "custom_disclaimer": provider_obj.custom_disclaimer,
  633. "labels": labels,
  634. }
  635. )
  636. @classmethod
  637. def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
  638. """
  639. get the tool icon
  640. :param tenant_id: the id of the tenant
  641. :param provider_type: the type of the provider
  642. :param provider_id: the id of the provider
  643. :return:
  644. """
  645. provider_type = provider_type
  646. provider_id = provider_id
  647. if provider_type == ToolProviderType.BUILT_IN:
  648. return (
  649. dify_config.CONSOLE_API_URL
  650. + "/console/api/workspaces/current/tool-provider/builtin/"
  651. + provider_id
  652. + "/icon"
  653. )
  654. elif provider_type == ToolProviderType.API:
  655. try:
  656. api_provider: ApiToolProvider | None = (
  657. db.session.query(ApiToolProvider)
  658. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  659. .first()
  660. )
  661. if not api_provider:
  662. raise ValueError("api tool not found")
  663. return json.loads(api_provider.icon)
  664. except:
  665. return {"background": "#252525", "content": "\ud83d\ude01"}
  666. elif provider_type == ToolProviderType.WORKFLOW:
  667. workflow_provider: WorkflowToolProvider | None = (
  668. db.session.query(WorkflowToolProvider)
  669. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  670. .first()
  671. )
  672. if workflow_provider is None:
  673. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  674. return json.loads(workflow_provider.icon)
  675. else:
  676. raise ValueError(f"provider type {provider_type} not found")
  677. ToolManager.load_hardcoded_providers_cache()