tool_manager.py 33 KB


  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Union, cast
  8. from yarl import URL
  9. import contexts
  10. from core.plugin.entities.plugin import GenericProviderID
  11. from core.plugin.manager.tool import PluginToolManager
  12. from core.tools.__base.tool_runtime import ToolRuntime
  13. from core.tools.plugin_tool.provider import PluginToolProviderController
  14. from core.tools.plugin_tool.tool import PluginTool
  15. if TYPE_CHECKING:
  16. from core.workflow.nodes.tool.entities import ToolEntity
  17. from configs import dify_config
  18. from core.agent.entities import AgentToolEntity
  19. from core.app.entities.app_invoke_entities import InvokeFrom
  20. from core.helper.module_import_helper import load_single_subclass_from_source
  21. from core.helper.position_helper import is_filtered
  22. from core.model_runtime.utils.encoders import jsonable_encoder
  23. from core.tools.__base.tool import Tool
  24. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  25. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  26. from core.tools.builtin_tool.tool import BuiltinTool
  27. from core.tools.custom_tool.provider import ApiToolProviderController
  28. from core.tools.custom_tool.tool import ApiTool
  29. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  30. from core.tools.entities.common_entities import I18nObject
  31. from core.tools.entities.tool_entities import (
  32. ApiProviderAuthType,
  33. ToolInvokeFrom,
  34. ToolParameter,
  35. ToolProviderType,
  36. )
  37. from core.tools.errors import ToolProviderNotFoundError
  38. from core.tools.tool_label_manager import ToolLabelManager
  39. from core.tools.utils.configuration import (
  40. ProviderConfigEncrypter,
  41. ToolParameterConfigurationManager,
  42. )
  43. from core.tools.workflow_as_tool.tool import WorkflowTool
  44. from extensions.ext_database import db
  45. from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
  46. from services.tools.tools_transform_service import ToolTransformService
  47. logger = logging.getLogger(__name__)
  48. class ToolManager:
  49. _builtin_provider_lock = Lock()
  50. _hardcoded_providers = {}
  51. _builtin_providers_loaded = False
  52. _builtin_tools_labels = {}
  53. @classmethod
  54. def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
  55. """
  56. get the hardcoded provider
  57. """
  58. if len(cls._hardcoded_providers) == 0:
  59. # init the builtin providers
  60. cls.load_hardcoded_providers_cache()
  61. return cls._hardcoded_providers[provider]
  62. @classmethod
  63. def get_builtin_provider(
  64. cls, provider: str, tenant_id: str
  65. ) -> BuiltinToolProviderController | PluginToolProviderController:
  66. """
  67. get the builtin provider
  68. :param provider: the name of the provider
  69. :param tenant_id: the id of the tenant
  70. :return: the provider
  71. """
  72. # split provider to
  73. if len(cls._hardcoded_providers) == 0:
  74. # init the builtin providers
  75. cls.load_hardcoded_providers_cache()
  76. if provider not in cls._hardcoded_providers:
  77. # get plugin provider
  78. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  79. if plugin_provider:
  80. return plugin_provider
  81. return cls._hardcoded_providers[provider]
  82. @classmethod
  83. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  84. """
  85. get the plugin provider
  86. """
  87. # check if context is set
  88. try:
  89. contexts.plugin_tool_providers.get()
  90. except LookupError:
  91. contexts.plugin_tool_providers.set({})
  92. contexts.plugin_tool_providers_lock.set(Lock())
  93. with contexts.plugin_tool_providers_lock.get():
  94. plugin_tool_providers = contexts.plugin_tool_providers.get()
  95. if provider in plugin_tool_providers:
  96. return plugin_tool_providers[provider]
  97. manager = PluginToolManager()
  98. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  99. if not provider_entity:
  100. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  101. controller = PluginToolProviderController(
  102. entity=provider_entity.declaration,
  103. plugin_id=provider_entity.plugin_id,
  104. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  105. tenant_id=tenant_id,
  106. )
  107. plugin_tool_providers[provider] = controller
  108. return controller
  109. @classmethod
  110. def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
  111. """
  112. get the builtin tool
  113. :param provider: the name of the provider
  114. :param tool_name: the name of the tool
  115. :param tenant_id: the id of the tenant
  116. :return: the provider, the tool
  117. """
  118. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  119. tool = provider_controller.get_tool(tool_name)
  120. return tool
  121. @classmethod
  122. def get_tool_runtime(
  123. cls,
  124. provider_type: ToolProviderType,
  125. provider_id: str,
  126. tool_name: str,
  127. tenant_id: str,
  128. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  129. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  130. ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]:
  131. """
  132. get the tool runtime
  133. :param provider_type: the type of the provider
  134. :param provider_name: the name of the provider
  135. :param tool_name: the name of the tool
  136. :return: the tool
  137. """
  138. if provider_type == ToolProviderType.BUILT_IN:
  139. # check if the builtin tool need credentials
  140. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  141. builtin_tool = provider_controller.get_tool(tool_name)
  142. if not builtin_tool:
  143. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  144. if not provider_controller.need_credentials:
  145. return cast(
  146. BuiltinTool,
  147. builtin_tool.fork_tool_runtime(
  148. runtime=ToolRuntime(
  149. tenant_id=tenant_id,
  150. credentials={},
  151. invoke_from=invoke_from,
  152. tool_invoke_from=tool_invoke_from,
  153. )
  154. ),
  155. )
  156. if isinstance(provider_controller, PluginToolProviderController):
  157. provider_id_entity = GenericProviderID(provider_id)
  158. # get credentials
  159. builtin_provider: BuiltinToolProvider | None = (
  160. db.session.query(BuiltinToolProvider)
  161. .filter(
  162. BuiltinToolProvider.tenant_id == tenant_id,
  163. (BuiltinToolProvider.provider == provider_id)
  164. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  165. )
  166. .first()
  167. )
  168. if builtin_provider is None:
  169. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  170. else:
  171. builtin_provider: BuiltinToolProvider | None = (
  172. db.session.query(BuiltinToolProvider)
  173. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  174. .first()
  175. )
  176. if builtin_provider is None:
  177. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  178. # decrypt the credentials
  179. credentials = builtin_provider.credentials
  180. tool_configuration = ProviderConfigEncrypter(
  181. tenant_id=tenant_id,
  182. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  183. provider_type=provider_controller.provider_type.value,
  184. provider_identity=provider_controller.entity.identity.name,
  185. )
  186. decrypted_credentials = tool_configuration.decrypt(credentials)
  187. return cast(
  188. BuiltinTool,
  189. builtin_tool.fork_tool_runtime(
  190. runtime=ToolRuntime(
  191. tenant_id=tenant_id,
  192. credentials=decrypted_credentials,
  193. runtime_parameters={},
  194. invoke_from=invoke_from,
  195. tool_invoke_from=tool_invoke_from,
  196. )
  197. ),
  198. )
  199. elif provider_type == ToolProviderType.API:
  200. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  201. # decrypt the credentials
  202. tool_configuration = ProviderConfigEncrypter(
  203. tenant_id=tenant_id,
  204. config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
  205. provider_type=api_provider.provider_type.value,
  206. provider_identity=api_provider.entity.identity.name,
  207. )
  208. decrypted_credentials = tool_configuration.decrypt(credentials)
  209. return cast(
  210. ApiTool,
  211. api_provider.get_tool(tool_name).fork_tool_runtime(
  212. runtime=ToolRuntime(
  213. tenant_id=tenant_id,
  214. credentials=decrypted_credentials,
  215. invoke_from=invoke_from,
  216. tool_invoke_from=tool_invoke_from,
  217. )
  218. ),
  219. )
  220. elif provider_type == ToolProviderType.WORKFLOW:
  221. workflow_provider = (
  222. db.session.query(WorkflowToolProvider)
  223. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  224. .first()
  225. )
  226. if workflow_provider is None:
  227. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  228. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  229. return cast(
  230. WorkflowTool,
  231. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  232. runtime=ToolRuntime(
  233. tenant_id=tenant_id,
  234. credentials={},
  235. invoke_from=invoke_from,
  236. tool_invoke_from=tool_invoke_from,
  237. )
  238. ),
  239. )
  240. elif provider_type == ToolProviderType.APP:
  241. raise NotImplementedError("app provider not implemented")
  242. elif provider_type == ToolProviderType.PLUGIN:
  243. return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
  244. else:
  245. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  246. @classmethod
  247. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
  248. """
  249. init runtime parameter
  250. """
  251. parameter_value = parameters.get(parameter_rule.name)
  252. if not parameter_value and parameter_value != 0:
  253. # get default value
  254. parameter_value = parameter_rule.default
  255. if not parameter_value and parameter_rule.required:
  256. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  257. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  258. # check if tool_parameter_config in options
  259. options = [x.value for x in parameter_rule.options]
  260. if parameter_value is not None and parameter_value not in options:
  261. raise ValueError(
  262. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
  263. )
  264. return parameter_rule.type.cast_value(parameter_value)
  265. @classmethod
  266. def get_agent_tool_runtime(
  267. cls,
  268. tenant_id: str,
  269. app_id: str,
  270. agent_tool: AgentToolEntity,
  271. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  272. ) -> Tool:
  273. """
  274. get the agent tool runtime
  275. """
  276. tool_entity = cls.get_tool_runtime(
  277. provider_type=agent_tool.provider_type,
  278. provider_id=agent_tool.provider_id,
  279. tool_name=agent_tool.tool_name,
  280. tenant_id=tenant_id,
  281. invoke_from=invoke_from,
  282. tool_invoke_from=ToolInvokeFrom.AGENT,
  283. )
  284. runtime_parameters = {}
  285. parameters = tool_entity.get_merged_runtime_parameters()
  286. for parameter in parameters:
  287. # check file types
  288. if (
  289. parameter.type
  290. in {
  291. ToolParameter.ToolParameterType.SYSTEM_FILES,
  292. ToolParameter.ToolParameterType.FILE,
  293. ToolParameter.ToolParameterType.FILES,
  294. }
  295. and parameter.required
  296. ):
  297. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  298. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  299. # save tool parameter to tool entity memory
  300. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  301. runtime_parameters[parameter.name] = value
  302. # decrypt runtime parameters
  303. encryption_manager = ToolParameterConfigurationManager(
  304. tenant_id=tenant_id,
  305. tool_runtime=tool_entity,
  306. provider_name=agent_tool.provider_id,
  307. provider_type=agent_tool.provider_type,
  308. identity_id=f"AGENT.{app_id}",
  309. )
  310. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  311. if not tool_entity.runtime:
  312. raise Exception("tool missing runtime")
  313. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  314. return tool_entity
  315. @classmethod
  316. def get_workflow_tool_runtime(
  317. cls,
  318. tenant_id: str,
  319. app_id: str,
  320. node_id: str,
  321. workflow_tool: "ToolEntity",
  322. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  323. ) -> Tool:
  324. """
  325. get the workflow tool runtime
  326. """
  327. tool_runtime = cls.get_tool_runtime(
  328. provider_type=workflow_tool.provider_type,
  329. provider_id=workflow_tool.provider_id,
  330. tool_name=workflow_tool.tool_name,
  331. tenant_id=tenant_id,
  332. invoke_from=invoke_from,
  333. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  334. )
  335. runtime_parameters = {}
  336. parameters = tool_runtime.get_merged_runtime_parameters()
  337. for parameter in parameters:
  338. # save tool parameter to tool entity memory
  339. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  340. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  341. runtime_parameters[parameter.name] = value
  342. # decrypt runtime parameters
  343. encryption_manager = ToolParameterConfigurationManager(
  344. tenant_id=tenant_id,
  345. tool_runtime=tool_runtime,
  346. provider_name=workflow_tool.provider_id,
  347. provider_type=workflow_tool.provider_type,
  348. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  349. )
  350. if runtime_parameters:
  351. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  352. if not tool_runtime.runtime:
  353. raise Exception("tool missing runtime")
  354. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  355. return tool_runtime
  356. @classmethod
  357. def get_tool_runtime_from_plugin(
  358. cls,
  359. tool_type: ToolProviderType,
  360. tenant_id: str,
  361. provider: str,
  362. tool_name: str,
  363. tool_parameters: dict[str, Any],
  364. ) -> Tool:
  365. """
  366. get tool runtime from plugin
  367. """
  368. tool_entity = cls.get_tool_runtime(
  369. provider_type=tool_type,
  370. provider_id=provider,
  371. tool_name=tool_name,
  372. tenant_id=tenant_id,
  373. invoke_from=InvokeFrom.SERVICE_API,
  374. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  375. )
  376. runtime_parameters = {}
  377. parameters = tool_entity.get_merged_runtime_parameters()
  378. for parameter in parameters:
  379. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  380. # save tool parameter to tool entity memory
  381. value = cls._init_runtime_parameter(parameter, tool_parameters)
  382. runtime_parameters[parameter.name] = value
  383. if not tool_entity.runtime:
  384. raise Exception("tool missing runtime")
  385. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  386. return tool_entity
  387. @classmethod
  388. def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
  389. """
  390. get the absolute path of the icon of the hardcoded provider
  391. :param provider: the name of the provider
  392. :param tenant_id: the id of the tenant
  393. :return: the absolute path of the icon, the mime type of the icon
  394. """
  395. # get provider
  396. provider_controller = cls.get_hardcoded_provider(provider)
  397. absolute_path = path.join(
  398. path.dirname(path.realpath(__file__)),
  399. "builtin_tool",
  400. "providers",
  401. provider,
  402. "_assets",
  403. provider_controller.entity.identity.icon,
  404. )
  405. # check if the icon exists
  406. if not path.exists(absolute_path):
  407. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  408. # get the mime type
  409. mime_type, _ = mimetypes.guess_type(absolute_path)
  410. mime_type = mime_type or "application/octet-stream"
  411. return absolute_path, mime_type
  412. @classmethod
  413. def list_hardcoded_providers(cls):
  414. # use cache first
  415. if cls._builtin_providers_loaded:
  416. yield from list(cls._hardcoded_providers.values())
  417. return
  418. with cls._builtin_provider_lock:
  419. if cls._builtin_providers_loaded:
  420. yield from list(cls._hardcoded_providers.values())
  421. return
  422. yield from cls._list_hardcoded_providers()
  423. @classmethod
  424. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  425. """
  426. list all the plugin providers
  427. """
  428. manager = PluginToolManager()
  429. provider_entities = manager.fetch_tool_providers(tenant_id)
  430. return [
  431. PluginToolProviderController(
  432. entity=provider.declaration,
  433. plugin_id=provider.plugin_id,
  434. plugin_unique_identifier=provider.plugin_unique_identifier,
  435. tenant_id=tenant_id,
  436. )
  437. for provider in provider_entities
  438. ]
  439. @classmethod
  440. def list_builtin_providers(
  441. cls, tenant_id: str
  442. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  443. """
  444. list all the builtin providers
  445. """
  446. yield from cls.list_hardcoded_providers()
  447. # get plugin providers
  448. yield from cls.list_plugin_providers(tenant_id)
  449. @classmethod
  450. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  451. """
  452. list all the builtin providers
  453. """
  454. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  455. if provider_path.startswith("__"):
  456. continue
  457. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  458. if provider_path.startswith("__"):
  459. continue
  460. # init provider
  461. try:
  462. provider_class = load_single_subclass_from_source(
  463. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  464. script_path=path.join(
  465. path.dirname(path.realpath(__file__)),
  466. "builtin_tool",
  467. "providers",
  468. provider_path,
  469. f"{provider_path}.py",
  470. ),
  471. parent_type=BuiltinToolProviderController,
  472. )
  473. provider: BuiltinToolProviderController = provider_class()
  474. cls._hardcoded_providers[provider.entity.identity.name] = provider
  475. for tool in provider.get_tools():
  476. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  477. yield provider
  478. except Exception as e:
  479. logger.exception(f"load builtin provider {provider}")
  480. continue
  481. # set builtin providers loaded
  482. cls._builtin_providers_loaded = True
  483. @classmethod
  484. def load_hardcoded_providers_cache(cls):
  485. for _ in cls.list_hardcoded_providers():
  486. pass
  487. @classmethod
  488. def clear_hardcoded_providers_cache(cls):
  489. cls._hardcoded_providers = {}
  490. cls._builtin_providers_loaded = False
  491. @classmethod
  492. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  493. """
  494. get the tool label
  495. :param tool_name: the name of the tool
  496. :return: the label of the tool
  497. """
  498. if len(cls._builtin_tools_labels) == 0:
  499. # init the builtin providers
  500. cls.load_hardcoded_providers_cache()
  501. if tool_name not in cls._builtin_tools_labels:
  502. return None
  503. return cls._builtin_tools_labels[tool_name]
  504. @classmethod
  505. def list_providers_from_api(
  506. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  507. ) -> list[ToolProviderApiEntity]:
  508. result_providers: dict[str, ToolProviderApiEntity] = {}
  509. filters = []
  510. if not typ:
  511. filters.extend(["builtin", "api", "workflow"])
  512. else:
  513. filters.append(typ)
  514. if "builtin" in filters:
  515. # get builtin providers
  516. builtin_providers = cls.list_builtin_providers(tenant_id)
  517. # get db builtin providers
  518. db_builtin_providers: list[BuiltinToolProvider] = (
  519. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  520. )
  521. # rewrite db_builtin_providers
  522. for db_provider in db_builtin_providers:
  523. tool_provider_id = GenericProviderID(db_provider.provider)
  524. db_provider.provider = tool_provider_id.to_string()
  525. find_db_builtin_provider = lambda provider: next(
  526. (x for x in db_builtin_providers if x.provider == provider), None
  527. )
  528. # append builtin providers
  529. for provider in builtin_providers:
  530. # handle include, exclude
  531. if is_filtered(
  532. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  533. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  534. data=provider,
  535. name_func=lambda x: x.identity.name,
  536. ):
  537. continue
  538. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  539. provider_controller=provider,
  540. db_provider=find_db_builtin_provider(provider.entity.identity.name),
  541. decrypt_credentials=False,
  542. )
  543. if isinstance(provider, PluginToolProviderController):
  544. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  545. else:
  546. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  547. # get db api providers
  548. if "api" in filters:
  549. db_api_providers: list[ApiToolProvider] = (
  550. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  551. )
  552. api_provider_controllers = [
  553. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  554. for provider in db_api_providers
  555. ]
  556. # get labels
  557. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  558. for api_provider_controller in api_provider_controllers:
  559. user_provider = ToolTransformService.api_provider_to_user_provider(
  560. provider_controller=api_provider_controller["controller"],
  561. db_provider=api_provider_controller["provider"],
  562. decrypt_credentials=False,
  563. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  564. )
  565. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  566. if "workflow" in filters:
  567. # get workflow providers
  568. workflow_providers: list[WorkflowToolProvider] = (
  569. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  570. )
  571. workflow_provider_controllers = []
  572. for provider in workflow_providers:
  573. try:
  574. workflow_provider_controllers.append(
  575. ToolTransformService.workflow_provider_to_controller(db_provider=provider)
  576. )
  577. except Exception as e:
  578. # app has been deleted
  579. pass
  580. labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
  581. for provider_controller in workflow_provider_controllers:
  582. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  583. provider_controller=provider_controller,
  584. labels=labels.get(provider_controller.provider_id, []),
  585. )
  586. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  587. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  588. @classmethod
  589. def get_api_provider_controller(
  590. cls, tenant_id: str, provider_id: str
  591. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  592. """
  593. get the api provider
  594. :param provider_name: the name of the provider
  595. :return: the provider controller, the credentials
  596. """
  597. provider: ApiToolProvider | None = (
  598. db.session.query(ApiToolProvider)
  599. .filter(
  600. ApiToolProvider.id == provider_id,
  601. ApiToolProvider.tenant_id == tenant_id,
  602. )
  603. .first()
  604. )
  605. if provider is None:
  606. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  607. controller = ApiToolProviderController.from_db(
  608. provider,
  609. ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  610. )
  611. controller.load_bundled_tools(provider.tools)
  612. return controller, provider.credentials
  613. @classmethod
  614. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  615. """
  616. get api provider
  617. """
  618. """
  619. get tool provider
  620. """
  621. provider_name = provider
  622. provider_obj: ApiToolProvider = (
  623. db.session.query(ApiToolProvider)
  624. .filter(
  625. ApiToolProvider.tenant_id == tenant_id,
  626. ApiToolProvider.name == provider,
  627. )
  628. .first()
  629. )
  630. if provider_obj is None:
  631. raise ValueError(f"you have not added provider {provider_name}")
  632. try:
  633. credentials = json.loads(provider_obj.credentials_str) or {}
  634. except:
  635. credentials = {}
  636. # package tool provider controller
  637. controller = ApiToolProviderController.from_db(
  638. provider_obj,
  639. ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  640. )
  641. # init tool configuration
  642. tool_configuration = ProviderConfigEncrypter(
  643. tenant_id=tenant_id,
  644. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  645. provider_type=controller.provider_type.value,
  646. provider_identity=controller.entity.identity.name,
  647. )
  648. decrypted_credentials = tool_configuration.decrypt(credentials)
  649. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  650. try:
  651. icon = json.loads(provider_obj.icon)
  652. except:
  653. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  654. # add tool labels
  655. labels = ToolLabelManager.get_tool_labels(controller)
  656. return jsonable_encoder(
  657. {
  658. "schema_type": provider_obj.schema_type,
  659. "schema": provider_obj.schema,
  660. "tools": provider_obj.tools,
  661. "icon": icon,
  662. "description": provider_obj.description,
  663. "credentials": masked_credentials,
  664. "privacy_policy": provider_obj.privacy_policy,
  665. "custom_disclaimer": provider_obj.custom_disclaimer,
  666. "labels": labels,
  667. }
  668. )
  669. @classmethod
  670. def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
  671. return (
  672. dify_config.CONSOLE_API_URL
  673. + "/console/api/workspaces/current/tool-provider/builtin/"
  674. + provider_id
  675. + "/icon"
  676. )
  677. @classmethod
  678. def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
  679. return str(
  680. URL(dify_config.CONSOLE_API_URL)
  681. / "console"
  682. / "api"
  683. / "workspaces"
  684. / "current"
  685. / "plugin"
  686. / "icon"
  687. % {"tenant_id": tenant_id, "filename": filename}
  688. )
  689. @classmethod
  690. def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  691. try:
  692. workflow_provider: WorkflowToolProvider | None = (
  693. db.session.query(WorkflowToolProvider)
  694. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  695. .first()
  696. )
  697. if workflow_provider is None:
  698. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  699. return json.loads(workflow_provider.icon)
  700. except:
  701. return {"background": "#252525", "content": "\ud83d\ude01"}
  702. @classmethod
  703. def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  704. try:
  705. api_provider: ApiToolProvider | None = (
  706. db.session.query(ApiToolProvider)
  707. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  708. .first()
  709. )
  710. if api_provider is None:
  711. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  712. return json.loads(api_provider.icon)
  713. except:
  714. return {"background": "#252525", "content": "\ud83d\ude01"}
  715. @classmethod
  716. def get_tool_icon(
  717. cls,
  718. tenant_id: str,
  719. provider_type: ToolProviderType,
  720. provider_id: str,
  721. ) -> Union[str, dict]:
  722. """
  723. get the tool icon
  724. :param tenant_id: the id of the tenant
  725. :param provider_type: the type of the provider
  726. :param provider_id: the id of the provider
  727. :return:
  728. """
  729. provider_type = provider_type
  730. provider_id = provider_id
  731. if provider_type == ToolProviderType.BUILT_IN:
  732. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  733. if isinstance(provider, PluginToolProviderController):
  734. try:
  735. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  736. except:
  737. return {"background": "#252525", "content": "\ud83d\ude01"}
  738. return cls.generate_builtin_tool_icon_url(provider_id)
  739. elif provider_type == ToolProviderType.API:
  740. return cls.generate_api_tool_icon_url(tenant_id, provider_id)
  741. elif provider_type == ToolProviderType.WORKFLOW:
  742. return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
  743. elif provider_type == ToolProviderType.PLUGIN:
  744. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  745. if isinstance(provider, PluginToolProviderController):
  746. try:
  747. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  748. except:
  749. return {"background": "#252525", "content": "\ud83d\ude01"}
  750. raise ValueError(f"plugin provider {provider_id} not found")
  751. else:
  752. raise ValueError(f"provider type {provider_type} not found")
  753. ToolManager.load_hardcoded_providers_cache()