tool_manager.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785
  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Union, cast
  8. from core.plugin.manager.tool import PluginToolManager
  9. from core.tools.__base.tool_runtime import ToolRuntime
  10. from core.tools.plugin_tool.provider import PluginToolProviderController
  11. from core.tools.plugin_tool.tool import PluginTool
  12. if TYPE_CHECKING:
  13. from core.workflow.nodes.tool.entities import ToolEntity
  14. from configs import dify_config
  15. from core.agent.entities import AgentToolEntity
  16. from core.app.entities.app_invoke_entities import InvokeFrom
  17. from core.helper.module_import_helper import load_single_subclass_from_source
  18. from core.helper.position_helper import is_filtered
  19. from core.model_runtime.utils.encoders import jsonable_encoder
  20. from core.tools.__base.tool import Tool
  21. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  22. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  23. from core.tools.builtin_tool.tool import BuiltinTool
  24. from core.tools.custom_tool.provider import ApiToolProviderController
  25. from core.tools.custom_tool.tool import ApiTool
  26. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  27. from core.tools.entities.common_entities import I18nObject
  28. from core.tools.entities.tool_entities import (
  29. ApiProviderAuthType,
  30. ToolInvokeFrom,
  31. ToolParameter,
  32. ToolProviderID,
  33. ToolProviderType,
  34. )
  35. from core.tools.errors import ToolProviderNotFoundError
  36. from core.tools.tool_label_manager import ToolLabelManager
  37. from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
  38. from core.tools.utils.tool_parameter_converter import ToolParameterConverter
  39. from core.tools.workflow_as_tool.tool import WorkflowTool
  40. from extensions.ext_database import db
  41. from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
  42. from services.tools.tools_transform_service import ToolTransformService
  43. logger = logging.getLogger(__name__)
  44. class ToolManager:
  45. _builtin_provider_lock = Lock()
  46. _hardcoded_providers = {}
  47. _builtin_providers_loaded = False
  48. _builtin_tools_labels = {}
  49. @classmethod
  50. def get_builtin_provider(
  51. cls, provider: str, tenant_id: str
  52. ) -> BuiltinToolProviderController | PluginToolProviderController:
  53. """
  54. get the builtin provider
  55. :param provider: the name of the provider
  56. :param tenant_id: the id of the tenant
  57. :return: the provider
  58. """
  59. # split provider to
  60. if len(cls._hardcoded_providers) == 0:
  61. # init the builtin providers
  62. cls.load_hardcoded_providers_cache()
  63. if provider not in cls._hardcoded_providers:
  64. # get plugin provider
  65. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  66. if plugin_provider:
  67. return plugin_provider
  68. return cls._hardcoded_providers[provider]
  69. @classmethod
  70. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  71. """
  72. get the plugin provider
  73. """
  74. manager = PluginToolManager()
  75. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  76. if not provider_entity:
  77. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  78. return PluginToolProviderController(
  79. entity=provider_entity.declaration,
  80. plugin_id=provider_entity.plugin_id,
  81. tenant_id=tenant_id,
  82. )
  83. @classmethod
  84. def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
  85. """
  86. get the builtin tool
  87. :param provider: the name of the provider
  88. :param tool_name: the name of the tool
  89. :param tenant_id: the id of the tenant
  90. :return: the provider, the tool
  91. """
  92. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  93. tool = provider_controller.get_tool(tool_name)
  94. return tool
  95. @classmethod
  96. def get_tool_runtime(
  97. cls,
  98. provider_type: ToolProviderType,
  99. provider_id: str,
  100. tool_name: str,
  101. tenant_id: str,
  102. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  103. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  104. ) -> Union[BuiltinTool, ApiTool, WorkflowTool]:
  105. """
  106. get the tool runtime
  107. :param provider_type: the type of the provider
  108. :param provider_name: the name of the provider
  109. :param tool_name: the name of the tool
  110. :return: the tool
  111. """
  112. if provider_type == ToolProviderType.BUILT_IN:
  113. # check if the builtin tool need credentials
  114. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  115. builtin_tool = provider_controller.get_tool(tool_name)
  116. if not builtin_tool:
  117. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  118. if not provider_controller.need_credentials:
  119. return cast(
  120. BuiltinTool,
  121. builtin_tool.fork_tool_runtime(
  122. runtime=ToolRuntime(
  123. tenant_id=tenant_id,
  124. credentials={},
  125. invoke_from=invoke_from,
  126. tool_invoke_from=tool_invoke_from,
  127. )
  128. ),
  129. )
  130. if isinstance(provider_controller, PluginToolProviderController):
  131. provider_id_entity = ToolProviderID(provider_id)
  132. # get credentials
  133. builtin_provider: BuiltinToolProvider | None = (
  134. db.session.query(BuiltinToolProvider)
  135. .filter(
  136. BuiltinToolProvider.tenant_id == tenant_id,
  137. (BuiltinToolProvider.provider == provider_id)
  138. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  139. )
  140. .first()
  141. )
  142. if builtin_provider is None:
  143. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  144. else:
  145. builtin_provider: BuiltinToolProvider | None = (
  146. db.session.query(BuiltinToolProvider)
  147. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  148. .first()
  149. )
  150. if builtin_provider is None:
  151. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  152. # decrypt the credentials
  153. credentials = builtin_provider.credentials
  154. tool_configuration = ProviderConfigEncrypter(
  155. tenant_id=tenant_id,
  156. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  157. provider_type=provider_controller.provider_type.value,
  158. provider_identity=provider_controller.entity.identity.name,
  159. )
  160. decrypted_credentials = tool_configuration.decrypt(credentials)
  161. return cast(
  162. BuiltinTool,
  163. builtin_tool.fork_tool_runtime(
  164. runtime=ToolRuntime(
  165. tenant_id=tenant_id,
  166. credentials=decrypted_credentials,
  167. runtime_parameters={},
  168. invoke_from=invoke_from,
  169. tool_invoke_from=tool_invoke_from,
  170. )
  171. ),
  172. )
  173. elif provider_type == ToolProviderType.API:
  174. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  175. # decrypt the credentials
  176. tool_configuration = ProviderConfigEncrypter(
  177. tenant_id=tenant_id,
  178. config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
  179. provider_type=api_provider.provider_type.value,
  180. provider_identity=api_provider.entity.identity.name,
  181. )
  182. decrypted_credentials = tool_configuration.decrypt(credentials)
  183. return cast(
  184. ApiTool,
  185. api_provider.get_tool(tool_name).fork_tool_runtime(
  186. runtime=ToolRuntime(
  187. tenant_id=tenant_id,
  188. credentials=decrypted_credentials,
  189. invoke_from=invoke_from,
  190. tool_invoke_from=tool_invoke_from,
  191. )
  192. ),
  193. )
  194. elif provider_type == ToolProviderType.WORKFLOW:
  195. workflow_provider = (
  196. db.session.query(WorkflowToolProvider)
  197. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  198. .first()
  199. )
  200. if workflow_provider is None:
  201. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  202. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  203. return cast(
  204. WorkflowTool,
  205. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  206. runtime=ToolRuntime(
  207. tenant_id=tenant_id,
  208. credentials={},
  209. invoke_from=invoke_from,
  210. tool_invoke_from=tool_invoke_from,
  211. )
  212. ),
  213. )
  214. elif provider_type == ToolProviderType.APP:
  215. raise NotImplementedError("app provider not implemented")
  216. else:
  217. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  218. @classmethod
  219. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
  220. """
  221. init runtime parameter
  222. """
  223. parameter_value = parameters.get(parameter_rule.name)
  224. if not parameter_value and parameter_value != 0:
  225. # get default value
  226. parameter_value = parameter_rule.default
  227. if not parameter_value and parameter_rule.required:
  228. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  229. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  230. # check if tool_parameter_config in options
  231. options = [x.value for x in parameter_rule.options]
  232. if parameter_value is not None and parameter_value not in options:
  233. raise ValueError(
  234. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
  235. )
  236. return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type)
  237. @classmethod
  238. def get_agent_tool_runtime(
  239. cls,
  240. tenant_id: str,
  241. app_id: str,
  242. agent_tool: AgentToolEntity,
  243. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  244. ) -> Tool:
  245. """
  246. get the agent tool runtime
  247. """
  248. tool_entity = cls.get_tool_runtime(
  249. provider_type=agent_tool.provider_type,
  250. provider_id=agent_tool.provider_id,
  251. tool_name=agent_tool.tool_name,
  252. tenant_id=tenant_id,
  253. invoke_from=invoke_from,
  254. tool_invoke_from=ToolInvokeFrom.AGENT,
  255. )
  256. runtime_parameters = {}
  257. parameters = tool_entity.get_merged_runtime_parameters()
  258. for parameter in parameters:
  259. # check file types
  260. if parameter.type == ToolParameter.ToolParameterType.FILE:
  261. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  262. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  263. # save tool parameter to tool entity memory
  264. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  265. runtime_parameters[parameter.name] = value
  266. # decrypt runtime parameters
  267. encryption_manager = ToolParameterConfigurationManager(
  268. tenant_id=tenant_id,
  269. tool_runtime=tool_entity,
  270. provider_name=agent_tool.provider_id,
  271. provider_type=agent_tool.provider_type,
  272. identity_id=f"AGENT.{app_id}",
  273. )
  274. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  275. if not tool_entity.runtime:
  276. raise Exception("tool missing runtime")
  277. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  278. return tool_entity
  279. @classmethod
  280. def get_workflow_tool_runtime(
  281. cls,
  282. tenant_id: str,
  283. app_id: str,
  284. node_id: str,
  285. workflow_tool: "ToolEntity",
  286. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  287. ) -> Tool:
  288. """
  289. get the workflow tool runtime
  290. """
  291. tool_runtime = cls.get_tool_runtime(
  292. provider_type=workflow_tool.provider_type,
  293. provider_id=workflow_tool.provider_id,
  294. tool_name=workflow_tool.tool_name,
  295. tenant_id=tenant_id,
  296. invoke_from=invoke_from,
  297. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  298. )
  299. runtime_parameters = {}
  300. parameters = tool_runtime.get_merged_runtime_parameters()
  301. for parameter in parameters:
  302. # save tool parameter to tool entity memory
  303. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  304. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  305. runtime_parameters[parameter.name] = value
  306. # decrypt runtime parameters
  307. encryption_manager = ToolParameterConfigurationManager(
  308. tenant_id=tenant_id,
  309. tool_runtime=tool_runtime,
  310. provider_name=workflow_tool.provider_id,
  311. provider_type=workflow_tool.provider_type,
  312. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  313. )
  314. if runtime_parameters:
  315. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  316. if not tool_runtime.runtime:
  317. raise Exception("tool missing runtime")
  318. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  319. return tool_runtime
  320. @classmethod
  321. def get_tool_runtime_from_plugin(
  322. cls,
  323. tool_type: ToolProviderType,
  324. tenant_id: str,
  325. provider: str,
  326. tool_name: str,
  327. tool_parameters: dict[str, Any],
  328. ) -> Tool:
  329. """
  330. get tool runtime from plugin
  331. """
  332. tool_entity = cls.get_tool_runtime(
  333. provider_type=tool_type,
  334. provider_id=provider,
  335. tool_name=tool_name,
  336. tenant_id=tenant_id,
  337. invoke_from=InvokeFrom.SERVICE_API,
  338. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  339. )
  340. runtime_parameters = {}
  341. parameters = tool_entity.get_merged_runtime_parameters()
  342. for parameter in parameters:
  343. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  344. # save tool parameter to tool entity memory
  345. value = cls._init_runtime_parameter(parameter, tool_parameters)
  346. runtime_parameters[parameter.name] = value
  347. if not tool_entity.runtime:
  348. raise Exception("tool missing runtime")
  349. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  350. return tool_entity
  351. @classmethod
  352. def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
  353. """
  354. get the absolute path of the icon of the builtin provider
  355. :param provider: the name of the provider
  356. :param tenant_id: the id of the tenant
  357. :return: the absolute path of the icon, the mime type of the icon
  358. """
  359. # get provider
  360. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  361. absolute_path = path.join(
  362. path.dirname(path.realpath(__file__)),
  363. "builtin_tool",
  364. "providers",
  365. provider,
  366. "_assets",
  367. provider_controller.entity.identity.icon,
  368. )
  369. # check if the icon exists
  370. if not path.exists(absolute_path):
  371. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  372. # get the mime type
  373. mime_type, _ = mimetypes.guess_type(absolute_path)
  374. mime_type = mime_type or "application/octet-stream"
  375. return absolute_path, mime_type
  376. @classmethod
  377. def list_hardcoded_providers(cls):
  378. # use cache first
  379. if cls._builtin_providers_loaded:
  380. yield from list(cls._hardcoded_providers.values())
  381. return
  382. with cls._builtin_provider_lock:
  383. if cls._builtin_providers_loaded:
  384. yield from list(cls._hardcoded_providers.values())
  385. return
  386. yield from cls._list_hardcoded_providers()
  387. @classmethod
  388. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  389. """
  390. list all the plugin providers
  391. """
  392. manager = PluginToolManager()
  393. provider_entities = manager.fetch_tool_providers(tenant_id)
  394. return [
  395. PluginToolProviderController(
  396. entity=provider.declaration,
  397. plugin_id=provider.plugin_id,
  398. tenant_id=tenant_id,
  399. )
  400. for provider in provider_entities
  401. ]
  402. @classmethod
  403. def list_builtin_providers(
  404. cls, tenant_id: str
  405. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  406. """
  407. list all the builtin providers
  408. """
  409. yield from cls.list_hardcoded_providers()
  410. # get plugin providers
  411. yield from cls.list_plugin_providers(tenant_id)
  412. @classmethod
  413. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  414. """
  415. list all the builtin providers
  416. """
  417. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  418. if provider_path.startswith("__"):
  419. continue
  420. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  421. if provider_path.startswith("__"):
  422. continue
  423. # init provider
  424. try:
  425. provider_class = load_single_subclass_from_source(
  426. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  427. script_path=path.join(
  428. path.dirname(path.realpath(__file__)),
  429. "builtin_tool",
  430. "providers",
  431. provider_path,
  432. f"{provider_path}.py",
  433. ),
  434. parent_type=BuiltinToolProviderController,
  435. )
  436. provider: BuiltinToolProviderController = provider_class()
  437. cls._hardcoded_providers[provider.entity.identity.name] = provider
  438. for tool in provider.get_tools():
  439. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  440. yield provider
  441. except Exception as e:
  442. logger.error(f"load builtin provider error: {e}")
  443. continue
  444. # set builtin providers loaded
  445. cls._builtin_providers_loaded = True
  446. @classmethod
  447. def load_hardcoded_providers_cache(cls):
  448. for _ in cls.list_hardcoded_providers():
  449. pass
  450. @classmethod
  451. def clear_hardcoded_providers_cache(cls):
  452. cls._hardcoded_providers = {}
  453. cls._builtin_providers_loaded = False
  454. @classmethod
  455. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  456. """
  457. get the tool label
  458. :param tool_name: the name of the tool
  459. :return: the label of the tool
  460. """
  461. if len(cls._builtin_tools_labels) == 0:
  462. # init the builtin providers
  463. cls.load_hardcoded_providers_cache()
  464. if tool_name not in cls._builtin_tools_labels:
  465. return None
  466. return cls._builtin_tools_labels[tool_name]
  467. @classmethod
  468. def list_providers_from_api(
  469. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  470. ) -> list[ToolProviderApiEntity]:
  471. result_providers: dict[str, ToolProviderApiEntity] = {}
  472. filters = []
  473. if not typ:
  474. filters.extend(["builtin", "api", "workflow"])
  475. else:
  476. filters.append(typ)
  477. if "builtin" in filters:
  478. # get builtin providers
  479. builtin_providers = cls.list_builtin_providers(tenant_id)
  480. # get db builtin providers
  481. db_builtin_providers: list[BuiltinToolProvider] = (
  482. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  483. )
  484. # rewrite db_builtin_providers
  485. for db_provider in db_builtin_providers:
  486. try:
  487. ToolProviderID(db_provider.provider)
  488. except Exception:
  489. db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
  490. find_db_builtin_provider = lambda provider: next(
  491. (x for x in db_builtin_providers if x.provider == provider), None
  492. )
  493. # append builtin providers
  494. for provider in builtin_providers:
  495. # handle include, exclude
  496. if is_filtered(
  497. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  498. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  499. data=provider,
  500. name_func=lambda x: x.identity.name,
  501. ):
  502. continue
  503. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  504. provider_controller=provider,
  505. db_provider=find_db_builtin_provider(provider.entity.identity.name),
  506. decrypt_credentials=False,
  507. )
  508. if isinstance(provider, PluginToolProviderController):
  509. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  510. else:
  511. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  512. # get db api providers
  513. if "api" in filters:
  514. db_api_providers: list[ApiToolProvider] = (
  515. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  516. )
  517. api_provider_controllers = [
  518. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  519. for provider in db_api_providers
  520. ]
  521. # get labels
  522. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  523. for api_provider_controller in api_provider_controllers:
  524. user_provider = ToolTransformService.api_provider_to_user_provider(
  525. provider_controller=api_provider_controller["controller"],
  526. db_provider=api_provider_controller["provider"],
  527. decrypt_credentials=False,
  528. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  529. )
  530. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  531. if "workflow" in filters:
  532. # get workflow providers
  533. workflow_providers: list[WorkflowToolProvider] = (
  534. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  535. )
  536. workflow_provider_controllers = []
  537. for provider in workflow_providers:
  538. try:
  539. workflow_provider_controllers.append(
  540. ToolTransformService.workflow_provider_to_controller(db_provider=provider)
  541. )
  542. except Exception as e:
  543. # app has been deleted
  544. pass
  545. labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
  546. for provider_controller in workflow_provider_controllers:
  547. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  548. provider_controller=provider_controller,
  549. labels=labels.get(provider_controller.provider_id, []),
  550. )
  551. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  552. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  553. @classmethod
  554. def get_api_provider_controller(
  555. cls, tenant_id: str, provider_id: str
  556. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  557. """
  558. get the api provider
  559. :param provider_name: the name of the provider
  560. :return: the provider controller, the credentials
  561. """
  562. provider: ApiToolProvider | None = (
  563. db.session.query(ApiToolProvider)
  564. .filter(
  565. ApiToolProvider.id == provider_id,
  566. ApiToolProvider.tenant_id == tenant_id,
  567. )
  568. .first()
  569. )
  570. if provider is None:
  571. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  572. controller = ApiToolProviderController.from_db(
  573. provider,
  574. ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  575. )
  576. controller.load_bundled_tools(provider.tools)
  577. return controller, provider.credentials
  578. @classmethod
  579. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  580. """
  581. get api provider
  582. """
  583. provider_obj: ApiToolProvider | None = (
  584. db.session.query(ApiToolProvider)
  585. .filter(
  586. ApiToolProvider.tenant_id == tenant_id,
  587. ApiToolProvider.name == provider,
  588. )
  589. .first()
  590. )
  591. if provider_obj is None:
  592. raise ValueError(f"you have not added provider {provider}")
  593. try:
  594. credentials = json.loads(provider_obj.credentials_str) or {}
  595. except:
  596. credentials = {}
  597. # package tool provider controller
  598. controller = ApiToolProviderController.from_db(
  599. provider_obj,
  600. ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  601. )
  602. # init tool configuration
  603. tool_configuration = ProviderConfigEncrypter(
  604. tenant_id=tenant_id,
  605. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  606. provider_type=controller.provider_type.value,
  607. provider_identity=controller.entity.identity.name,
  608. )
  609. decrypted_credentials = tool_configuration.decrypt(credentials)
  610. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  611. try:
  612. icon = json.loads(provider_obj.icon)
  613. except:
  614. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  615. # add tool labels
  616. labels = ToolLabelManager.get_tool_labels(controller)
  617. return jsonable_encoder(
  618. {
  619. "schema_type": provider_obj.schema_type,
  620. "schema": provider_obj.schema,
  621. "tools": provider_obj.tools,
  622. "icon": icon,
  623. "description": provider_obj.description,
  624. "credentials": masked_credentials,
  625. "privacy_policy": provider_obj.privacy_policy,
  626. "custom_disclaimer": provider_obj.custom_disclaimer,
  627. "labels": labels,
  628. }
  629. )
  630. @classmethod
  631. def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
  632. """
  633. get the tool icon
  634. :param tenant_id: the id of the tenant
  635. :param provider_type: the type of the provider
  636. :param provider_id: the id of the provider
  637. :return:
  638. """
  639. provider_type = provider_type
  640. provider_id = provider_id
  641. if provider_type == ToolProviderType.BUILT_IN:
  642. return (
  643. dify_config.CONSOLE_API_URL
  644. + "/console/api/workspaces/current/tool-provider/builtin/"
  645. + provider_id
  646. + "/icon"
  647. )
  648. elif provider_type == ToolProviderType.API:
  649. try:
  650. api_provider: ApiToolProvider | None = (
  651. db.session.query(ApiToolProvider)
  652. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  653. .first()
  654. )
  655. if not api_provider:
  656. raise ValueError("api tool not found")
  657. return json.loads(api_provider.icon)
  658. except:
  659. return {"background": "#252525", "content": "\ud83d\ude01"}
  660. elif provider_type == ToolProviderType.WORKFLOW:
  661. workflow_provider: WorkflowToolProvider | None = (
  662. db.session.query(WorkflowToolProvider)
  663. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  664. .first()
  665. )
  666. if workflow_provider is None:
  667. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  668. return json.loads(workflow_provider.icon)
  669. else:
  670. raise ValueError(f"provider type {provider_type} not found")
  671. ToolManager.load_hardcoded_providers_cache()