tool_manager.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783
  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Union, cast
  8. from core.plugin.manager.tool import PluginToolManager
  9. from core.tools.__base.tool_runtime import ToolRuntime
  10. from core.tools.plugin_tool.provider import PluginToolProviderController
  11. from core.tools.plugin_tool.tool import PluginTool
  12. if TYPE_CHECKING:
  13. from core.workflow.nodes.tool.entities import ToolEntity
  14. from configs import dify_config
  15. from core.agent.entities import AgentToolEntity
  16. from core.app.entities.app_invoke_entities import InvokeFrom
  17. from core.helper.module_import_helper import load_single_subclass_from_source
  18. from core.helper.position_helper import is_filtered
  19. from core.model_runtime.utils.encoders import jsonable_encoder
  20. from core.tools.__base.tool import Tool
  21. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  22. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  23. from core.tools.builtin_tool.tool import BuiltinTool
  24. from core.tools.custom_tool.provider import ApiToolProviderController
  25. from core.tools.custom_tool.tool import ApiTool
  26. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  27. from core.tools.entities.common_entities import I18nObject
  28. from core.tools.entities.tool_entities import (
  29. ApiProviderAuthType,
  30. ToolInvokeFrom,
  31. ToolParameter,
  32. ToolProviderID,
  33. ToolProviderType,
  34. )
  35. from core.tools.errors import ToolProviderNotFoundError
  36. from core.tools.tool_label_manager import ToolLabelManager
  37. from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
  38. from core.tools.utils.tool_parameter_converter import ToolParameterConverter
  39. from core.tools.workflow_as_tool.tool import WorkflowTool
  40. from extensions.ext_database import db
  41. from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
  42. from services.tools.tools_transform_service import ToolTransformService
  43. logger = logging.getLogger(__name__)
  44. class ToolManager:
  45. _builtin_provider_lock = Lock()
  46. _hardcoded_providers = {}
  47. _builtin_providers_loaded = False
  48. _builtin_tools_labels = {}
  49. @classmethod
  50. def get_builtin_provider(
  51. cls, provider: str, tenant_id: str
  52. ) -> BuiltinToolProviderController | PluginToolProviderController:
  53. """
  54. get the builtin provider
  55. :param provider: the name of the provider
  56. :param tenant_id: the id of the tenant
  57. :return: the provider
  58. """
  59. # split provider to
  60. if len(cls._hardcoded_providers) == 0:
  61. # init the builtin providers
  62. cls.load_hardcoded_providers_cache()
  63. if provider not in cls._hardcoded_providers:
  64. # get plugin provider
  65. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  66. if plugin_provider:
  67. return plugin_provider
  68. return cls._hardcoded_providers[provider]
  69. @classmethod
  70. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  71. """
  72. get the plugin provider
  73. """
  74. manager = PluginToolManager()
  75. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  76. if not provider_entity:
  77. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  78. return PluginToolProviderController(
  79. entity=provider_entity.declaration,
  80. tenant_id=tenant_id,
  81. )
  82. @classmethod
  83. def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
  84. """
  85. get the builtin tool
  86. :param provider: the name of the provider
  87. :param tool_name: the name of the tool
  88. :param tenant_id: the id of the tenant
  89. :return: the provider, the tool
  90. """
  91. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  92. tool = provider_controller.get_tool(tool_name)
  93. return tool
  94. @classmethod
  95. def get_tool_runtime(
  96. cls,
  97. provider_type: ToolProviderType,
  98. provider_id: str,
  99. tool_name: str,
  100. tenant_id: str,
  101. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  102. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  103. ) -> Union[BuiltinTool, ApiTool, WorkflowTool]:
  104. """
  105. get the tool runtime
  106. :param provider_type: the type of the provider
  107. :param provider_name: the name of the provider
  108. :param tool_name: the name of the tool
  109. :return: the tool
  110. """
  111. if provider_type == ToolProviderType.BUILT_IN:
  112. # check if the builtin tool need credentials
  113. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  114. builtin_tool = provider_controller.get_tool(tool_name)
  115. if not builtin_tool:
  116. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  117. if not provider_controller.need_credentials:
  118. return cast(
  119. BuiltinTool,
  120. builtin_tool.fork_tool_runtime(
  121. runtime=ToolRuntime(
  122. tenant_id=tenant_id,
  123. credentials={},
  124. invoke_from=invoke_from,
  125. tool_invoke_from=tool_invoke_from,
  126. )
  127. ),
  128. )
  129. if isinstance(provider_controller, PluginToolProviderController):
  130. provider_id_entity = ToolProviderID(provider_id)
  131. # get credentials
  132. builtin_provider: BuiltinToolProvider | None = (
  133. db.session.query(BuiltinToolProvider)
  134. .filter(
  135. BuiltinToolProvider.tenant_id == tenant_id,
  136. (BuiltinToolProvider.provider == provider_id)
  137. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  138. )
  139. .first()
  140. )
  141. if builtin_provider is None:
  142. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  143. else:
  144. builtin_provider: BuiltinToolProvider | None = (
  145. db.session.query(BuiltinToolProvider)
  146. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  147. .first()
  148. )
  149. if builtin_provider is None:
  150. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  151. # decrypt the credentials
  152. credentials = builtin_provider.credentials
  153. tool_configuration = ProviderConfigEncrypter(
  154. tenant_id=tenant_id,
  155. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  156. provider_type=provider_controller.provider_type.value,
  157. provider_identity=provider_controller.entity.identity.name,
  158. )
  159. decrypted_credentials = tool_configuration.decrypt(credentials)
  160. return cast(
  161. BuiltinTool,
  162. builtin_tool.fork_tool_runtime(
  163. runtime=ToolRuntime(
  164. tenant_id=tenant_id,
  165. credentials=decrypted_credentials,
  166. runtime_parameters={},
  167. invoke_from=invoke_from,
  168. tool_invoke_from=tool_invoke_from,
  169. )
  170. ),
  171. )
  172. elif provider_type == ToolProviderType.API:
  173. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  174. # decrypt the credentials
  175. tool_configuration = ProviderConfigEncrypter(
  176. tenant_id=tenant_id,
  177. config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
  178. provider_type=api_provider.provider_type.value,
  179. provider_identity=api_provider.entity.identity.name,
  180. )
  181. decrypted_credentials = tool_configuration.decrypt(credentials)
  182. return cast(
  183. ApiTool,
  184. api_provider.get_tool(tool_name).fork_tool_runtime(
  185. runtime=ToolRuntime(
  186. tenant_id=tenant_id,
  187. credentials=decrypted_credentials,
  188. invoke_from=invoke_from,
  189. tool_invoke_from=tool_invoke_from,
  190. )
  191. ),
  192. )
  193. elif provider_type == ToolProviderType.WORKFLOW:
  194. workflow_provider = (
  195. db.session.query(WorkflowToolProvider)
  196. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  197. .first()
  198. )
  199. if workflow_provider is None:
  200. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  201. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  202. return cast(
  203. WorkflowTool,
  204. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  205. runtime=ToolRuntime(
  206. tenant_id=tenant_id,
  207. credentials={},
  208. invoke_from=invoke_from,
  209. tool_invoke_from=tool_invoke_from,
  210. )
  211. ),
  212. )
  213. elif provider_type == ToolProviderType.APP:
  214. raise NotImplementedError("app provider not implemented")
  215. else:
  216. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  217. @classmethod
  218. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict) -> Union[str, int, float, bool]:
  219. """
  220. init runtime parameter
  221. """
  222. parameter_value = parameters.get(parameter_rule.name)
  223. if not parameter_value and parameter_value != 0:
  224. # get default value
  225. parameter_value = parameter_rule.default
  226. if not parameter_value and parameter_rule.required:
  227. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  228. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  229. # check if tool_parameter_config in options
  230. options = [x.value for x in parameter_rule.options]
  231. if parameter_value is not None and parameter_value not in options:
  232. raise ValueError(
  233. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
  234. )
  235. return ToolParameterConverter.cast_parameter_by_type(parameter_value, parameter_rule.type)
  236. @classmethod
  237. def get_agent_tool_runtime(
  238. cls,
  239. tenant_id: str,
  240. app_id: str,
  241. agent_tool: AgentToolEntity,
  242. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  243. ) -> Tool:
  244. """
  245. get the agent tool runtime
  246. """
  247. tool_entity = cls.get_tool_runtime(
  248. provider_type=agent_tool.provider_type,
  249. provider_id=agent_tool.provider_id,
  250. tool_name=agent_tool.tool_name,
  251. tenant_id=tenant_id,
  252. invoke_from=invoke_from,
  253. tool_invoke_from=ToolInvokeFrom.AGENT,
  254. )
  255. runtime_parameters = {}
  256. parameters = tool_entity.get_merged_runtime_parameters()
  257. for parameter in parameters:
  258. # check file types
  259. if parameter.type == ToolParameter.ToolParameterType.FILE:
  260. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  261. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  262. # save tool parameter to tool entity memory
  263. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  264. runtime_parameters[parameter.name] = value
  265. # decrypt runtime parameters
  266. encryption_manager = ToolParameterConfigurationManager(
  267. tenant_id=tenant_id,
  268. tool_runtime=tool_entity,
  269. provider_name=agent_tool.provider_id,
  270. provider_type=agent_tool.provider_type,
  271. identity_id=f"AGENT.{app_id}",
  272. )
  273. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  274. if not tool_entity.runtime:
  275. raise Exception("tool missing runtime")
  276. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  277. return tool_entity
  278. @classmethod
  279. def get_workflow_tool_runtime(
  280. cls,
  281. tenant_id: str,
  282. app_id: str,
  283. node_id: str,
  284. workflow_tool: "ToolEntity",
  285. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  286. ) -> Tool:
  287. """
  288. get the workflow tool runtime
  289. """
  290. tool_runtime = cls.get_tool_runtime(
  291. provider_type=workflow_tool.provider_type,
  292. provider_id=workflow_tool.provider_id,
  293. tool_name=workflow_tool.tool_name,
  294. tenant_id=tenant_id,
  295. invoke_from=invoke_from,
  296. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  297. )
  298. runtime_parameters = {}
  299. parameters = tool_runtime.get_merged_runtime_parameters()
  300. for parameter in parameters:
  301. # save tool parameter to tool entity memory
  302. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  303. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  304. runtime_parameters[parameter.name] = value
  305. # decrypt runtime parameters
  306. encryption_manager = ToolParameterConfigurationManager(
  307. tenant_id=tenant_id,
  308. tool_runtime=tool_runtime,
  309. provider_name=workflow_tool.provider_id,
  310. provider_type=workflow_tool.provider_type,
  311. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  312. )
  313. if runtime_parameters:
  314. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  315. if not tool_runtime.runtime:
  316. raise Exception("tool missing runtime")
  317. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  318. return tool_runtime
  319. @classmethod
  320. def get_tool_runtime_from_plugin(
  321. cls,
  322. tool_type: ToolProviderType,
  323. tenant_id: str,
  324. provider: str,
  325. tool_name: str,
  326. tool_parameters: dict[str, Any],
  327. ) -> Tool:
  328. """
  329. get tool runtime from plugin
  330. """
  331. tool_entity = cls.get_tool_runtime(
  332. provider_type=tool_type,
  333. provider_id=provider,
  334. tool_name=tool_name,
  335. tenant_id=tenant_id,
  336. invoke_from=InvokeFrom.SERVICE_API,
  337. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  338. )
  339. runtime_parameters = {}
  340. parameters = tool_entity.get_merged_runtime_parameters()
  341. for parameter in parameters:
  342. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  343. # save tool parameter to tool entity memory
  344. value = cls._init_runtime_parameter(parameter, tool_parameters)
  345. runtime_parameters[parameter.name] = value
  346. if not tool_entity.runtime:
  347. raise Exception("tool missing runtime")
  348. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  349. return tool_entity
  350. @classmethod
  351. def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
  352. """
  353. get the absolute path of the icon of the builtin provider
  354. :param provider: the name of the provider
  355. :param tenant_id: the id of the tenant
  356. :return: the absolute path of the icon, the mime type of the icon
  357. """
  358. # get provider
  359. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  360. absolute_path = path.join(
  361. path.dirname(path.realpath(__file__)),
  362. "builtin_tool",
  363. "providers",
  364. provider,
  365. "_assets",
  366. provider_controller.entity.identity.icon,
  367. )
  368. # check if the icon exists
  369. if not path.exists(absolute_path):
  370. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  371. # get the mime type
  372. mime_type, _ = mimetypes.guess_type(absolute_path)
  373. mime_type = mime_type or "application/octet-stream"
  374. return absolute_path, mime_type
  375. @classmethod
  376. def list_hardcoded_providers(cls):
  377. # use cache first
  378. if cls._builtin_providers_loaded:
  379. yield from list(cls._hardcoded_providers.values())
  380. return
  381. with cls._builtin_provider_lock:
  382. if cls._builtin_providers_loaded:
  383. yield from list(cls._hardcoded_providers.values())
  384. return
  385. yield from cls._list_hardcoded_providers()
  386. @classmethod
  387. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  388. """
  389. list all the plugin providers
  390. """
  391. manager = PluginToolManager()
  392. provider_entities = manager.fetch_tool_providers(tenant_id)
  393. return [
  394. PluginToolProviderController(
  395. entity=provider.declaration,
  396. tenant_id=tenant_id,
  397. )
  398. for provider in provider_entities
  399. ]
  400. @classmethod
  401. def list_builtin_providers(
  402. cls, tenant_id: str
  403. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  404. """
  405. list all the builtin providers
  406. """
  407. yield from cls.list_hardcoded_providers()
  408. # get plugin providers
  409. yield from cls.list_plugin_providers(tenant_id)
  410. @classmethod
  411. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  412. """
  413. list all the builtin providers
  414. """
  415. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  416. if provider_path.startswith("__"):
  417. continue
  418. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  419. if provider_path.startswith("__"):
  420. continue
  421. # init provider
  422. try:
  423. provider_class = load_single_subclass_from_source(
  424. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  425. script_path=path.join(
  426. path.dirname(path.realpath(__file__)),
  427. "builtin_tool",
  428. "providers",
  429. provider_path,
  430. f"{provider_path}.py",
  431. ),
  432. parent_type=BuiltinToolProviderController,
  433. )
  434. provider: BuiltinToolProviderController = provider_class()
  435. cls._hardcoded_providers[provider.entity.identity.name] = provider
  436. for tool in provider.get_tools():
  437. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  438. yield provider
  439. except Exception as e:
  440. logger.error(f"load builtin provider error: {e}")
  441. continue
  442. # set builtin providers loaded
  443. cls._builtin_providers_loaded = True
  444. @classmethod
  445. def load_hardcoded_providers_cache(cls):
  446. for _ in cls.list_hardcoded_providers():
  447. pass
  448. @classmethod
  449. def clear_hardcoded_providers_cache(cls):
  450. cls._hardcoded_providers = {}
  451. cls._builtin_providers_loaded = False
  452. @classmethod
  453. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  454. """
  455. get the tool label
  456. :param tool_name: the name of the tool
  457. :return: the label of the tool
  458. """
  459. if len(cls._builtin_tools_labels) == 0:
  460. # init the builtin providers
  461. cls.load_hardcoded_providers_cache()
  462. if tool_name not in cls._builtin_tools_labels:
  463. return None
  464. return cls._builtin_tools_labels[tool_name]
  465. @classmethod
  466. def list_providers_from_api(
  467. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  468. ) -> list[ToolProviderApiEntity]:
  469. result_providers: dict[str, ToolProviderApiEntity] = {}
  470. filters = []
  471. if not typ:
  472. filters.extend(["builtin", "api", "workflow"])
  473. else:
  474. filters.append(typ)
  475. if "builtin" in filters:
  476. # get builtin providers
  477. builtin_providers = cls.list_builtin_providers(tenant_id)
  478. # get db builtin providers
  479. db_builtin_providers: list[BuiltinToolProvider] = (
  480. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  481. )
  482. # rewrite db_builtin_providers
  483. for db_provider in db_builtin_providers:
  484. try:
  485. ToolProviderID(db_provider.provider)
  486. except Exception:
  487. db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
  488. find_db_builtin_provider = lambda provider: next(
  489. (x for x in db_builtin_providers if x.provider == provider), None
  490. )
  491. # append builtin providers
  492. for provider in builtin_providers:
  493. # handle include, exclude
  494. if is_filtered(
  495. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  496. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  497. data=provider,
  498. name_func=lambda x: x.identity.name,
  499. ):
  500. continue
  501. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  502. provider_controller=provider,
  503. db_provider=find_db_builtin_provider(provider.entity.identity.name),
  504. decrypt_credentials=False,
  505. )
  506. if isinstance(provider, PluginToolProviderController):
  507. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  508. else:
  509. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  510. # get db api providers
  511. if "api" in filters:
  512. db_api_providers: list[ApiToolProvider] = (
  513. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  514. )
  515. api_provider_controllers = [
  516. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  517. for provider in db_api_providers
  518. ]
  519. # get labels
  520. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  521. for api_provider_controller in api_provider_controllers:
  522. user_provider = ToolTransformService.api_provider_to_user_provider(
  523. provider_controller=api_provider_controller["controller"],
  524. db_provider=api_provider_controller["provider"],
  525. decrypt_credentials=False,
  526. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  527. )
  528. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  529. if "workflow" in filters:
  530. # get workflow providers
  531. workflow_providers: list[WorkflowToolProvider] = (
  532. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  533. )
  534. workflow_provider_controllers = []
  535. for provider in workflow_providers:
  536. try:
  537. workflow_provider_controllers.append(
  538. ToolTransformService.workflow_provider_to_controller(db_provider=provider)
  539. )
  540. except Exception as e:
  541. # app has been deleted
  542. pass
  543. labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
  544. for provider_controller in workflow_provider_controllers:
  545. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  546. provider_controller=provider_controller,
  547. labels=labels.get(provider_controller.provider_id, []),
  548. )
  549. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  550. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  551. @classmethod
  552. def get_api_provider_controller(
  553. cls, tenant_id: str, provider_id: str
  554. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  555. """
  556. get the api provider
  557. :param provider_name: the name of the provider
  558. :return: the provider controller, the credentials
  559. """
  560. provider: ApiToolProvider | None = (
  561. db.session.query(ApiToolProvider)
  562. .filter(
  563. ApiToolProvider.id == provider_id,
  564. ApiToolProvider.tenant_id == tenant_id,
  565. )
  566. .first()
  567. )
  568. if provider is None:
  569. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  570. controller = ApiToolProviderController.from_db(
  571. provider,
  572. ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  573. )
  574. controller.load_bundled_tools(provider.tools)
  575. return controller, provider.credentials
  576. @classmethod
  577. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  578. """
  579. get api provider
  580. """
  581. provider_obj: ApiToolProvider | None = (
  582. db.session.query(ApiToolProvider)
  583. .filter(
  584. ApiToolProvider.tenant_id == tenant_id,
  585. ApiToolProvider.name == provider,
  586. )
  587. .first()
  588. )
  589. if provider_obj is None:
  590. raise ValueError(f"you have not added provider {provider}")
  591. try:
  592. credentials = json.loads(provider_obj.credentials_str) or {}
  593. except:
  594. credentials = {}
  595. # package tool provider controller
  596. controller = ApiToolProviderController.from_db(
  597. provider_obj,
  598. ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  599. )
  600. # init tool configuration
  601. tool_configuration = ProviderConfigEncrypter(
  602. tenant_id=tenant_id,
  603. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  604. provider_type=controller.provider_type.value,
  605. provider_identity=controller.entity.identity.name,
  606. )
  607. decrypted_credentials = tool_configuration.decrypt(credentials)
  608. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  609. try:
  610. icon = json.loads(provider_obj.icon)
  611. except:
  612. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  613. # add tool labels
  614. labels = ToolLabelManager.get_tool_labels(controller)
  615. return jsonable_encoder(
  616. {
  617. "schema_type": provider_obj.schema_type,
  618. "schema": provider_obj.schema,
  619. "tools": provider_obj.tools,
  620. "icon": icon,
  621. "description": provider_obj.description,
  622. "credentials": masked_credentials,
  623. "privacy_policy": provider_obj.privacy_policy,
  624. "custom_disclaimer": provider_obj.custom_disclaimer,
  625. "labels": labels,
  626. }
  627. )
  628. @classmethod
  629. def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
  630. """
  631. get the tool icon
  632. :param tenant_id: the id of the tenant
  633. :param provider_type: the type of the provider
  634. :param provider_id: the id of the provider
  635. :return:
  636. """
  637. provider_type = provider_type
  638. provider_id = provider_id
  639. if provider_type == ToolProviderType.BUILT_IN:
  640. return (
  641. dify_config.CONSOLE_API_URL
  642. + "/console/api/workspaces/current/tool-provider/builtin/"
  643. + provider_id
  644. + "/icon"
  645. )
  646. elif provider_type == ToolProviderType.API:
  647. try:
  648. api_provider: ApiToolProvider | None = (
  649. db.session.query(ApiToolProvider)
  650. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  651. .first()
  652. )
  653. if not api_provider:
  654. raise ValueError("api tool not found")
  655. return json.loads(api_provider.icon)
  656. except:
  657. return {"background": "#252525", "content": "\ud83d\ude01"}
  658. elif provider_type == ToolProviderType.WORKFLOW:
  659. workflow_provider: WorkflowToolProvider | None = (
  660. db.session.query(WorkflowToolProvider)
  661. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  662. .first()
  663. )
  664. if workflow_provider is None:
  665. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  666. return json.loads(workflow_provider.icon)
  667. else:
  668. raise ValueError(f"provider type {provider_type} not found")
  669. ToolManager.load_hardcoded_providers_cache()