tool_manager.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806
  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Union, cast
  8. from core.plugin.manager.tool import PluginToolManager
  9. from core.tools.__base.tool_runtime import ToolRuntime
  10. from core.tools.plugin_tool.provider import PluginToolProviderController
  11. from core.tools.plugin_tool.tool import PluginTool
  12. if TYPE_CHECKING:
  13. from core.workflow.nodes.tool.entities import ToolEntity
  14. from configs import dify_config
  15. from core.agent.entities import AgentToolEntity
  16. from core.app.entities.app_invoke_entities import InvokeFrom
  17. from core.helper.module_import_helper import load_single_subclass_from_source
  18. from core.helper.position_helper import is_filtered
  19. from core.model_runtime.utils.encoders import jsonable_encoder
  20. from core.tools.__base.tool import Tool
  21. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  22. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  23. from core.tools.builtin_tool.tool import BuiltinTool
  24. from core.tools.custom_tool.provider import ApiToolProviderController
  25. from core.tools.custom_tool.tool import ApiTool
  26. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  27. from core.tools.entities.common_entities import I18nObject
  28. from core.tools.entities.tool_entities import (
  29. ApiProviderAuthType,
  30. ToolInvokeFrom,
  31. ToolParameter,
  32. ToolProviderID,
  33. ToolProviderType,
  34. )
  35. from core.tools.errors import ToolProviderNotFoundError
  36. from core.tools.tool_label_manager import ToolLabelManager
  37. from core.tools.utils.configuration import (
  38. ProviderConfigEncrypter,
  39. ToolParameterConfigurationManager,
  40. )
  41. from core.tools.workflow_as_tool.tool import WorkflowTool
  42. from extensions.ext_database import db
  43. from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
  44. from services.tools.tools_transform_service import ToolTransformService
  45. logger = logging.getLogger(__name__)
  46. class ToolManager:
  47. _builtin_provider_lock = Lock()
  48. _hardcoded_providers = {}
  49. _builtin_providers_loaded = False
  50. _builtin_tools_labels = {}
  51. @classmethod
  52. def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
  53. """
  54. get the hardcoded provider
  55. """
  56. if len(cls._hardcoded_providers) == 0:
  57. # init the builtin providers
  58. cls.load_hardcoded_providers_cache()
  59. return cls._hardcoded_providers[provider]
  60. @classmethod
  61. def get_builtin_provider(
  62. cls, provider: str, tenant_id: str
  63. ) -> BuiltinToolProviderController | PluginToolProviderController:
  64. """
  65. get the builtin provider
  66. :param provider: the name of the provider
  67. :param tenant_id: the id of the tenant
  68. :return: the provider
  69. """
  70. # split provider to
  71. if len(cls._hardcoded_providers) == 0:
  72. # init the builtin providers
  73. cls.load_hardcoded_providers_cache()
  74. if provider not in cls._hardcoded_providers:
  75. # get plugin provider
  76. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  77. if plugin_provider:
  78. return plugin_provider
  79. return cls._hardcoded_providers[provider]
  80. @classmethod
  81. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  82. """
  83. get the plugin provider
  84. """
  85. manager = PluginToolManager()
  86. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  87. if not provider_entity:
  88. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  89. return PluginToolProviderController(
  90. entity=provider_entity.declaration,
  91. plugin_id=provider_entity.plugin_id,
  92. tenant_id=tenant_id,
  93. )
  94. @classmethod
  95. def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
  96. """
  97. get the builtin tool
  98. :param provider: the name of the provider
  99. :param tool_name: the name of the tool
  100. :param tenant_id: the id of the tenant
  101. :return: the provider, the tool
  102. """
  103. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  104. tool = provider_controller.get_tool(tool_name)
  105. return tool
  106. @classmethod
  107. def get_tool_runtime(
  108. cls,
  109. provider_type: ToolProviderType,
  110. provider_id: str,
  111. tool_name: str,
  112. tenant_id: str,
  113. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  114. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  115. ) -> Union[BuiltinTool, ApiTool, WorkflowTool]:
  116. """
  117. get the tool runtime
  118. :param provider_type: the type of the provider
  119. :param provider_name: the name of the provider
  120. :param tool_name: the name of the tool
  121. :return: the tool
  122. """
  123. if provider_type == ToolProviderType.BUILT_IN:
  124. # check if the builtin tool need credentials
  125. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  126. builtin_tool = provider_controller.get_tool(tool_name)
  127. if not builtin_tool:
  128. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  129. if not provider_controller.need_credentials:
  130. return cast(
  131. BuiltinTool,
  132. builtin_tool.fork_tool_runtime(
  133. runtime=ToolRuntime(
  134. tenant_id=tenant_id,
  135. credentials={},
  136. invoke_from=invoke_from,
  137. tool_invoke_from=tool_invoke_from,
  138. )
  139. ),
  140. )
  141. if isinstance(provider_controller, PluginToolProviderController):
  142. provider_id_entity = ToolProviderID(provider_id)
  143. # get credentials
  144. builtin_provider: BuiltinToolProvider | None = (
  145. db.session.query(BuiltinToolProvider)
  146. .filter(
  147. BuiltinToolProvider.tenant_id == tenant_id,
  148. (BuiltinToolProvider.provider == provider_id)
  149. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  150. )
  151. .first()
  152. )
  153. if builtin_provider is None:
  154. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  155. else:
  156. builtin_provider: BuiltinToolProvider | None = (
  157. db.session.query(BuiltinToolProvider)
  158. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  159. .first()
  160. )
  161. if builtin_provider is None:
  162. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  163. # decrypt the credentials
  164. credentials = builtin_provider.credentials
  165. tool_configuration = ProviderConfigEncrypter(
  166. tenant_id=tenant_id,
  167. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  168. provider_type=provider_controller.provider_type.value,
  169. provider_identity=provider_controller.entity.identity.name,
  170. )
  171. decrypted_credentials = tool_configuration.decrypt(credentials)
  172. return cast(
  173. BuiltinTool,
  174. builtin_tool.fork_tool_runtime(
  175. runtime=ToolRuntime(
  176. tenant_id=tenant_id,
  177. credentials=decrypted_credentials,
  178. runtime_parameters={},
  179. invoke_from=invoke_from,
  180. tool_invoke_from=tool_invoke_from,
  181. )
  182. ),
  183. )
  184. elif provider_type == ToolProviderType.API:
  185. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  186. # decrypt the credentials
  187. tool_configuration = ProviderConfigEncrypter(
  188. tenant_id=tenant_id,
  189. config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
  190. provider_type=api_provider.provider_type.value,
  191. provider_identity=api_provider.entity.identity.name,
  192. )
  193. decrypted_credentials = tool_configuration.decrypt(credentials)
  194. return cast(
  195. ApiTool,
  196. api_provider.get_tool(tool_name).fork_tool_runtime(
  197. runtime=ToolRuntime(
  198. tenant_id=tenant_id,
  199. credentials=decrypted_credentials,
  200. invoke_from=invoke_from,
  201. tool_invoke_from=tool_invoke_from,
  202. )
  203. ),
  204. )
  205. elif provider_type == ToolProviderType.WORKFLOW:
  206. workflow_provider = (
  207. db.session.query(WorkflowToolProvider)
  208. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  209. .first()
  210. )
  211. if workflow_provider is None:
  212. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  213. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  214. return cast(
  215. WorkflowTool,
  216. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  217. runtime=ToolRuntime(
  218. tenant_id=tenant_id,
  219. credentials={},
  220. invoke_from=invoke_from,
  221. tool_invoke_from=tool_invoke_from,
  222. )
  223. ),
  224. )
  225. elif provider_type == ToolProviderType.APP:
  226. raise NotImplementedError("app provider not implemented")
  227. else:
  228. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  229. @classmethod
  230. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
  231. """
  232. init runtime parameter
  233. """
  234. parameter_value = parameters.get(parameter_rule.name)
  235. if not parameter_value and parameter_value != 0:
  236. # get default value
  237. parameter_value = parameter_rule.default
  238. if not parameter_value and parameter_rule.required:
  239. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  240. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  241. # check if tool_parameter_config in options
  242. options = [x.value for x in parameter_rule.options]
  243. if parameter_value is not None and parameter_value not in options:
  244. raise ValueError(
  245. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
  246. )
  247. return parameter_rule.type.cast_value(parameter_value)
  248. @classmethod
  249. def get_agent_tool_runtime(
  250. cls,
  251. tenant_id: str,
  252. app_id: str,
  253. agent_tool: AgentToolEntity,
  254. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  255. ) -> Tool:
  256. """
  257. get the agent tool runtime
  258. """
  259. tool_entity = cls.get_tool_runtime(
  260. provider_type=agent_tool.provider_type,
  261. provider_id=agent_tool.provider_id,
  262. tool_name=agent_tool.tool_name,
  263. tenant_id=tenant_id,
  264. invoke_from=invoke_from,
  265. tool_invoke_from=ToolInvokeFrom.AGENT,
  266. )
  267. runtime_parameters = {}
  268. parameters = tool_entity.get_merged_runtime_parameters()
  269. for parameter in parameters:
  270. # check file types
  271. if (
  272. parameter.type
  273. in {
  274. ToolParameter.ToolParameterType.SYSTEM_FILES,
  275. ToolParameter.ToolParameterType.FILE,
  276. ToolParameter.ToolParameterType.FILES,
  277. }
  278. and parameter.required
  279. ):
  280. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  281. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  282. # save tool parameter to tool entity memory
  283. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  284. runtime_parameters[parameter.name] = value
  285. # decrypt runtime parameters
  286. encryption_manager = ToolParameterConfigurationManager(
  287. tenant_id=tenant_id,
  288. tool_runtime=tool_entity,
  289. provider_name=agent_tool.provider_id,
  290. provider_type=agent_tool.provider_type,
  291. identity_id=f"AGENT.{app_id}",
  292. )
  293. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  294. if not tool_entity.runtime:
  295. raise Exception("tool missing runtime")
  296. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  297. return tool_entity
  298. @classmethod
  299. def get_workflow_tool_runtime(
  300. cls,
  301. tenant_id: str,
  302. app_id: str,
  303. node_id: str,
  304. workflow_tool: "ToolEntity",
  305. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  306. ) -> Tool:
  307. """
  308. get the workflow tool runtime
  309. """
  310. tool_runtime = cls.get_tool_runtime(
  311. provider_type=workflow_tool.provider_type,
  312. provider_id=workflow_tool.provider_id,
  313. tool_name=workflow_tool.tool_name,
  314. tenant_id=tenant_id,
  315. invoke_from=invoke_from,
  316. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  317. )
  318. runtime_parameters = {}
  319. parameters = tool_runtime.get_merged_runtime_parameters()
  320. for parameter in parameters:
  321. # save tool parameter to tool entity memory
  322. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  323. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  324. runtime_parameters[parameter.name] = value
  325. # decrypt runtime parameters
  326. encryption_manager = ToolParameterConfigurationManager(
  327. tenant_id=tenant_id,
  328. tool_runtime=tool_runtime,
  329. provider_name=workflow_tool.provider_id,
  330. provider_type=workflow_tool.provider_type,
  331. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  332. )
  333. if runtime_parameters:
  334. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  335. if not tool_runtime.runtime:
  336. raise Exception("tool missing runtime")
  337. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  338. return tool_runtime
  339. @classmethod
  340. def get_tool_runtime_from_plugin(
  341. cls,
  342. tool_type: ToolProviderType,
  343. tenant_id: str,
  344. provider: str,
  345. tool_name: str,
  346. tool_parameters: dict[str, Any],
  347. ) -> Tool:
  348. """
  349. get tool runtime from plugin
  350. """
  351. tool_entity = cls.get_tool_runtime(
  352. provider_type=tool_type,
  353. provider_id=provider,
  354. tool_name=tool_name,
  355. tenant_id=tenant_id,
  356. invoke_from=InvokeFrom.SERVICE_API,
  357. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  358. )
  359. runtime_parameters = {}
  360. parameters = tool_entity.get_merged_runtime_parameters()
  361. for parameter in parameters:
  362. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  363. # save tool parameter to tool entity memory
  364. value = cls._init_runtime_parameter(parameter, tool_parameters)
  365. runtime_parameters[parameter.name] = value
  366. if not tool_entity.runtime:
  367. raise Exception("tool missing runtime")
  368. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  369. return tool_entity
  370. @classmethod
  371. def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
  372. """
  373. get the absolute path of the icon of the hardcoded provider
  374. :param provider: the name of the provider
  375. :param tenant_id: the id of the tenant
  376. :return: the absolute path of the icon, the mime type of the icon
  377. """
  378. # get provider
  379. provider_controller = cls.get_hardcoded_provider(provider)
  380. absolute_path = path.join(
  381. path.dirname(path.realpath(__file__)),
  382. "builtin_tool",
  383. "providers",
  384. provider,
  385. "_assets",
  386. provider_controller.entity.identity.icon,
  387. )
  388. # check if the icon exists
  389. if not path.exists(absolute_path):
  390. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  391. # get the mime type
  392. mime_type, _ = mimetypes.guess_type(absolute_path)
  393. mime_type = mime_type or "application/octet-stream"
  394. return absolute_path, mime_type
  395. @classmethod
  396. def list_hardcoded_providers(cls):
  397. # use cache first
  398. if cls._builtin_providers_loaded:
  399. yield from list(cls._hardcoded_providers.values())
  400. return
  401. with cls._builtin_provider_lock:
  402. if cls._builtin_providers_loaded:
  403. yield from list(cls._hardcoded_providers.values())
  404. return
  405. yield from cls._list_hardcoded_providers()
  406. @classmethod
  407. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  408. """
  409. list all the plugin providers
  410. """
  411. manager = PluginToolManager()
  412. provider_entities = manager.fetch_tool_providers(tenant_id)
  413. return [
  414. PluginToolProviderController(
  415. entity=provider.declaration,
  416. plugin_id=provider.plugin_id,
  417. tenant_id=tenant_id,
  418. )
  419. for provider in provider_entities
  420. ]
  421. @classmethod
  422. def list_builtin_providers(
  423. cls, tenant_id: str
  424. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  425. """
  426. list all the builtin providers
  427. """
  428. yield from cls.list_hardcoded_providers()
  429. # get plugin providers
  430. yield from cls.list_plugin_providers(tenant_id)
  431. @classmethod
  432. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  433. """
  434. list all the builtin providers
  435. """
  436. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  437. if provider_path.startswith("__"):
  438. continue
  439. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  440. if provider_path.startswith("__"):
  441. continue
  442. # init provider
  443. try:
  444. provider_class = load_single_subclass_from_source(
  445. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  446. script_path=path.join(
  447. path.dirname(path.realpath(__file__)),
  448. "builtin_tool",
  449. "providers",
  450. provider_path,
  451. f"{provider_path}.py",
  452. ),
  453. parent_type=BuiltinToolProviderController,
  454. )
  455. provider: BuiltinToolProviderController = provider_class()
  456. cls._hardcoded_providers[provider.entity.identity.name] = provider
  457. for tool in provider.get_tools():
  458. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  459. yield provider
  460. except Exception as e:
  461. logger.exception(f"load builtin provider {provider} error: {e}")
  462. continue
  463. # set builtin providers loaded
  464. cls._builtin_providers_loaded = True
  465. @classmethod
  466. def load_hardcoded_providers_cache(cls):
  467. for _ in cls.list_hardcoded_providers():
  468. pass
  469. @classmethod
  470. def clear_hardcoded_providers_cache(cls):
  471. cls._hardcoded_providers = {}
  472. cls._builtin_providers_loaded = False
  473. @classmethod
  474. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  475. """
  476. get the tool label
  477. :param tool_name: the name of the tool
  478. :return: the label of the tool
  479. """
  480. if len(cls._builtin_tools_labels) == 0:
  481. # init the builtin providers
  482. cls.load_hardcoded_providers_cache()
  483. if tool_name not in cls._builtin_tools_labels:
  484. return None
  485. return cls._builtin_tools_labels[tool_name]
  486. @classmethod
  487. def list_providers_from_api(
  488. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  489. ) -> list[ToolProviderApiEntity]:
  490. result_providers: dict[str, ToolProviderApiEntity] = {}
  491. filters = []
  492. if not typ:
  493. filters.extend(["builtin", "api", "workflow"])
  494. else:
  495. filters.append(typ)
  496. if "builtin" in filters:
  497. # get builtin providers
  498. builtin_providers = cls.list_builtin_providers(tenant_id)
  499. # get db builtin providers
  500. db_builtin_providers: list[BuiltinToolProvider] = (
  501. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  502. )
  503. # rewrite db_builtin_providers
  504. for db_provider in db_builtin_providers:
  505. try:
  506. ToolProviderID(db_provider.provider)
  507. except Exception:
  508. db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
  509. find_db_builtin_provider = lambda provider: next(
  510. (x for x in db_builtin_providers if x.provider == provider), None
  511. )
  512. # append builtin providers
  513. for provider in builtin_providers:
  514. # handle include, exclude
  515. if is_filtered(
  516. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  517. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  518. data=provider,
  519. name_func=lambda x: x.identity.name,
  520. ):
  521. continue
  522. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  523. provider_controller=provider,
  524. db_provider=find_db_builtin_provider(provider.entity.identity.name),
  525. decrypt_credentials=False,
  526. )
  527. if isinstance(provider, PluginToolProviderController):
  528. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  529. else:
  530. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  531. # get db api providers
  532. if "api" in filters:
  533. db_api_providers: list[ApiToolProvider] = (
  534. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  535. )
  536. api_provider_controllers = [
  537. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  538. for provider in db_api_providers
  539. ]
  540. # get labels
  541. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  542. for api_provider_controller in api_provider_controllers:
  543. user_provider = ToolTransformService.api_provider_to_user_provider(
  544. provider_controller=api_provider_controller["controller"],
  545. db_provider=api_provider_controller["provider"],
  546. decrypt_credentials=False,
  547. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  548. )
  549. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  550. if "workflow" in filters:
  551. # get workflow providers
  552. workflow_providers: list[WorkflowToolProvider] = (
  553. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  554. )
  555. workflow_provider_controllers = []
  556. for provider in workflow_providers:
  557. try:
  558. workflow_provider_controllers.append(
  559. ToolTransformService.workflow_provider_to_controller(db_provider=provider)
  560. )
  561. except Exception as e:
  562. # app has been deleted
  563. pass
  564. labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
  565. for provider_controller in workflow_provider_controllers:
  566. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  567. provider_controller=provider_controller,
  568. labels=labels.get(provider_controller.provider_id, []),
  569. )
  570. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  571. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  572. @classmethod
  573. def get_api_provider_controller(
  574. cls, tenant_id: str, provider_id: str
  575. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  576. """
  577. get the api provider
  578. :param provider_name: the name of the provider
  579. :return: the provider controller, the credentials
  580. """
  581. provider: ApiToolProvider | None = (
  582. db.session.query(ApiToolProvider)
  583. .filter(
  584. ApiToolProvider.id == provider_id,
  585. ApiToolProvider.tenant_id == tenant_id,
  586. )
  587. .first()
  588. )
  589. if provider is None:
  590. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  591. controller = ApiToolProviderController.from_db(
  592. provider,
  593. ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  594. )
  595. controller.load_bundled_tools(provider.tools)
  596. return controller, provider.credentials
  597. @classmethod
  598. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  599. """
  600. get api provider
  601. """
  602. provider_obj: ApiToolProvider | None = (
  603. db.session.query(ApiToolProvider)
  604. .filter(
  605. ApiToolProvider.tenant_id == tenant_id,
  606. ApiToolProvider.name == provider,
  607. )
  608. .first()
  609. )
  610. if provider_obj is None:
  611. raise ValueError(f"you have not added provider {provider}")
  612. try:
  613. credentials = json.loads(provider_obj.credentials_str) or {}
  614. except:
  615. credentials = {}
  616. # package tool provider controller
  617. controller = ApiToolProviderController.from_db(
  618. provider_obj,
  619. ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  620. )
  621. # init tool configuration
  622. tool_configuration = ProviderConfigEncrypter(
  623. tenant_id=tenant_id,
  624. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  625. provider_type=controller.provider_type.value,
  626. provider_identity=controller.entity.identity.name,
  627. )
  628. decrypted_credentials = tool_configuration.decrypt(credentials)
  629. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  630. try:
  631. icon = json.loads(provider_obj.icon)
  632. except:
  633. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  634. # add tool labels
  635. labels = ToolLabelManager.get_tool_labels(controller)
  636. return jsonable_encoder(
  637. {
  638. "schema_type": provider_obj.schema_type,
  639. "schema": provider_obj.schema,
  640. "tools": provider_obj.tools,
  641. "icon": icon,
  642. "description": provider_obj.description,
  643. "credentials": masked_credentials,
  644. "privacy_policy": provider_obj.privacy_policy,
  645. "custom_disclaimer": provider_obj.custom_disclaimer,
  646. "labels": labels,
  647. }
  648. )
  649. @classmethod
  650. def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
  651. """
  652. get the tool icon
  653. :param tenant_id: the id of the tenant
  654. :param provider_type: the type of the provider
  655. :param provider_id: the id of the provider
  656. :return:
  657. """
  658. provider_type = provider_type
  659. provider_id = provider_id
  660. if provider_type == ToolProviderType.BUILT_IN:
  661. return (
  662. dify_config.CONSOLE_API_URL
  663. + "/console/api/workspaces/current/tool-provider/builtin/"
  664. + provider_id
  665. + "/icon"
  666. )
  667. elif provider_type == ToolProviderType.API:
  668. try:
  669. api_provider: ApiToolProvider | None = (
  670. db.session.query(ApiToolProvider)
  671. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  672. .first()
  673. )
  674. if not api_provider:
  675. raise ValueError("api tool not found")
  676. return json.loads(api_provider.icon)
  677. except:
  678. return {"background": "#252525", "content": "\ud83d\ude01"}
  679. elif provider_type == ToolProviderType.WORKFLOW:
  680. workflow_provider: WorkflowToolProvider | None = (
  681. db.session.query(WorkflowToolProvider)
  682. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  683. .first()
  684. )
  685. if workflow_provider is None:
  686. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  687. return json.loads(workflow_provider.icon)
  688. else:
  689. raise ValueError(f"provider type {provider_type} not found")
  690. ToolManager.load_hardcoded_providers_cache()