tool_manager.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791
  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Union, cast
  8. from core.plugin.manager.tool import PluginToolManager
  9. from core.tools.__base.tool_runtime import ToolRuntime
  10. from core.tools.plugin_tool.provider import PluginToolProviderController
  11. from core.tools.plugin_tool.tool import PluginTool
  12. if TYPE_CHECKING:
  13. from core.workflow.nodes.tool.entities import ToolEntity
  14. from configs import dify_config
  15. from core.agent.entities import AgentToolEntity
  16. from core.app.entities.app_invoke_entities import InvokeFrom
  17. from core.helper.module_import_helper import load_single_subclass_from_source
  18. from core.helper.position_helper import is_filtered
  19. from core.model_runtime.utils.encoders import jsonable_encoder
  20. from core.tools.__base.tool import Tool
  21. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  22. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  23. from core.tools.builtin_tool.tool import BuiltinTool
  24. from core.tools.custom_tool.provider import ApiToolProviderController
  25. from core.tools.custom_tool.tool import ApiTool
  26. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  27. from core.tools.entities.common_entities import I18nObject
  28. from core.tools.entities.tool_entities import (
  29. ApiProviderAuthType,
  30. ToolInvokeFrom,
  31. ToolParameter,
  32. ToolProviderID,
  33. ToolProviderType,
  34. )
  35. from core.tools.errors import ToolProviderNotFoundError
  36. from core.tools.tool_label_manager import ToolLabelManager
  37. from core.tools.utils.configuration import (
  38. ProviderConfigEncrypter,
  39. ToolParameterConfigurationManager,
  40. )
  41. from core.tools.workflow_as_tool.tool import WorkflowTool
  42. from extensions.ext_database import db
  43. from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
  44. from services.tools.tools_transform_service import ToolTransformService
  45. logger = logging.getLogger(__name__)
  46. class ToolManager:
  47. _builtin_provider_lock = Lock()
  48. _hardcoded_providers = {}
  49. _builtin_providers_loaded = False
  50. _builtin_tools_labels = {}
  51. @classmethod
  52. def get_builtin_provider(
  53. cls, provider: str, tenant_id: str
  54. ) -> BuiltinToolProviderController | PluginToolProviderController:
  55. """
  56. get the builtin provider
  57. :param provider: the name of the provider
  58. :param tenant_id: the id of the tenant
  59. :return: the provider
  60. """
  61. # split provider to
  62. if len(cls._hardcoded_providers) == 0:
  63. # init the builtin providers
  64. cls.load_hardcoded_providers_cache()
  65. if provider not in cls._hardcoded_providers:
  66. # get plugin provider
  67. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  68. if plugin_provider:
  69. return plugin_provider
  70. return cls._hardcoded_providers[provider]
  71. @classmethod
  72. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  73. """
  74. get the plugin provider
  75. """
  76. manager = PluginToolManager()
  77. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  78. if not provider_entity:
  79. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  80. return PluginToolProviderController(
  81. entity=provider_entity.declaration,
  82. plugin_id=provider_entity.plugin_id,
  83. tenant_id=tenant_id,
  84. )
  85. @classmethod
  86. def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
  87. """
  88. get the builtin tool
  89. :param provider: the name of the provider
  90. :param tool_name: the name of the tool
  91. :param tenant_id: the id of the tenant
  92. :return: the provider, the tool
  93. """
  94. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  95. tool = provider_controller.get_tool(tool_name)
  96. return tool
  97. @classmethod
  98. def get_tool_runtime(
  99. cls,
  100. provider_type: ToolProviderType,
  101. provider_id: str,
  102. tool_name: str,
  103. tenant_id: str,
  104. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  105. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  106. ) -> Union[BuiltinTool, ApiTool, WorkflowTool]:
  107. """
  108. get the tool runtime
  109. :param provider_type: the type of the provider
  110. :param provider_name: the name of the provider
  111. :param tool_name: the name of the tool
  112. :return: the tool
  113. """
  114. if provider_type == ToolProviderType.BUILT_IN:
  115. # check if the builtin tool need credentials
  116. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  117. builtin_tool = provider_controller.get_tool(tool_name)
  118. if not builtin_tool:
  119. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  120. if not provider_controller.need_credentials:
  121. return cast(
  122. BuiltinTool,
  123. builtin_tool.fork_tool_runtime(
  124. runtime=ToolRuntime(
  125. tenant_id=tenant_id,
  126. credentials={},
  127. invoke_from=invoke_from,
  128. tool_invoke_from=tool_invoke_from,
  129. )
  130. ),
  131. )
  132. if isinstance(provider_controller, PluginToolProviderController):
  133. provider_id_entity = ToolProviderID(provider_id)
  134. # get credentials
  135. builtin_provider: BuiltinToolProvider | None = (
  136. db.session.query(BuiltinToolProvider)
  137. .filter(
  138. BuiltinToolProvider.tenant_id == tenant_id,
  139. (BuiltinToolProvider.provider == provider_id)
  140. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  141. )
  142. .first()
  143. )
  144. if builtin_provider is None:
  145. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  146. else:
  147. builtin_provider: BuiltinToolProvider | None = (
  148. db.session.query(BuiltinToolProvider)
  149. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  150. .first()
  151. )
  152. if builtin_provider is None:
  153. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  154. # decrypt the credentials
  155. credentials = builtin_provider.credentials
  156. tool_configuration = ProviderConfigEncrypter(
  157. tenant_id=tenant_id,
  158. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  159. provider_type=provider_controller.provider_type.value,
  160. provider_identity=provider_controller.entity.identity.name,
  161. )
  162. decrypted_credentials = tool_configuration.decrypt(credentials)
  163. return cast(
  164. BuiltinTool,
  165. builtin_tool.fork_tool_runtime(
  166. runtime=ToolRuntime(
  167. tenant_id=tenant_id,
  168. credentials=decrypted_credentials,
  169. runtime_parameters={},
  170. invoke_from=invoke_from,
  171. tool_invoke_from=tool_invoke_from,
  172. )
  173. ),
  174. )
  175. elif provider_type == ToolProviderType.API:
  176. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  177. # decrypt the credentials
  178. tool_configuration = ProviderConfigEncrypter(
  179. tenant_id=tenant_id,
  180. config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
  181. provider_type=api_provider.provider_type.value,
  182. provider_identity=api_provider.entity.identity.name,
  183. )
  184. decrypted_credentials = tool_configuration.decrypt(credentials)
  185. return cast(
  186. ApiTool,
  187. api_provider.get_tool(tool_name).fork_tool_runtime(
  188. runtime=ToolRuntime(
  189. tenant_id=tenant_id,
  190. credentials=decrypted_credentials,
  191. invoke_from=invoke_from,
  192. tool_invoke_from=tool_invoke_from,
  193. )
  194. ),
  195. )
  196. elif provider_type == ToolProviderType.WORKFLOW:
  197. workflow_provider = (
  198. db.session.query(WorkflowToolProvider)
  199. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  200. .first()
  201. )
  202. if workflow_provider is None:
  203. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  204. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  205. return cast(
  206. WorkflowTool,
  207. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  208. runtime=ToolRuntime(
  209. tenant_id=tenant_id,
  210. credentials={},
  211. invoke_from=invoke_from,
  212. tool_invoke_from=tool_invoke_from,
  213. )
  214. ),
  215. )
  216. elif provider_type == ToolProviderType.APP:
  217. raise NotImplementedError("app provider not implemented")
  218. else:
  219. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  220. @classmethod
  221. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
  222. """
  223. init runtime parameter
  224. """
  225. parameter_value = parameters.get(parameter_rule.name)
  226. if not parameter_value and parameter_value != 0:
  227. # get default value
  228. parameter_value = parameter_rule.default
  229. if not parameter_value and parameter_rule.required:
  230. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  231. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  232. # check if tool_parameter_config in options
  233. options = [x.value for x in parameter_rule.options]
  234. if parameter_value is not None and parameter_value not in options:
  235. raise ValueError(
  236. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
  237. )
  238. return parameter_rule.type.cast_value(parameter_value)
  239. @classmethod
  240. def get_agent_tool_runtime(
  241. cls,
  242. tenant_id: str,
  243. app_id: str,
  244. agent_tool: AgentToolEntity,
  245. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  246. ) -> Tool:
  247. """
  248. get the agent tool runtime
  249. """
  250. tool_entity = cls.get_tool_runtime(
  251. provider_type=agent_tool.provider_type,
  252. provider_id=agent_tool.provider_id,
  253. tool_name=agent_tool.tool_name,
  254. tenant_id=tenant_id,
  255. invoke_from=invoke_from,
  256. tool_invoke_from=ToolInvokeFrom.AGENT,
  257. )
  258. runtime_parameters = {}
  259. parameters = tool_entity.get_merged_runtime_parameters()
  260. for parameter in parameters:
  261. # check file types
  262. if parameter.type in {
  263. ToolParameter.ToolParameterType.SYSTEM_FILES,
  264. ToolParameter.ToolParameterType.FILE,
  265. ToolParameter.ToolParameterType.FILES,
  266. }:
  267. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  268. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  269. # save tool parameter to tool entity memory
  270. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  271. runtime_parameters[parameter.name] = value
  272. # decrypt runtime parameters
  273. encryption_manager = ToolParameterConfigurationManager(
  274. tenant_id=tenant_id,
  275. tool_runtime=tool_entity,
  276. provider_name=agent_tool.provider_id,
  277. provider_type=agent_tool.provider_type,
  278. identity_id=f"AGENT.{app_id}",
  279. )
  280. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  281. if not tool_entity.runtime:
  282. raise Exception("tool missing runtime")
  283. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  284. return tool_entity
  285. @classmethod
  286. def get_workflow_tool_runtime(
  287. cls,
  288. tenant_id: str,
  289. app_id: str,
  290. node_id: str,
  291. workflow_tool: "ToolEntity",
  292. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  293. ) -> Tool:
  294. """
  295. get the workflow tool runtime
  296. """
  297. tool_runtime = cls.get_tool_runtime(
  298. provider_type=workflow_tool.provider_type,
  299. provider_id=workflow_tool.provider_id,
  300. tool_name=workflow_tool.tool_name,
  301. tenant_id=tenant_id,
  302. invoke_from=invoke_from,
  303. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  304. )
  305. runtime_parameters = {}
  306. parameters = tool_runtime.get_merged_runtime_parameters()
  307. for parameter in parameters:
  308. # save tool parameter to tool entity memory
  309. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  310. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  311. runtime_parameters[parameter.name] = value
  312. # decrypt runtime parameters
  313. encryption_manager = ToolParameterConfigurationManager(
  314. tenant_id=tenant_id,
  315. tool_runtime=tool_runtime,
  316. provider_name=workflow_tool.provider_id,
  317. provider_type=workflow_tool.provider_type,
  318. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  319. )
  320. if runtime_parameters:
  321. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  322. if not tool_runtime.runtime:
  323. raise Exception("tool missing runtime")
  324. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  325. return tool_runtime
  326. @classmethod
  327. def get_tool_runtime_from_plugin(
  328. cls,
  329. tool_type: ToolProviderType,
  330. tenant_id: str,
  331. provider: str,
  332. tool_name: str,
  333. tool_parameters: dict[str, Any],
  334. ) -> Tool:
  335. """
  336. get tool runtime from plugin
  337. """
  338. tool_entity = cls.get_tool_runtime(
  339. provider_type=tool_type,
  340. provider_id=provider,
  341. tool_name=tool_name,
  342. tenant_id=tenant_id,
  343. invoke_from=InvokeFrom.SERVICE_API,
  344. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  345. )
  346. runtime_parameters = {}
  347. parameters = tool_entity.get_merged_runtime_parameters()
  348. for parameter in parameters:
  349. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  350. # save tool parameter to tool entity memory
  351. value = cls._init_runtime_parameter(parameter, tool_parameters)
  352. runtime_parameters[parameter.name] = value
  353. if not tool_entity.runtime:
  354. raise Exception("tool missing runtime")
  355. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  356. return tool_entity
  357. @classmethod
  358. def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
  359. """
  360. get the absolute path of the icon of the builtin provider
  361. :param provider: the name of the provider
  362. :param tenant_id: the id of the tenant
  363. :return: the absolute path of the icon, the mime type of the icon
  364. """
  365. # get provider
  366. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  367. absolute_path = path.join(
  368. path.dirname(path.realpath(__file__)),
  369. "builtin_tool",
  370. "providers",
  371. provider,
  372. "_assets",
  373. provider_controller.entity.identity.icon,
  374. )
  375. # check if the icon exists
  376. if not path.exists(absolute_path):
  377. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  378. # get the mime type
  379. mime_type, _ = mimetypes.guess_type(absolute_path)
  380. mime_type = mime_type or "application/octet-stream"
  381. return absolute_path, mime_type
  382. @classmethod
  383. def list_hardcoded_providers(cls):
  384. # use cache first
  385. if cls._builtin_providers_loaded:
  386. yield from list(cls._hardcoded_providers.values())
  387. return
  388. with cls._builtin_provider_lock:
  389. if cls._builtin_providers_loaded:
  390. yield from list(cls._hardcoded_providers.values())
  391. return
  392. yield from cls._list_hardcoded_providers()
  393. @classmethod
  394. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  395. """
  396. list all the plugin providers
  397. """
  398. manager = PluginToolManager()
  399. provider_entities = manager.fetch_tool_providers(tenant_id)
  400. return [
  401. PluginToolProviderController(
  402. entity=provider.declaration,
  403. plugin_id=provider.plugin_id,
  404. tenant_id=tenant_id,
  405. )
  406. for provider in provider_entities
  407. ]
  408. @classmethod
  409. def list_builtin_providers(
  410. cls, tenant_id: str
  411. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  412. """
  413. list all the builtin providers
  414. """
  415. yield from cls.list_hardcoded_providers()
  416. # get plugin providers
  417. yield from cls.list_plugin_providers(tenant_id)
  418. @classmethod
  419. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  420. """
  421. list all the builtin providers
  422. """
  423. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  424. if provider_path.startswith("__"):
  425. continue
  426. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  427. if provider_path.startswith("__"):
  428. continue
  429. # init provider
  430. try:
  431. provider_class = load_single_subclass_from_source(
  432. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  433. script_path=path.join(
  434. path.dirname(path.realpath(__file__)),
  435. "builtin_tool",
  436. "providers",
  437. provider_path,
  438. f"{provider_path}.py",
  439. ),
  440. parent_type=BuiltinToolProviderController,
  441. )
  442. provider: BuiltinToolProviderController = provider_class()
  443. cls._hardcoded_providers[provider.entity.identity.name] = provider
  444. for tool in provider.get_tools():
  445. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  446. yield provider
  447. except Exception as e:
  448. logger.error(f"load builtin provider error: {e}")
  449. continue
  450. # set builtin providers loaded
  451. cls._builtin_providers_loaded = True
  452. @classmethod
  453. def load_hardcoded_providers_cache(cls):
  454. for _ in cls.list_hardcoded_providers():
  455. pass
  456. @classmethod
  457. def clear_hardcoded_providers_cache(cls):
  458. cls._hardcoded_providers = {}
  459. cls._builtin_providers_loaded = False
  460. @classmethod
  461. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  462. """
  463. get the tool label
  464. :param tool_name: the name of the tool
  465. :return: the label of the tool
  466. """
  467. if len(cls._builtin_tools_labels) == 0:
  468. # init the builtin providers
  469. cls.load_hardcoded_providers_cache()
  470. if tool_name not in cls._builtin_tools_labels:
  471. return None
  472. return cls._builtin_tools_labels[tool_name]
  473. @classmethod
  474. def list_providers_from_api(
  475. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  476. ) -> list[ToolProviderApiEntity]:
  477. result_providers: dict[str, ToolProviderApiEntity] = {}
  478. filters = []
  479. if not typ:
  480. filters.extend(["builtin", "api", "workflow"])
  481. else:
  482. filters.append(typ)
  483. if "builtin" in filters:
  484. # get builtin providers
  485. builtin_providers = cls.list_builtin_providers(tenant_id)
  486. # get db builtin providers
  487. db_builtin_providers: list[BuiltinToolProvider] = (
  488. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  489. )
  490. # rewrite db_builtin_providers
  491. for db_provider in db_builtin_providers:
  492. try:
  493. ToolProviderID(db_provider.provider)
  494. except Exception:
  495. db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
  496. find_db_builtin_provider = lambda provider: next(
  497. (x for x in db_builtin_providers if x.provider == provider), None
  498. )
  499. # append builtin providers
  500. for provider in builtin_providers:
  501. # handle include, exclude
  502. if is_filtered(
  503. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  504. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  505. data=provider,
  506. name_func=lambda x: x.identity.name,
  507. ):
  508. continue
  509. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  510. provider_controller=provider,
  511. db_provider=find_db_builtin_provider(provider.entity.identity.name),
  512. decrypt_credentials=False,
  513. )
  514. if isinstance(provider, PluginToolProviderController):
  515. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  516. else:
  517. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  518. # get db api providers
  519. if "api" in filters:
  520. db_api_providers: list[ApiToolProvider] = (
  521. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  522. )
  523. api_provider_controllers = [
  524. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  525. for provider in db_api_providers
  526. ]
  527. # get labels
  528. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  529. for api_provider_controller in api_provider_controllers:
  530. user_provider = ToolTransformService.api_provider_to_user_provider(
  531. provider_controller=api_provider_controller["controller"],
  532. db_provider=api_provider_controller["provider"],
  533. decrypt_credentials=False,
  534. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  535. )
  536. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  537. if "workflow" in filters:
  538. # get workflow providers
  539. workflow_providers: list[WorkflowToolProvider] = (
  540. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  541. )
  542. workflow_provider_controllers = []
  543. for provider in workflow_providers:
  544. try:
  545. workflow_provider_controllers.append(
  546. ToolTransformService.workflow_provider_to_controller(db_provider=provider)
  547. )
  548. except Exception as e:
  549. # app has been deleted
  550. pass
  551. labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
  552. for provider_controller in workflow_provider_controllers:
  553. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  554. provider_controller=provider_controller,
  555. labels=labels.get(provider_controller.provider_id, []),
  556. )
  557. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  558. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  559. @classmethod
  560. def get_api_provider_controller(
  561. cls, tenant_id: str, provider_id: str
  562. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  563. """
  564. get the api provider
  565. :param provider_name: the name of the provider
  566. :return: the provider controller, the credentials
  567. """
  568. provider: ApiToolProvider | None = (
  569. db.session.query(ApiToolProvider)
  570. .filter(
  571. ApiToolProvider.id == provider_id,
  572. ApiToolProvider.tenant_id == tenant_id,
  573. )
  574. .first()
  575. )
  576. if provider is None:
  577. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  578. controller = ApiToolProviderController.from_db(
  579. provider,
  580. ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  581. )
  582. controller.load_bundled_tools(provider.tools)
  583. return controller, provider.credentials
  584. @classmethod
  585. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  586. """
  587. get api provider
  588. """
  589. provider_obj: ApiToolProvider | None = (
  590. db.session.query(ApiToolProvider)
  591. .filter(
  592. ApiToolProvider.tenant_id == tenant_id,
  593. ApiToolProvider.name == provider,
  594. )
  595. .first()
  596. )
  597. if provider_obj is None:
  598. raise ValueError(f"you have not added provider {provider}")
  599. try:
  600. credentials = json.loads(provider_obj.credentials_str) or {}
  601. except:
  602. credentials = {}
  603. # package tool provider controller
  604. controller = ApiToolProviderController.from_db(
  605. provider_obj,
  606. ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  607. )
  608. # init tool configuration
  609. tool_configuration = ProviderConfigEncrypter(
  610. tenant_id=tenant_id,
  611. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  612. provider_type=controller.provider_type.value,
  613. provider_identity=controller.entity.identity.name,
  614. )
  615. decrypted_credentials = tool_configuration.decrypt(credentials)
  616. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  617. try:
  618. icon = json.loads(provider_obj.icon)
  619. except:
  620. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  621. # add tool labels
  622. labels = ToolLabelManager.get_tool_labels(controller)
  623. return jsonable_encoder(
  624. {
  625. "schema_type": provider_obj.schema_type,
  626. "schema": provider_obj.schema,
  627. "tools": provider_obj.tools,
  628. "icon": icon,
  629. "description": provider_obj.description,
  630. "credentials": masked_credentials,
  631. "privacy_policy": provider_obj.privacy_policy,
  632. "custom_disclaimer": provider_obj.custom_disclaimer,
  633. "labels": labels,
  634. }
  635. )
  636. @classmethod
  637. def get_tool_icon(cls, tenant_id: str, provider_type: ToolProviderType, provider_id: str) -> Union[str, dict]:
  638. """
  639. get the tool icon
  640. :param tenant_id: the id of the tenant
  641. :param provider_type: the type of the provider
  642. :param provider_id: the id of the provider
  643. :return:
  644. """
  645. provider_type = provider_type
  646. provider_id = provider_id
  647. if provider_type == ToolProviderType.BUILT_IN:
  648. return (
  649. dify_config.CONSOLE_API_URL
  650. + "/console/api/workspaces/current/tool-provider/builtin/"
  651. + provider_id
  652. + "/icon"
  653. )
  654. elif provider_type == ToolProviderType.API:
  655. try:
  656. api_provider: ApiToolProvider | None = (
  657. db.session.query(ApiToolProvider)
  658. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  659. .first()
  660. )
  661. if not api_provider:
  662. raise ValueError("api tool not found")
  663. return json.loads(api_provider.icon)
  664. except:
  665. return {"background": "#252525", "content": "\ud83d\ude01"}
  666. elif provider_type == ToolProviderType.WORKFLOW:
  667. workflow_provider: WorkflowToolProvider | None = (
  668. db.session.query(WorkflowToolProvider)
  669. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  670. .first()
  671. )
  672. if workflow_provider is None:
  673. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  674. return json.loads(workflow_provider.icon)
  675. else:
  676. raise ValueError(f"provider type {provider_type} not found")
  677. ToolManager.load_hardcoded_providers_cache()