tool_manager.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808
  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Union, cast
  8. from core.plugin.entities.plugin import GenericProviderID
  9. from core.plugin.manager.tool import PluginToolManager
  10. from core.tools.__base.tool_runtime import ToolRuntime
  11. from core.tools.plugin_tool.provider import PluginToolProviderController
  12. from core.tools.plugin_tool.tool import PluginTool
  13. if TYPE_CHECKING:
  14. from core.workflow.nodes.tool.entities import ToolEntity
  15. from configs import dify_config
  16. from core.agent.entities import AgentToolEntity
  17. from core.app.entities.app_invoke_entities import InvokeFrom
  18. from core.helper.module_import_helper import load_single_subclass_from_source
  19. from core.helper.position_helper import is_filtered
  20. from core.model_runtime.utils.encoders import jsonable_encoder
  21. from core.tools.__base.tool import Tool
  22. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  23. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  24. from core.tools.builtin_tool.tool import BuiltinTool
  25. from core.tools.custom_tool.provider import ApiToolProviderController
  26. from core.tools.custom_tool.tool import ApiTool
  27. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  28. from core.tools.entities.common_entities import I18nObject
  29. from core.tools.entities.tool_entities import (
  30. ApiProviderAuthType,
  31. ToolInvokeFrom,
  32. ToolParameter,
  33. ToolProviderType,
  34. )
  35. from core.tools.errors import ToolProviderNotFoundError
  36. from core.tools.tool_label_manager import ToolLabelManager
  37. from core.tools.utils.configuration import (
  38. ProviderConfigEncrypter,
  39. ToolParameterConfigurationManager,
  40. )
  41. from core.tools.workflow_as_tool.tool import WorkflowTool
  42. from extensions.ext_database import db
  43. from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
  44. from services.tools.tools_transform_service import ToolTransformService
  45. logger = logging.getLogger(__name__)
  46. class ToolManager:
  47. _builtin_provider_lock = Lock()
  48. _hardcoded_providers = {}
  49. _builtin_providers_loaded = False
  50. _builtin_tools_labels = {}
  51. @classmethod
  52. def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
  53. """
  54. get the hardcoded provider
  55. """
  56. if len(cls._hardcoded_providers) == 0:
  57. # init the builtin providers
  58. cls.load_hardcoded_providers_cache()
  59. return cls._hardcoded_providers[provider]
  60. @classmethod
  61. def get_builtin_provider(
  62. cls, provider: str, tenant_id: str
  63. ) -> BuiltinToolProviderController | PluginToolProviderController:
  64. """
  65. get the builtin provider
  66. :param provider: the name of the provider
  67. :param tenant_id: the id of the tenant
  68. :return: the provider
  69. """
  70. # split provider to
  71. if len(cls._hardcoded_providers) == 0:
  72. # init the builtin providers
  73. cls.load_hardcoded_providers_cache()
  74. if provider not in cls._hardcoded_providers:
  75. # get plugin provider
  76. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  77. if plugin_provider:
  78. return plugin_provider
  79. return cls._hardcoded_providers[provider]
  80. @classmethod
  81. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  82. """
  83. get the plugin provider
  84. """
  85. manager = PluginToolManager()
  86. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  87. if not provider_entity:
  88. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  89. return PluginToolProviderController(
  90. entity=provider_entity.declaration,
  91. plugin_id=provider_entity.plugin_id,
  92. tenant_id=tenant_id,
  93. )
  94. @classmethod
  95. def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
  96. """
  97. get the builtin tool
  98. :param provider: the name of the provider
  99. :param tool_name: the name of the tool
  100. :param tenant_id: the id of the tenant
  101. :return: the provider, the tool
  102. """
  103. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  104. tool = provider_controller.get_tool(tool_name)
  105. return tool
  106. @classmethod
  107. def get_tool_runtime(
  108. cls,
  109. provider_type: ToolProviderType,
  110. provider_id: str,
  111. tool_name: str,
  112. tenant_id: str,
  113. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  114. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  115. ) -> Union[BuiltinTool, ApiTool, WorkflowTool]:
  116. """
  117. get the tool runtime
  118. :param provider_type: the type of the provider
  119. :param provider_name: the name of the provider
  120. :param tool_name: the name of the tool
  121. :return: the tool
  122. """
  123. if provider_type == ToolProviderType.BUILT_IN:
  124. # check if the builtin tool need credentials
  125. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  126. builtin_tool = provider_controller.get_tool(tool_name)
  127. if not builtin_tool:
  128. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  129. if not provider_controller.need_credentials:
  130. return cast(
  131. BuiltinTool,
  132. builtin_tool.fork_tool_runtime(
  133. runtime=ToolRuntime(
  134. tenant_id=tenant_id,
  135. credentials={},
  136. invoke_from=invoke_from,
  137. tool_invoke_from=tool_invoke_from,
  138. )
  139. ),
  140. )
  141. if isinstance(provider_controller, PluginToolProviderController):
  142. provider_id_entity = GenericProviderID(provider_id)
  143. # get credentials
  144. builtin_provider: BuiltinToolProvider | None = (
  145. db.session.query(BuiltinToolProvider)
  146. .filter(
  147. BuiltinToolProvider.tenant_id == tenant_id,
  148. (BuiltinToolProvider.provider == provider_id)
  149. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  150. )
  151. .first()
  152. )
  153. if builtin_provider is None:
  154. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  155. else:
  156. builtin_provider: BuiltinToolProvider | None = (
  157. db.session.query(BuiltinToolProvider)
  158. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  159. .first()
  160. )
  161. if builtin_provider is None:
  162. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  163. # decrypt the credentials
  164. credentials = builtin_provider.credentials
  165. tool_configuration = ProviderConfigEncrypter(
  166. tenant_id=tenant_id,
  167. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  168. provider_type=provider_controller.provider_type.value,
  169. provider_identity=provider_controller.entity.identity.name,
  170. )
  171. decrypted_credentials = tool_configuration.decrypt(credentials)
  172. return cast(
  173. BuiltinTool,
  174. builtin_tool.fork_tool_runtime(
  175. runtime=ToolRuntime(
  176. tenant_id=tenant_id,
  177. credentials=decrypted_credentials,
  178. runtime_parameters={},
  179. invoke_from=invoke_from,
  180. tool_invoke_from=tool_invoke_from,
  181. )
  182. ),
  183. )
  184. elif provider_type == ToolProviderType.API:
  185. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  186. # decrypt the credentials
  187. tool_configuration = ProviderConfigEncrypter(
  188. tenant_id=tenant_id,
  189. config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
  190. provider_type=api_provider.provider_type.value,
  191. provider_identity=api_provider.entity.identity.name,
  192. )
  193. decrypted_credentials = tool_configuration.decrypt(credentials)
  194. return cast(
  195. ApiTool,
  196. api_provider.get_tool(tool_name).fork_tool_runtime(
  197. runtime=ToolRuntime(
  198. tenant_id=tenant_id,
  199. credentials=decrypted_credentials,
  200. invoke_from=invoke_from,
  201. tool_invoke_from=tool_invoke_from,
  202. )
  203. ),
  204. )
  205. elif provider_type == ToolProviderType.WORKFLOW:
  206. workflow_provider = (
  207. db.session.query(WorkflowToolProvider)
  208. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  209. .first()
  210. )
  211. if workflow_provider is None:
  212. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  213. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  214. return cast(
  215. WorkflowTool,
  216. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  217. runtime=ToolRuntime(
  218. tenant_id=tenant_id,
  219. credentials={},
  220. invoke_from=invoke_from,
  221. tool_invoke_from=tool_invoke_from,
  222. )
  223. ),
  224. )
  225. elif provider_type == ToolProviderType.APP:
  226. raise NotImplementedError("app provider not implemented")
  227. else:
  228. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  229. @classmethod
  230. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
  231. """
  232. init runtime parameter
  233. """
  234. parameter_value = parameters.get(parameter_rule.name)
  235. if not parameter_value and parameter_value != 0:
  236. # get default value
  237. parameter_value = parameter_rule.default
  238. if not parameter_value and parameter_rule.required:
  239. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  240. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  241. # check if tool_parameter_config in options
  242. options = [x.value for x in parameter_rule.options]
  243. if parameter_value is not None and parameter_value not in options:
  244. raise ValueError(
  245. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
  246. )
  247. return parameter_rule.type.cast_value(parameter_value)
  248. @classmethod
  249. def get_agent_tool_runtime(
  250. cls,
  251. tenant_id: str,
  252. app_id: str,
  253. agent_tool: AgentToolEntity,
  254. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  255. ) -> Tool:
  256. """
  257. get the agent tool runtime
  258. """
  259. tool_entity = cls.get_tool_runtime(
  260. provider_type=agent_tool.provider_type,
  261. provider_id=agent_tool.provider_id,
  262. tool_name=agent_tool.tool_name,
  263. tenant_id=tenant_id,
  264. invoke_from=invoke_from,
  265. tool_invoke_from=ToolInvokeFrom.AGENT,
  266. )
  267. runtime_parameters = {}
  268. parameters = tool_entity.get_merged_runtime_parameters()
  269. for parameter in parameters:
  270. # check file types
  271. if (
  272. parameter.type
  273. in {
  274. ToolParameter.ToolParameterType.SYSTEM_FILES,
  275. ToolParameter.ToolParameterType.FILE,
  276. ToolParameter.ToolParameterType.FILES,
  277. }
  278. and parameter.required
  279. ):
  280. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  281. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  282. # save tool parameter to tool entity memory
  283. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  284. runtime_parameters[parameter.name] = value
  285. # decrypt runtime parameters
  286. encryption_manager = ToolParameterConfigurationManager(
  287. tenant_id=tenant_id,
  288. tool_runtime=tool_entity,
  289. provider_name=agent_tool.provider_id,
  290. provider_type=agent_tool.provider_type,
  291. identity_id=f"AGENT.{app_id}",
  292. )
  293. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  294. if not tool_entity.runtime:
  295. raise Exception("tool missing runtime")
  296. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  297. return tool_entity
  298. @classmethod
  299. def get_workflow_tool_runtime(
  300. cls,
  301. tenant_id: str,
  302. app_id: str,
  303. node_id: str,
  304. workflow_tool: "ToolEntity",
  305. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  306. ) -> Tool:
  307. """
  308. get the workflow tool runtime
  309. """
  310. tool_runtime = cls.get_tool_runtime(
  311. provider_type=workflow_tool.provider_type,
  312. provider_id=workflow_tool.provider_id,
  313. tool_name=workflow_tool.tool_name,
  314. tenant_id=tenant_id,
  315. invoke_from=invoke_from,
  316. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  317. )
  318. runtime_parameters = {}
  319. parameters = tool_runtime.get_merged_runtime_parameters()
  320. for parameter in parameters:
  321. # save tool parameter to tool entity memory
  322. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  323. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  324. runtime_parameters[parameter.name] = value
  325. # decrypt runtime parameters
  326. encryption_manager = ToolParameterConfigurationManager(
  327. tenant_id=tenant_id,
  328. tool_runtime=tool_runtime,
  329. provider_name=workflow_tool.provider_id,
  330. provider_type=workflow_tool.provider_type,
  331. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  332. )
  333. if runtime_parameters:
  334. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  335. if not tool_runtime.runtime:
  336. raise Exception("tool missing runtime")
  337. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  338. return tool_runtime
  339. @classmethod
  340. def get_tool_runtime_from_plugin(
  341. cls,
  342. tool_type: ToolProviderType,
  343. tenant_id: str,
  344. provider: str,
  345. tool_name: str,
  346. tool_parameters: dict[str, Any],
  347. ) -> Tool:
  348. """
  349. get tool runtime from plugin
  350. """
  351. tool_entity = cls.get_tool_runtime(
  352. provider_type=tool_type,
  353. provider_id=provider,
  354. tool_name=tool_name,
  355. tenant_id=tenant_id,
  356. invoke_from=InvokeFrom.SERVICE_API,
  357. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  358. )
  359. runtime_parameters = {}
  360. parameters = tool_entity.get_merged_runtime_parameters()
  361. for parameter in parameters:
  362. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  363. # save tool parameter to tool entity memory
  364. value = cls._init_runtime_parameter(parameter, tool_parameters)
  365. runtime_parameters[parameter.name] = value
  366. if not tool_entity.runtime:
  367. raise Exception("tool missing runtime")
  368. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  369. return tool_entity
  370. @classmethod
  371. def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
  372. """
  373. get the absolute path of the icon of the hardcoded provider
  374. :param provider: the name of the provider
  375. :param tenant_id: the id of the tenant
  376. :return: the absolute path of the icon, the mime type of the icon
  377. """
  378. # get provider
  379. provider_controller = cls.get_hardcoded_provider(provider)
  380. absolute_path = path.join(
  381. path.dirname(path.realpath(__file__)),
  382. "builtin_tool",
  383. "providers",
  384. provider,
  385. "_assets",
  386. provider_controller.entity.identity.icon,
  387. )
  388. # check if the icon exists
  389. if not path.exists(absolute_path):
  390. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  391. # get the mime type
  392. mime_type, _ = mimetypes.guess_type(absolute_path)
  393. mime_type = mime_type or "application/octet-stream"
  394. return absolute_path, mime_type
  395. @classmethod
  396. def list_hardcoded_providers(cls):
  397. # use cache first
  398. if cls._builtin_providers_loaded:
  399. yield from list(cls._hardcoded_providers.values())
  400. return
  401. with cls._builtin_provider_lock:
  402. if cls._builtin_providers_loaded:
  403. yield from list(cls._hardcoded_providers.values())
  404. return
  405. yield from cls._list_hardcoded_providers()
  406. @classmethod
  407. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  408. """
  409. list all the plugin providers
  410. """
  411. manager = PluginToolManager()
  412. provider_entities = manager.fetch_tool_providers(tenant_id)
  413. return [
  414. PluginToolProviderController(
  415. entity=provider.declaration,
  416. plugin_id=provider.plugin_id,
  417. tenant_id=tenant_id,
  418. )
  419. for provider in provider_entities
  420. ]
  421. @classmethod
  422. def list_builtin_providers(
  423. cls, tenant_id: str
  424. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  425. """
  426. list all the builtin providers
  427. """
  428. yield from cls.list_hardcoded_providers()
  429. # get plugin providers
  430. yield from cls.list_plugin_providers(tenant_id)
  431. @classmethod
  432. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  433. """
  434. list all the builtin providers
  435. """
  436. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  437. if provider_path.startswith("__"):
  438. continue
  439. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  440. if provider_path.startswith("__"):
  441. continue
  442. # init provider
  443. try:
  444. provider_class = load_single_subclass_from_source(
  445. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  446. script_path=path.join(
  447. path.dirname(path.realpath(__file__)),
  448. "builtin_tool",
  449. "providers",
  450. provider_path,
  451. f"{provider_path}.py",
  452. ),
  453. parent_type=BuiltinToolProviderController,
  454. )
  455. provider: BuiltinToolProviderController = provider_class()
  456. cls._hardcoded_providers[provider.entity.identity.name] = provider
  457. for tool in provider.get_tools():
  458. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  459. yield provider
  460. except Exception as e:
  461. logger.exception(f"load builtin provider {provider} error: {e}")
  462. continue
  463. # set builtin providers loaded
  464. cls._builtin_providers_loaded = True
  465. @classmethod
  466. def load_hardcoded_providers_cache(cls):
  467. for _ in cls.list_hardcoded_providers():
  468. pass
  469. @classmethod
  470. def clear_hardcoded_providers_cache(cls):
  471. cls._hardcoded_providers = {}
  472. cls._builtin_providers_loaded = False
  473. @classmethod
  474. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  475. """
  476. get the tool label
  477. :param tool_name: the name of the tool
  478. :return: the label of the tool
  479. """
  480. if len(cls._builtin_tools_labels) == 0:
  481. # init the builtin providers
  482. cls.load_hardcoded_providers_cache()
  483. if tool_name not in cls._builtin_tools_labels:
  484. return None
  485. return cls._builtin_tools_labels[tool_name]
  486. @classmethod
  487. def list_providers_from_api(
  488. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  489. ) -> list[ToolProviderApiEntity]:
  490. result_providers: dict[str, ToolProviderApiEntity] = {}
  491. filters = []
  492. if not typ:
  493. filters.extend(["builtin", "api", "workflow"])
  494. else:
  495. filters.append(typ)
  496. if "builtin" in filters:
  497. # get builtin providers
  498. builtin_providers = cls.list_builtin_providers(tenant_id)
  499. # get db builtin providers
  500. db_builtin_providers: list[BuiltinToolProvider] = (
  501. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  502. )
  503. # rewrite db_builtin_providers
  504. for db_provider in db_builtin_providers:
  505. tool_provider_id = GenericProviderID(db_provider.provider)
  506. db_provider.provider = tool_provider_id.to_string()
  507. find_db_builtin_provider = lambda provider: next(
  508. (x for x in db_builtin_providers if x.provider == provider), None
  509. )
  510. # append builtin providers
  511. for provider in builtin_providers:
  512. # handle include, exclude
  513. if is_filtered(
  514. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  515. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  516. data=provider,
  517. name_func=lambda x: x.identity.name,
  518. ):
  519. continue
  520. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  521. provider_controller=provider,
  522. db_provider=find_db_builtin_provider(provider.entity.identity.name),
  523. decrypt_credentials=False,
  524. )
  525. if isinstance(provider, PluginToolProviderController):
  526. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  527. else:
  528. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  529. # get db api providers
  530. if "api" in filters:
  531. db_api_providers: list[ApiToolProvider] = (
  532. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  533. )
  534. api_provider_controllers = [
  535. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  536. for provider in db_api_providers
  537. ]
  538. # get labels
  539. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  540. for api_provider_controller in api_provider_controllers:
  541. user_provider = ToolTransformService.api_provider_to_user_provider(
  542. provider_controller=api_provider_controller["controller"],
  543. db_provider=api_provider_controller["provider"],
  544. decrypt_credentials=False,
  545. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  546. )
  547. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  548. if "workflow" in filters:
  549. # get workflow providers
  550. workflow_providers: list[WorkflowToolProvider] = (
  551. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  552. )
  553. workflow_provider_controllers = []
  554. for provider in workflow_providers:
  555. try:
  556. workflow_provider_controllers.append(
  557. ToolTransformService.workflow_provider_to_controller(db_provider=provider)
  558. )
  559. except Exception as e:
  560. # app has been deleted
  561. pass
  562. labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
  563. for provider_controller in workflow_provider_controllers:
  564. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  565. provider_controller=provider_controller,
  566. labels=labels.get(provider_controller.provider_id, []),
  567. )
  568. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  569. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  570. @classmethod
  571. def get_api_provider_controller(
  572. cls, tenant_id: str, provider_id: str
  573. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  574. """
  575. get the api provider
  576. :param provider_name: the name of the provider
  577. :return: the provider controller, the credentials
  578. """
  579. provider: ApiToolProvider | None = (
  580. db.session.query(ApiToolProvider)
  581. .filter(
  582. ApiToolProvider.id == provider_id,
  583. ApiToolProvider.tenant_id == tenant_id,
  584. )
  585. .first()
  586. )
  587. if provider is None:
  588. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  589. controller = ApiToolProviderController.from_db(
  590. provider,
  591. ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  592. )
  593. controller.load_bundled_tools(provider.tools)
  594. return controller, provider.credentials
  595. @classmethod
  596. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  597. """
  598. get api provider
  599. """
  600. """
  601. get tool provider
  602. """
  603. provider_name = provider
  604. provider_obj: ApiToolProvider = (
  605. db.session.query(ApiToolProvider)
  606. .filter(
  607. ApiToolProvider.tenant_id == tenant_id,
  608. ApiToolProvider.name == provider,
  609. )
  610. .first()
  611. )
  612. if provider_obj is None:
  613. raise ValueError(f"you have not added provider {provider_name}")
  614. try:
  615. credentials = json.loads(provider_obj.credentials_str) or {}
  616. except:
  617. credentials = {}
  618. # package tool provider controller
  619. controller = ApiToolProviderController.from_db(
  620. provider_obj,
  621. ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  622. )
  623. # init tool configuration
  624. tool_configuration = ProviderConfigEncrypter(
  625. tenant_id=tenant_id,
  626. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  627. provider_type=controller.provider_type.value,
  628. provider_identity=controller.entity.identity.name,
  629. )
  630. decrypted_credentials = tool_configuration.decrypt(credentials)
  631. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  632. try:
  633. icon = json.loads(provider_obj.icon)
  634. except:
  635. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  636. # add tool labels
  637. labels = ToolLabelManager.get_tool_labels(controller)
  638. return jsonable_encoder(
  639. {
  640. "schema_type": provider_obj.schema_type,
  641. "schema": provider_obj.schema,
  642. "tools": provider_obj.tools,
  643. "icon": icon,
  644. "description": provider_obj.description,
  645. "credentials": masked_credentials,
  646. "privacy_policy": provider_obj.privacy_policy,
  647. "custom_disclaimer": provider_obj.custom_disclaimer,
  648. "labels": labels,
  649. }
  650. )
  651. @classmethod
  652. def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
  653. """
  654. get the tool icon
  655. :param tenant_id: the id of the tenant
  656. :param provider_type: the type of the provider
  657. :param provider_id: the id of the provider
  658. :return:
  659. """
  660. provider_type = provider_type
  661. provider_id = provider_id
  662. if provider_type == ToolProviderType.BUILT_IN:
  663. return (
  664. dify_config.CONSOLE_API_URL
  665. + "/console/api/workspaces/current/tool-provider/builtin/"
  666. + provider_id
  667. + "/icon"
  668. )
  669. elif provider_type == ToolProviderType.API:
  670. try:
  671. api_provider: ApiToolProvider | None = (
  672. db.session.query(ApiToolProvider)
  673. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  674. .first()
  675. )
  676. if not api_provider:
  677. raise ValueError("api tool not found")
  678. return json.loads(api_provider.icon)
  679. except:
  680. return {"background": "#252525", "content": "\ud83d\ude01"}
  681. elif provider_type == ToolProviderType.WORKFLOW:
  682. workflow_provider: WorkflowToolProvider | None = (
  683. db.session.query(WorkflowToolProvider)
  684. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  685. .first()
  686. )
  687. if workflow_provider is None:
  688. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  689. return json.loads(workflow_provider.icon)
  690. else:
  691. raise ValueError(f"provider type {provider_type} not found")
  692. ToolManager.load_hardcoded_providers_cache()