tool_manager.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749
  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Union, cast
  8. from core.plugin.manager.tool import PluginToolManager
  9. from core.tools.__base.tool_runtime import ToolRuntime
  10. from core.tools.plugin_tool.provider import PluginToolProviderController
  11. from core.tools.plugin_tool.tool import PluginTool
  12. if TYPE_CHECKING:
  13. from core.workflow.nodes.tool.entities import ToolEntity
  14. from configs import dify_config
  15. from core.agent.entities import AgentToolEntity
  16. from core.app.entities.app_invoke_entities import InvokeFrom
  17. from core.helper.module_import_helper import load_single_subclass_from_source
  18. from core.helper.position_helper import is_filtered
  19. from core.model_runtime.utils.encoders import jsonable_encoder
  20. from core.tools.__base.tool import Tool
  21. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  22. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  23. from core.tools.builtin_tool.tool import BuiltinTool
  24. from core.tools.custom_tool.provider import ApiToolProviderController
  25. from core.tools.custom_tool.tool import ApiTool
  26. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  27. from core.tools.entities.common_entities import I18nObject
  28. from core.tools.entities.tool_entities import (
  29. ApiProviderAuthType,
  30. ToolInvokeFrom,
  31. ToolParameter,
  32. ToolProviderID,
  33. ToolProviderType,
  34. )
  35. from core.tools.errors import ToolProviderNotFoundError
  36. from core.tools.tool_label_manager import ToolLabelManager
  37. from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
  38. from core.tools.utils.tool_parameter_converter import ToolParameterConverter
  39. from core.tools.workflow_as_tool.tool import WorkflowTool
  40. from extensions.ext_database import db
  41. from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
  42. from services.tools.tools_transform_service import ToolTransformService
  43. logger = logging.getLogger(__name__)
  44. class ToolManager:
  45. _builtin_provider_lock = Lock()
  46. _hardcoded_providers = {}
  47. _builtin_providers_loaded = False
  48. _builtin_tools_labels = {}
  49. @classmethod
  50. def get_builtin_provider(
  51. cls, provider: str, tenant_id: str
  52. ) -> BuiltinToolProviderController | PluginToolProviderController:
  53. """
  54. get the builtin provider
  55. :param provider: the name of the provider
  56. :param tenant_id: the id of the tenant
  57. :return: the provider
  58. """
  59. # split provider to
  60. if len(cls._hardcoded_providers) == 0:
  61. # init the builtin providers
  62. cls.load_hardcoded_providers_cache()
  63. if provider not in cls._hardcoded_providers:
  64. # get plugin provider
  65. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  66. if plugin_provider:
  67. return plugin_provider
  68. return cls._hardcoded_providers[provider]
  69. @classmethod
  70. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  71. """
  72. get the plugin provider
  73. """
  74. manager = PluginToolManager()
  75. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  76. if not provider_entity:
  77. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  78. return PluginToolProviderController(
  79. entity=provider_entity.declaration,
  80. tenant_id=tenant_id,
  81. )
  82. @classmethod
  83. def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
  84. """
  85. get the builtin tool
  86. :param provider: the name of the provider
  87. :param tool_name: the name of the tool
  88. :param tenant_id: the id of the tenant
  89. :return: the provider, the tool
  90. """
  91. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  92. tool = provider_controller.get_tool(tool_name)
  93. return tool
  94. @classmethod
  95. def get_tool_runtime(
  96. cls,
  97. provider_type: ToolProviderType,
  98. provider_id: str,
  99. tool_name: str,
  100. tenant_id: str,
  101. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  102. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  103. ) -> Union[BuiltinTool, ApiTool, WorkflowTool]:
  104. """
  105. get the tool runtime
  106. :param provider_type: the type of the provider
  107. :param provider_name: the name of the provider
  108. :param tool_name: the name of the tool
  109. :return: the tool
  110. """
  111. if provider_type == ToolProviderType.BUILT_IN:
  112. # check if the builtin tool need credentials
  113. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  114. builtin_tool = provider_controller.get_tool(tool_name)
  115. if not builtin_tool:
  116. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  117. if not provider_controller.need_credentials:
  118. return cast(
  119. BuiltinTool,
  120. builtin_tool.fork_tool_runtime(
  121. runtime=ToolRuntime(
  122. tenant_id=tenant_id,
  123. credentials={},
  124. invoke_from=invoke_from,
  125. tool_invoke_from=tool_invoke_from,
  126. )
  127. ),
  128. )
  129. if isinstance(provider_controller, PluginToolProviderController):
  130. provider_id_entity = ToolProviderID(provider_id)
  131. # get credentials
  132. builtin_provider: BuiltinToolProvider | None = (
  133. db.session.query(BuiltinToolProvider)
  134. .filter(
  135. BuiltinToolProvider.tenant_id == tenant_id,
  136. (BuiltinToolProvider.provider == provider_id)
  137. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  138. )
  139. .first()
  140. )
  141. if builtin_provider is None:
  142. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  143. else:
  144. builtin_provider: BuiltinToolProvider | None = (
  145. db.session.query(BuiltinToolProvider)
  146. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  147. .first()
  148. )
  149. if builtin_provider is None:
  150. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  151. # decrypt the credentials
  152. credentials = builtin_provider.credentials
  153. tool_configuration = ProviderConfigEncrypter(
  154. tenant_id=tenant_id,
  155. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  156. provider_type=provider_controller.provider_type.value,
  157. provider_identity=provider_controller.entity.identity.name,
  158. )
  159. decrypted_credentials = tool_configuration.decrypt(credentials)
  160. return cast(
  161. BuiltinTool,
  162. builtin_tool.fork_tool_runtime(
  163. runtime=ToolRuntime(
  164. tenant_id=tenant_id,
  165. credentials=decrypted_credentials,
  166. runtime_parameters={},
  167. invoke_from=invoke_from,
  168. tool_invoke_from=tool_invoke_from,
  169. )
  170. ),
  171. )
  172. elif provider_type == ToolProviderType.API:
  173. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  174. # decrypt the credentials
  175. tool_configuration = ProviderConfigEncrypter(
  176. tenant_id=tenant_id,
  177. config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
  178. provider_type=api_provider.provider_type.value,
  179. provider_identity=api_provider.entity.identity.name,
  180. )
  181. decrypted_credentials = tool_configuration.decrypt(credentials)
  182. return cast(
  183. ApiTool,
  184. api_provider.get_tool(tool_name).fork_tool_runtime(
  185. runtime=ToolRuntime(
  186. tenant_id=tenant_id,
  187. credentials=decrypted_credentials,
  188. invoke_from=invoke_from,
  189. tool_invoke_from=tool_invoke_from,
  190. )
  191. ),
  192. )
  193. elif provider_type == ToolProviderType.WORKFLOW:
  194. workflow_provider = (
  195. db.session.query(WorkflowToolProvider)
  196. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  197. .first()
  198. )
  199. if workflow_provider is None:
  200. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  201. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  202. return cast(
  203. WorkflowTool,
  204. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  205. runtime=ToolRuntime(
  206. tenant_id=tenant_id,
  207. credentials={},
  208. invoke_from=invoke_from,
  209. tool_invoke_from=tool_invoke_from,
  210. )
  211. ),
  212. )
  213. elif provider_type == ToolProviderType.APP:
  214. raise NotImplementedError("app provider not implemented")
  215. else:
  216. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  217. @classmethod
  218. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
  219. """
  220. init runtime parameter
  221. """
  222. parameter_value = parameters.get(parameter_rule.name)
  223. if not parameter_value and parameter_value != 0:
  224. # get default value
  225. parameter_value = parameter_rule.default
  226. if not parameter_value and parameter_rule.required:
  227. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  228. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  229. # check if tool_parameter_config in options
  230. options = [x.value for x in parameter_rule.options]
  231. if parameter_value is not None and parameter_value not in options:
  232. raise ValueError(
  233. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
  234. )
  235. return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type)
  236. @classmethod
  237. def get_agent_tool_runtime(
  238. cls,
  239. tenant_id: str,
  240. app_id: str,
  241. agent_tool: AgentToolEntity,
  242. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  243. ) -> Tool:
  244. """
  245. get the agent tool runtime
  246. """
  247. tool_entity = cls.get_tool_runtime(
  248. provider_type=agent_tool.provider_type,
  249. provider_id=agent_tool.provider_id,
  250. tool_name=agent_tool.tool_name,
  251. tenant_id=tenant_id,
  252. invoke_from=invoke_from,
  253. tool_invoke_from=ToolInvokeFrom.AGENT,
  254. )
  255. runtime_parameters = {}
  256. parameters = tool_entity.get_merged_runtime_parameters()
  257. for parameter in parameters:
  258. # check file types
  259. if parameter.type == ToolParameter.ToolParameterType.FILE:
  260. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  261. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  262. # save tool parameter to tool entity memory
  263. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  264. runtime_parameters[parameter.name] = value
  265. # decrypt runtime parameters
  266. encryption_manager = ToolParameterConfigurationManager(
  267. tenant_id=tenant_id,
  268. tool_runtime=tool_entity,
  269. provider_name=agent_tool.provider_id,
  270. provider_type=agent_tool.provider_type,
  271. identity_id=f"AGENT.{app_id}",
  272. )
  273. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  274. if not tool_entity.runtime:
  275. raise Exception("tool missing runtime")
  276. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  277. return tool_entity
  278. @classmethod
  279. def get_workflow_tool_runtime(
  280. cls,
  281. tenant_id: str,
  282. app_id: str,
  283. node_id: str,
  284. workflow_tool: "ToolEntity",
  285. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  286. ) -> Tool:
  287. """
  288. get the workflow tool runtime
  289. """
  290. tool_runtime = cls.get_tool_runtime(
  291. provider_type=workflow_tool.provider_type,
  292. provider_id=workflow_tool.provider_id,
  293. tool_name=workflow_tool.tool_name,
  294. tenant_id=tenant_id,
  295. invoke_from=invoke_from,
  296. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  297. )
  298. runtime_parameters = {}
  299. parameters = tool_runtime.get_merged_runtime_parameters()
  300. for parameter in parameters:
  301. # save tool parameter to tool entity memory
  302. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  303. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  304. runtime_parameters[parameter.name] = value
  305. # decrypt runtime parameters
  306. encryption_manager = ToolParameterConfigurationManager(
  307. tenant_id=tenant_id,
  308. tool_runtime=tool_runtime,
  309. provider_name=workflow_tool.provider_id,
  310. provider_type=workflow_tool.provider_type,
  311. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  312. )
  313. if runtime_parameters:
  314. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  315. if not tool_runtime.runtime:
  316. raise Exception("tool missing runtime")
  317. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  318. return tool_runtime
  319. @classmethod
  320. def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
  321. """
  322. get the absolute path of the icon of the builtin provider
  323. :param provider: the name of the provider
  324. :param tenant_id: the id of the tenant
  325. :return: the absolute path of the icon, the mime type of the icon
  326. """
  327. # get provider
  328. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  329. absolute_path = path.join(
  330. path.dirname(path.realpath(__file__)),
  331. "builtin_tool",
  332. "providers",
  333. provider,
  334. "_assets",
  335. provider_controller.entity.identity.icon,
  336. )
  337. # check if the icon exists
  338. if not path.exists(absolute_path):
  339. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  340. # get the mime type
  341. mime_type, _ = mimetypes.guess_type(absolute_path)
  342. mime_type = mime_type or "application/octet-stream"
  343. return absolute_path, mime_type
  344. @classmethod
  345. def list_hardcoded_providers(cls):
  346. # use cache first
  347. if cls._builtin_providers_loaded:
  348. yield from list(cls._hardcoded_providers.values())
  349. return
  350. with cls._builtin_provider_lock:
  351. if cls._builtin_providers_loaded:
  352. yield from list(cls._hardcoded_providers.values())
  353. return
  354. yield from cls._list_hardcoded_providers()
  355. @classmethod
  356. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  357. """
  358. list all the plugin providers
  359. """
  360. manager = PluginToolManager()
  361. provider_entities = manager.fetch_tool_providers(tenant_id)
  362. return [
  363. PluginToolProviderController(
  364. entity=provider.declaration,
  365. tenant_id=tenant_id,
  366. )
  367. for provider in provider_entities
  368. ]
  369. @classmethod
  370. def list_builtin_providers(
  371. cls, tenant_id: str
  372. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  373. """
  374. list all the builtin providers
  375. """
  376. yield from cls.list_hardcoded_providers()
  377. # get plugin providers
  378. yield from cls.list_plugin_providers(tenant_id)
  379. @classmethod
  380. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  381. """
  382. list all the builtin providers
  383. """
  384. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  385. if provider_path.startswith("__"):
  386. continue
  387. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  388. if provider_path.startswith("__"):
  389. continue
  390. # init provider
  391. try:
  392. provider_class = load_single_subclass_from_source(
  393. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  394. script_path=path.join(
  395. path.dirname(path.realpath(__file__)),
  396. "builtin_tool",
  397. "providers",
  398. provider_path,
  399. f"{provider_path}.py",
  400. ),
  401. parent_type=BuiltinToolProviderController,
  402. )
  403. provider: BuiltinToolProviderController = provider_class()
  404. cls._hardcoded_providers[provider.entity.identity.name] = provider
  405. for tool in provider.get_tools():
  406. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  407. yield provider
  408. except Exception as e:
  409. logger.error(f"load builtin provider error: {e}")
  410. continue
  411. # set builtin providers loaded
  412. cls._builtin_providers_loaded = True
  413. @classmethod
  414. def load_hardcoded_providers_cache(cls):
  415. for _ in cls.list_hardcoded_providers():
  416. pass
  417. @classmethod
  418. def clear_hardcoded_providers_cache(cls):
  419. cls._hardcoded_providers = {}
  420. cls._builtin_providers_loaded = False
  421. @classmethod
  422. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  423. """
  424. get the tool label
  425. :param tool_name: the name of the tool
  426. :return: the label of the tool
  427. """
  428. if len(cls._builtin_tools_labels) == 0:
  429. # init the builtin providers
  430. cls.load_hardcoded_providers_cache()
  431. if tool_name not in cls._builtin_tools_labels:
  432. return None
  433. return cls._builtin_tools_labels[tool_name]
  434. @classmethod
  435. def list_providers_from_api(
  436. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  437. ) -> list[ToolProviderApiEntity]:
  438. result_providers: dict[str, ToolProviderApiEntity] = {}
  439. filters = []
  440. if not typ:
  441. filters.extend(["builtin", "api", "workflow"])
  442. else:
  443. filters.append(typ)
  444. if "builtin" in filters:
  445. # get builtin providers
  446. builtin_providers = cls.list_builtin_providers(tenant_id)
  447. # get db builtin providers
  448. db_builtin_providers: list[BuiltinToolProvider] = (
  449. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  450. )
  451. # rewrite db_builtin_providers
  452. for db_provider in db_builtin_providers:
  453. try:
  454. ToolProviderID(db_provider.provider)
  455. except Exception:
  456. db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
  457. find_db_builtin_provider = lambda provider: next(
  458. (x for x in db_builtin_providers if x.provider == provider), None
  459. )
  460. # append builtin providers
  461. for provider in builtin_providers:
  462. # handle include, exclude
  463. if is_filtered(
  464. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  465. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  466. data=provider,
  467. name_func=lambda x: x.identity.name,
  468. ):
  469. continue
  470. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  471. provider_controller=provider,
  472. db_provider=find_db_builtin_provider(provider.entity.identity.name),
  473. decrypt_credentials=False,
  474. )
  475. if isinstance(provider, PluginToolProviderController):
  476. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  477. else:
  478. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  479. # get db api providers
  480. if "api" in filters:
  481. db_api_providers: list[ApiToolProvider] = (
  482. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  483. )
  484. api_provider_controllers = [
  485. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  486. for provider in db_api_providers
  487. ]
  488. # get labels
  489. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  490. for api_provider_controller in api_provider_controllers:
  491. user_provider = ToolTransformService.api_provider_to_user_provider(
  492. provider_controller=api_provider_controller["controller"],
  493. db_provider=api_provider_controller["provider"],
  494. decrypt_credentials=False,
  495. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  496. )
  497. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  498. if "workflow" in filters:
  499. # get workflow providers
  500. workflow_providers: list[WorkflowToolProvider] = (
  501. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  502. )
  503. workflow_provider_controllers = []
  504. for provider in workflow_providers:
  505. try:
  506. workflow_provider_controllers.append(
  507. ToolTransformService.workflow_provider_to_controller(db_provider=provider)
  508. )
  509. except Exception as e:
  510. # app has been deleted
  511. pass
  512. labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
  513. for provider_controller in workflow_provider_controllers:
  514. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  515. provider_controller=provider_controller,
  516. labels=labels.get(provider_controller.provider_id, []),
  517. )
  518. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  519. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  520. @classmethod
  521. def get_api_provider_controller(
  522. cls, tenant_id: str, provider_id: str
  523. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  524. """
  525. get the api provider
  526. :param provider_name: the name of the provider
  527. :return: the provider controller, the credentials
  528. """
  529. provider: ApiToolProvider | None = (
  530. db.session.query(ApiToolProvider)
  531. .filter(
  532. ApiToolProvider.id == provider_id,
  533. ApiToolProvider.tenant_id == tenant_id,
  534. )
  535. .first()
  536. )
  537. if provider is None:
  538. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  539. controller = ApiToolProviderController.from_db(
  540. provider,
  541. ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  542. )
  543. controller.load_bundled_tools(provider.tools)
  544. return controller, provider.credentials
  545. @classmethod
  546. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  547. """
  548. get api provider
  549. """
  550. provider_obj: ApiToolProvider | None = (
  551. db.session.query(ApiToolProvider)
  552. .filter(
  553. ApiToolProvider.tenant_id == tenant_id,
  554. ApiToolProvider.name == provider,
  555. )
  556. .first()
  557. )
  558. if provider_obj is None:
  559. raise ValueError(f"you have not added provider {provider}")
  560. try:
  561. credentials = json.loads(provider_obj.credentials_str) or {}
  562. except:
  563. credentials = {}
  564. # package tool provider controller
  565. controller = ApiToolProviderController.from_db(
  566. provider_obj,
  567. ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  568. )
  569. # init tool configuration
  570. tool_configuration = ProviderConfigEncrypter(
  571. tenant_id=tenant_id,
  572. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  573. provider_type=controller.provider_type.value,
  574. provider_identity=controller.entity.identity.name,
  575. )
  576. decrypted_credentials = tool_configuration.decrypt(credentials)
  577. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  578. try:
  579. icon = json.loads(provider_obj.icon)
  580. except:
  581. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  582. # add tool labels
  583. labels = ToolLabelManager.get_tool_labels(controller)
  584. return jsonable_encoder(
  585. {
  586. "schema_type": provider_obj.schema_type,
  587. "schema": provider_obj.schema,
  588. "tools": provider_obj.tools,
  589. "icon": icon,
  590. "description": provider_obj.description,
  591. "credentials": masked_credentials,
  592. "privacy_policy": provider_obj.privacy_policy,
  593. "custom_disclaimer": provider_obj.custom_disclaimer,
  594. "labels": labels,
  595. }
  596. )
  597. @classmethod
  598. def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
  599. """
  600. get the tool icon
  601. :param tenant_id: the id of the tenant
  602. :param provider_type: the type of the provider
  603. :param provider_id: the id of the provider
  604. :return:
  605. """
  606. provider_type = provider_type
  607. provider_id = provider_id
  608. if provider_type == ToolProviderType.BUILT_IN:
  609. return (
  610. dify_config.CONSOLE_API_URL
  611. + "/console/api/workspaces/current/tool-provider/builtin/"
  612. + provider_id
  613. + "/icon"
  614. )
  615. elif provider_type == ToolProviderType.API:
  616. try:
  617. api_provider: ApiToolProvider | None = (
  618. db.session.query(ApiToolProvider)
  619. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  620. .first()
  621. )
  622. if not api_provider:
  623. raise ValueError("api tool not found")
  624. return json.loads(api_provider.icon)
  625. except:
  626. return {"background": "#252525", "content": "\ud83d\ude01"}
  627. elif provider_type == ToolProviderType.WORKFLOW:
  628. workflow_provider: WorkflowToolProvider | None = (
  629. db.session.query(WorkflowToolProvider)
  630. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  631. .first()
  632. )
  633. if workflow_provider is None:
  634. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  635. return json.loads(workflow_provider.icon)
  636. else:
  637. raise ValueError(f"provider type {provider_type} not found")
  638. ToolManager.load_hardcoded_providers_cache()