tool_manager.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. import json
  2. import logging
  3. import mimetypes
  4. from collections.abc import Generator
  5. from os import listdir, path
  6. from threading import Lock
  7. from typing import TYPE_CHECKING, Any, Union, cast
  8. from yarl import URL
  9. import contexts
  10. from core.plugin.entities.plugin import GenericProviderID
  11. from core.plugin.manager.tool import PluginToolManager
  12. from core.tools.__base.tool_runtime import ToolRuntime
  13. from core.tools.plugin_tool.provider import PluginToolProviderController
  14. from core.tools.plugin_tool.tool import PluginTool
  15. if TYPE_CHECKING:
  16. from core.workflow.nodes.tool.entities import ToolEntity
  17. from configs import dify_config
  18. from core.agent.entities import AgentToolEntity
  19. from core.app.entities.app_invoke_entities import InvokeFrom
  20. from core.helper.module_import_helper import load_single_subclass_from_source
  21. from core.helper.position_helper import is_filtered
  22. from core.model_runtime.utils.encoders import jsonable_encoder
  23. from core.tools.__base.tool import Tool
  24. from core.tools.builtin_tool.provider import BuiltinToolProviderController
  25. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  26. from core.tools.builtin_tool.tool import BuiltinTool
  27. from core.tools.custom_tool.provider import ApiToolProviderController
  28. from core.tools.custom_tool.tool import ApiTool
  29. from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
  30. from core.tools.entities.common_entities import I18nObject
  31. from core.tools.entities.tool_entities import (
  32. ApiProviderAuthType,
  33. ToolInvokeFrom,
  34. ToolParameter,
  35. ToolProviderType,
  36. )
  37. from core.tools.errors import ToolProviderNotFoundError
  38. from core.tools.tool_label_manager import ToolLabelManager
  39. from core.tools.utils.configuration import (
  40. ProviderConfigEncrypter,
  41. ToolParameterConfigurationManager,
  42. )
  43. from core.tools.workflow_as_tool.tool import WorkflowTool
  44. from extensions.ext_database import db
  45. from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
  46. from services.tools.tools_transform_service import ToolTransformService
  47. logger = logging.getLogger(__name__)
  48. class ToolManager:
  49. _builtin_provider_lock = Lock()
  50. _hardcoded_providers = {}
  51. _builtin_providers_loaded = False
  52. _builtin_tools_labels = {}
  53. @classmethod
  54. def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
  55. """
  56. get the hardcoded provider
  57. """
  58. if len(cls._hardcoded_providers) == 0:
  59. # init the builtin providers
  60. cls.load_hardcoded_providers_cache()
  61. return cls._hardcoded_providers[provider]
  62. @classmethod
  63. def get_builtin_provider(
  64. cls, provider: str, tenant_id: str
  65. ) -> BuiltinToolProviderController | PluginToolProviderController:
  66. """
  67. get the builtin provider
  68. :param provider: the name of the provider
  69. :param tenant_id: the id of the tenant
  70. :return: the provider
  71. """
  72. # split provider to
  73. if len(cls._hardcoded_providers) == 0:
  74. # init the builtin providers
  75. cls.load_hardcoded_providers_cache()
  76. if provider not in cls._hardcoded_providers:
  77. # get plugin provider
  78. plugin_provider = cls.get_plugin_provider(provider, tenant_id)
  79. if plugin_provider:
  80. return plugin_provider
  81. return cls._hardcoded_providers[provider]
  82. @classmethod
  83. def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
  84. """
  85. get the plugin provider
  86. """
  87. # check if context is set
  88. try:
  89. contexts.plugin_tool_providers.get()
  90. except LookupError:
  91. contexts.plugin_tool_providers.set({})
  92. contexts.plugin_tool_providers_lock.set(Lock())
  93. with contexts.plugin_tool_providers_lock.get():
  94. plugin_tool_providers = contexts.plugin_tool_providers.get()
  95. if provider in plugin_tool_providers:
  96. return plugin_tool_providers[provider]
  97. manager = PluginToolManager()
  98. provider_entity = manager.fetch_tool_provider(tenant_id, provider)
  99. if not provider_entity:
  100. raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
  101. controller = PluginToolProviderController(
  102. entity=provider_entity.declaration,
  103. plugin_id=provider_entity.plugin_id,
  104. plugin_unique_identifier=provider_entity.plugin_unique_identifier,
  105. tenant_id=tenant_id,
  106. )
  107. plugin_tool_providers[provider] = controller
  108. return controller
  109. @classmethod
  110. def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
  111. """
  112. get the builtin tool
  113. :param provider: the name of the provider
  114. :param tool_name: the name of the tool
  115. :param tenant_id: the id of the tenant
  116. :return: the provider, the tool
  117. """
  118. provider_controller = cls.get_builtin_provider(provider, tenant_id)
  119. tool = provider_controller.get_tool(tool_name)
  120. return tool
  121. @classmethod
  122. def get_tool_runtime(
  123. cls,
  124. provider_type: ToolProviderType,
  125. provider_id: str,
  126. tool_name: str,
  127. tenant_id: str,
  128. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  129. tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
  130. ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]:
  131. """
  132. get the tool runtime
  133. :param provider_type: the type of the provider
  134. :param provider_name: the name of the provider
  135. :param tool_name: the name of the tool
  136. :return: the tool
  137. """
  138. if provider_type == ToolProviderType.BUILT_IN:
  139. # check if the builtin tool need credentials
  140. provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
  141. builtin_tool = provider_controller.get_tool(tool_name)
  142. if not builtin_tool:
  143. raise ToolProviderNotFoundError(f"builtin tool {tool_name} not found")
  144. if not provider_controller.need_credentials:
  145. return cast(
  146. BuiltinTool,
  147. builtin_tool.fork_tool_runtime(
  148. runtime=ToolRuntime(
  149. tenant_id=tenant_id,
  150. credentials={},
  151. invoke_from=invoke_from,
  152. tool_invoke_from=tool_invoke_from,
  153. )
  154. ),
  155. )
  156. if isinstance(provider_controller, PluginToolProviderController):
  157. provider_id_entity = GenericProviderID(provider_id)
  158. # get credentials
  159. builtin_provider: BuiltinToolProvider | None = (
  160. db.session.query(BuiltinToolProvider)
  161. .filter(
  162. BuiltinToolProvider.tenant_id == tenant_id,
  163. (BuiltinToolProvider.provider == provider_id)
  164. | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
  165. )
  166. .first()
  167. )
  168. if builtin_provider is None:
  169. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  170. else:
  171. builtin_provider: BuiltinToolProvider | None = (
  172. db.session.query(BuiltinToolProvider)
  173. .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
  174. .first()
  175. )
  176. if builtin_provider is None:
  177. raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
  178. # decrypt the credentials
  179. credentials = builtin_provider.credentials
  180. tool_configuration = ProviderConfigEncrypter(
  181. tenant_id=tenant_id,
  182. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  183. provider_type=provider_controller.provider_type.value,
  184. provider_identity=provider_controller.entity.identity.name,
  185. )
  186. decrypted_credentials = tool_configuration.decrypt(credentials)
  187. return cast(
  188. BuiltinTool,
  189. builtin_tool.fork_tool_runtime(
  190. runtime=ToolRuntime(
  191. tenant_id=tenant_id,
  192. credentials=decrypted_credentials,
  193. runtime_parameters={},
  194. invoke_from=invoke_from,
  195. tool_invoke_from=tool_invoke_from,
  196. )
  197. ),
  198. )
  199. elif provider_type == ToolProviderType.API:
  200. api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
  201. # decrypt the credentials
  202. tool_configuration = ProviderConfigEncrypter(
  203. tenant_id=tenant_id,
  204. config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
  205. provider_type=api_provider.provider_type.value,
  206. provider_identity=api_provider.entity.identity.name,
  207. )
  208. decrypted_credentials = tool_configuration.decrypt(credentials)
  209. return cast(
  210. ApiTool,
  211. api_provider.get_tool(tool_name).fork_tool_runtime(
  212. runtime=ToolRuntime(
  213. tenant_id=tenant_id,
  214. credentials=decrypted_credentials,
  215. invoke_from=invoke_from,
  216. tool_invoke_from=tool_invoke_from,
  217. )
  218. ),
  219. )
  220. elif provider_type == ToolProviderType.WORKFLOW:
  221. workflow_provider = (
  222. db.session.query(WorkflowToolProvider)
  223. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  224. .first()
  225. )
  226. if workflow_provider is None:
  227. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  228. controller = ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider)
  229. return cast(
  230. WorkflowTool,
  231. controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(
  232. runtime=ToolRuntime(
  233. tenant_id=tenant_id,
  234. credentials={},
  235. invoke_from=invoke_from,
  236. tool_invoke_from=tool_invoke_from,
  237. )
  238. ),
  239. )
  240. elif provider_type == ToolProviderType.APP:
  241. raise NotImplementedError("app provider not implemented")
  242. elif provider_type == ToolProviderType.PLUGIN:
  243. return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
  244. else:
  245. raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
  246. @classmethod
  247. def _init_runtime_parameter(cls, parameter_rule: ToolParameter, parameters: dict):
  248. """
  249. init runtime parameter
  250. """
  251. parameter_value = parameters.get(parameter_rule.name)
  252. if not parameter_value and parameter_value != 0:
  253. # get default value
  254. parameter_value = parameter_rule.default
  255. if not parameter_value and parameter_rule.required:
  256. raise ValueError(f"tool parameter {parameter_rule.name} not found in tool config")
  257. if parameter_rule.type == ToolParameter.ToolParameterType.SELECT:
  258. # check if tool_parameter_config in options
  259. options = [x.value for x in parameter_rule.options]
  260. if parameter_value is not None and parameter_value not in options:
  261. raise ValueError(
  262. f"tool parameter {parameter_rule.name} value {parameter_value} not in options {options}"
  263. )
  264. return parameter_rule.type.cast_value(parameter_value)
  265. @classmethod
  266. def get_agent_tool_runtime(
  267. cls,
  268. tenant_id: str,
  269. app_id: str,
  270. agent_tool: AgentToolEntity,
  271. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  272. ) -> Tool:
  273. """
  274. get the agent tool runtime
  275. """
  276. tool_entity = cls.get_tool_runtime(
  277. provider_type=agent_tool.provider_type,
  278. provider_id=agent_tool.provider_id,
  279. tool_name=agent_tool.tool_name,
  280. tenant_id=tenant_id,
  281. invoke_from=invoke_from,
  282. tool_invoke_from=ToolInvokeFrom.AGENT,
  283. )
  284. runtime_parameters = {}
  285. parameters = tool_entity.get_merged_runtime_parameters()
  286. for parameter in parameters:
  287. # check file types
  288. if (
  289. parameter.type
  290. in {
  291. ToolParameter.ToolParameterType.SYSTEM_FILES,
  292. ToolParameter.ToolParameterType.FILE,
  293. ToolParameter.ToolParameterType.FILES,
  294. }
  295. and parameter.required
  296. ):
  297. raise ValueError(f"file type parameter {parameter.name} not supported in agent")
  298. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  299. # save tool parameter to tool entity memory
  300. value = cls._init_runtime_parameter(parameter, agent_tool.tool_parameters)
  301. runtime_parameters[parameter.name] = value
  302. # decrypt runtime parameters
  303. encryption_manager = ToolParameterConfigurationManager(
  304. tenant_id=tenant_id,
  305. tool_runtime=tool_entity,
  306. provider_name=agent_tool.provider_id,
  307. provider_type=agent_tool.provider_type,
  308. identity_id=f"AGENT.{app_id}",
  309. )
  310. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  311. if not tool_entity.runtime:
  312. raise Exception("tool missing runtime")
  313. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  314. return tool_entity
  315. @classmethod
  316. def get_workflow_tool_runtime(
  317. cls,
  318. tenant_id: str,
  319. app_id: str,
  320. node_id: str,
  321. workflow_tool: "ToolEntity",
  322. invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
  323. ) -> Tool:
  324. """
  325. get the workflow tool runtime
  326. """
  327. tool_runtime = cls.get_tool_runtime(
  328. provider_type=workflow_tool.provider_type,
  329. provider_id=workflow_tool.provider_id,
  330. tool_name=workflow_tool.tool_name,
  331. tenant_id=tenant_id,
  332. invoke_from=invoke_from,
  333. tool_invoke_from=ToolInvokeFrom.WORKFLOW,
  334. )
  335. runtime_parameters = {}
  336. parameters = tool_runtime.get_merged_runtime_parameters()
  337. for parameter in parameters:
  338. # save tool parameter to tool entity memory
  339. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  340. value = cls._init_runtime_parameter(parameter, workflow_tool.tool_configurations)
  341. runtime_parameters[parameter.name] = value
  342. # decrypt runtime parameters
  343. encryption_manager = ToolParameterConfigurationManager(
  344. tenant_id=tenant_id,
  345. tool_runtime=tool_runtime,
  346. provider_name=workflow_tool.provider_id,
  347. provider_type=workflow_tool.provider_type,
  348. identity_id=f"WORKFLOW.{app_id}.{node_id}",
  349. )
  350. if runtime_parameters:
  351. runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
  352. if not tool_runtime.runtime:
  353. raise Exception("tool missing runtime")
  354. tool_runtime.runtime.runtime_parameters.update(runtime_parameters)
  355. return tool_runtime
  356. @classmethod
  357. def get_tool_runtime_from_plugin(
  358. cls,
  359. tool_type: ToolProviderType,
  360. tenant_id: str,
  361. provider: str,
  362. tool_name: str,
  363. tool_parameters: dict[str, Any],
  364. ) -> Tool:
  365. """
  366. get tool runtime from plugin
  367. """
  368. tool_entity = cls.get_tool_runtime(
  369. provider_type=tool_type,
  370. provider_id=provider,
  371. tool_name=tool_name,
  372. tenant_id=tenant_id,
  373. invoke_from=InvokeFrom.SERVICE_API,
  374. tool_invoke_from=ToolInvokeFrom.PLUGIN,
  375. )
  376. runtime_parameters = {}
  377. parameters = tool_entity.get_merged_runtime_parameters()
  378. for parameter in parameters:
  379. if parameter.form == ToolParameter.ToolParameterForm.FORM:
  380. # save tool parameter to tool entity memory
  381. value = cls._init_runtime_parameter(parameter, tool_parameters)
  382. runtime_parameters[parameter.name] = value
  383. if not tool_entity.runtime:
  384. raise Exception("tool missing runtime")
  385. tool_entity.runtime.runtime_parameters.update(runtime_parameters)
  386. return tool_entity
  387. @classmethod
  388. def get_hardcoded_provider_icon(cls, provider: str) -> tuple[str, str]:
  389. """
  390. get the absolute path of the icon of the hardcoded provider
  391. :param provider: the name of the provider
  392. :param tenant_id: the id of the tenant
  393. :return: the absolute path of the icon, the mime type of the icon
  394. """
  395. # get provider
  396. provider_controller = cls.get_hardcoded_provider(provider)
  397. absolute_path = path.join(
  398. path.dirname(path.realpath(__file__)),
  399. "builtin_tool",
  400. "providers",
  401. provider,
  402. "_assets",
  403. provider_controller.entity.identity.icon,
  404. )
  405. # check if the icon exists
  406. if not path.exists(absolute_path):
  407. raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found")
  408. # get the mime type
  409. mime_type, _ = mimetypes.guess_type(absolute_path)
  410. mime_type = mime_type or "application/octet-stream"
  411. return absolute_path, mime_type
  412. @classmethod
  413. def list_hardcoded_providers(cls):
  414. # use cache first
  415. if cls._builtin_providers_loaded:
  416. yield from list(cls._hardcoded_providers.values())
  417. return
  418. with cls._builtin_provider_lock:
  419. if cls._builtin_providers_loaded:
  420. yield from list(cls._hardcoded_providers.values())
  421. return
  422. yield from cls._list_hardcoded_providers()
  423. @classmethod
  424. def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
  425. """
  426. list all the plugin providers
  427. """
  428. manager = PluginToolManager()
  429. provider_entities = manager.fetch_tool_providers(tenant_id)
  430. return [
  431. PluginToolProviderController(
  432. entity=provider.declaration,
  433. plugin_id=provider.plugin_id,
  434. plugin_unique_identifier=provider.plugin_unique_identifier,
  435. tenant_id=tenant_id,
  436. )
  437. for provider in provider_entities
  438. ]
  439. @classmethod
  440. def list_builtin_providers(
  441. cls, tenant_id: str
  442. ) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
  443. """
  444. list all the builtin providers
  445. """
  446. yield from cls.list_hardcoded_providers()
  447. # get plugin providers
  448. yield from cls.list_plugin_providers(tenant_id)
  449. @classmethod
  450. def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
  451. """
  452. list all the builtin providers
  453. """
  454. for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers")):
  455. if provider_path.startswith("__"):
  456. continue
  457. if path.isdir(path.join(path.dirname(path.realpath(__file__)), "builtin_tool", "providers", provider_path)):
  458. if provider_path.startswith("__"):
  459. continue
  460. # init provider
  461. try:
  462. provider_class = load_single_subclass_from_source(
  463. module_name=f"core.tools.builtin_tool.providers.{provider_path}.{provider_path}",
  464. script_path=path.join(
  465. path.dirname(path.realpath(__file__)),
  466. "builtin_tool",
  467. "providers",
  468. provider_path,
  469. f"{provider_path}.py",
  470. ),
  471. parent_type=BuiltinToolProviderController,
  472. )
  473. provider: BuiltinToolProviderController = provider_class()
  474. cls._hardcoded_providers[provider.entity.identity.name] = provider
  475. for tool in provider.get_tools():
  476. cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
  477. yield provider
  478. except Exception as e:
  479. logger.exception(f"load builtin provider {provider}")
  480. continue
  481. # set builtin providers loaded
  482. cls._builtin_providers_loaded = True
  483. @classmethod
  484. def load_hardcoded_providers_cache(cls):
  485. for _ in cls.list_hardcoded_providers():
  486. pass
  487. @classmethod
  488. def clear_hardcoded_providers_cache(cls):
  489. cls._hardcoded_providers = {}
  490. cls._builtin_providers_loaded = False
  491. @classmethod
  492. def get_tool_label(cls, tool_name: str) -> Union[I18nObject, None]:
  493. """
  494. get the tool label
  495. :param tool_name: the name of the tool
  496. :return: the label of the tool
  497. """
  498. if len(cls._builtin_tools_labels) == 0:
  499. # init the builtin providers
  500. cls.load_hardcoded_providers_cache()
  501. if tool_name not in cls._builtin_tools_labels:
  502. return None
  503. return cls._builtin_tools_labels[tool_name]
  504. @classmethod
  505. def list_providers_from_api(
  506. cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
  507. ) -> list[ToolProviderApiEntity]:
  508. result_providers: dict[str, ToolProviderApiEntity] = {}
  509. filters = []
  510. if not typ:
  511. filters.extend(["builtin", "api", "workflow"])
  512. else:
  513. filters.append(typ)
  514. if "builtin" in filters:
  515. # get builtin providers
  516. builtin_providers = cls.list_builtin_providers(tenant_id)
  517. # get db builtin providers
  518. db_builtin_providers: list[BuiltinToolProvider] = (
  519. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
  520. )
  521. # rewrite db_builtin_providers
  522. for db_provider in db_builtin_providers:
  523. tool_provider_id = GenericProviderID(db_provider.provider)
  524. db_provider.provider = tool_provider_id.to_string()
  525. find_db_builtin_provider = lambda provider: next(
  526. (x for x in db_builtin_providers if x.provider == provider), None
  527. )
  528. # append builtin providers
  529. for provider in builtin_providers:
  530. # handle include, exclude
  531. if is_filtered(
  532. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  533. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  534. data=provider,
  535. name_func=lambda x: x.identity.name,
  536. ):
  537. continue
  538. user_provider = ToolTransformService.builtin_provider_to_user_provider(
  539. provider_controller=provider,
  540. db_provider=find_db_builtin_provider(provider.entity.identity.name),
  541. decrypt_credentials=False,
  542. )
  543. if isinstance(provider, PluginToolProviderController):
  544. result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
  545. else:
  546. result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
  547. # get db api providers
  548. if "api" in filters:
  549. db_api_providers: list[ApiToolProvider] = (
  550. db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
  551. )
  552. api_provider_controllers = [
  553. {"provider": provider, "controller": ToolTransformService.api_provider_to_controller(provider)}
  554. for provider in db_api_providers
  555. ]
  556. # get labels
  557. labels = ToolLabelManager.get_tools_labels([x["controller"] for x in api_provider_controllers])
  558. for api_provider_controller in api_provider_controllers:
  559. user_provider = ToolTransformService.api_provider_to_user_provider(
  560. provider_controller=api_provider_controller["controller"],
  561. db_provider=api_provider_controller["provider"],
  562. decrypt_credentials=False,
  563. labels=labels.get(api_provider_controller["controller"].provider_id, []),
  564. )
  565. result_providers[f"api_provider.{user_provider.name}"] = user_provider
  566. if "workflow" in filters:
  567. # get workflow providers
  568. workflow_providers: list[WorkflowToolProvider] = (
  569. db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
  570. )
  571. workflow_provider_controllers = []
  572. for provider in workflow_providers:
  573. try:
  574. workflow_provider_controllers.append(
  575. ToolTransformService.workflow_provider_to_controller(db_provider=provider)
  576. )
  577. except Exception as e:
  578. # app has been deleted
  579. pass
  580. labels = ToolLabelManager.get_tools_labels(workflow_provider_controllers)
  581. for provider_controller in workflow_provider_controllers:
  582. user_provider = ToolTransformService.workflow_provider_to_user_provider(
  583. provider_controller=provider_controller,
  584. labels=labels.get(provider_controller.provider_id, []),
  585. )
  586. result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
  587. return BuiltinToolProviderSort.sort(list(result_providers.values()))
  588. @classmethod
  589. def get_api_provider_controller(
  590. cls, tenant_id: str, provider_id: str
  591. ) -> tuple[ApiToolProviderController, dict[str, Any]]:
  592. """
  593. get the api provider
  594. :param provider_name: the name of the provider
  595. :return: the provider controller, the credentials
  596. """
  597. provider: ApiToolProvider | None = (
  598. db.session.query(ApiToolProvider)
  599. .filter(
  600. ApiToolProvider.id == provider_id,
  601. ApiToolProvider.tenant_id == tenant_id,
  602. )
  603. .first()
  604. )
  605. if provider is None:
  606. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  607. controller = ApiToolProviderController.from_db(
  608. provider,
  609. ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  610. )
  611. controller.load_bundled_tools(provider.tools)
  612. return controller, provider.credentials
  613. @classmethod
  614. def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
  615. """
  616. get api provider
  617. """
  618. """
  619. get tool provider
  620. """
  621. provider_name = provider
  622. provider_obj: ApiToolProvider = (
  623. db.session.query(ApiToolProvider)
  624. .filter(
  625. ApiToolProvider.tenant_id == tenant_id,
  626. ApiToolProvider.name == provider,
  627. )
  628. .first()
  629. )
  630. if provider_obj is None:
  631. raise ValueError(f"you have not added provider {provider_name}")
  632. try:
  633. credentials = json.loads(provider_obj.credentials_str) or {}
  634. except:
  635. credentials = {}
  636. # package tool provider controller
  637. controller = ApiToolProviderController.from_db(
  638. provider_obj,
  639. ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
  640. )
  641. # init tool configuration
  642. tool_configuration = ProviderConfigEncrypter(
  643. tenant_id=tenant_id,
  644. config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
  645. provider_type=controller.provider_type.value,
  646. provider_identity=controller.entity.identity.name,
  647. )
  648. decrypted_credentials = tool_configuration.decrypt(credentials)
  649. masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
  650. try:
  651. icon = json.loads(provider_obj.icon)
  652. except:
  653. icon = {"background": "#252525", "content": "\ud83d\ude01"}
  654. # add tool labels
  655. labels = ToolLabelManager.get_tool_labels(controller)
  656. return jsonable_encoder(
  657. {
  658. "schema_type": provider_obj.schema_type,
  659. "schema": provider_obj.schema,
  660. "tools": provider_obj.tools,
  661. "icon": icon,
  662. "description": provider_obj.description,
  663. "credentials": masked_credentials,
  664. "privacy_policy": provider_obj.privacy_policy,
  665. "custom_disclaimer": provider_obj.custom_disclaimer,
  666. "labels": labels,
  667. }
  668. )
  669. @classmethod
  670. def generate_builtin_tool_icon_url(cls, provider_id: str) -> str:
  671. return (
  672. dify_config.CONSOLE_API_URL
  673. + "/console/api/workspaces/current/tool-provider/builtin/"
  674. + provider_id
  675. + "/icon"
  676. )
  677. @classmethod
  678. def generate_plugin_tool_icon_url(cls, tenant_id: str, filename: str) -> str:
  679. return str(
  680. URL(dify_config.CONSOLE_API_URL)
  681. / "console"
  682. / "api"
  683. / "workspaces"
  684. / "current"
  685. / "plugin"
  686. / "icon"
  687. % {"tenant_id": tenant_id, "filename": filename}
  688. )
  689. @classmethod
  690. def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  691. try:
  692. workflow_provider: WorkflowToolProvider | None = (
  693. db.session.query(WorkflowToolProvider)
  694. .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
  695. .first()
  696. )
  697. if workflow_provider is None:
  698. raise ToolProviderNotFoundError(f"workflow provider {provider_id} not found")
  699. return json.loads(workflow_provider.icon)
  700. except:
  701. return {"background": "#252525", "content": "\ud83d\ude01"}
  702. @classmethod
  703. def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict:
  704. try:
  705. api_provider: ApiToolProvider | None = (
  706. db.session.query(ApiToolProvider)
  707. .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
  708. .first()
  709. )
  710. if api_provider is None:
  711. raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
  712. return json.loads(api_provider.icon)
  713. except:
  714. return {"background": "#252525", "content": "\ud83d\ude01"}
  715. @classmethod
  716. def get_tool_icon(
  717. cls,
  718. tenant_id: str,
  719. provider_type: ToolProviderType,
  720. provider_id: str,
  721. ) -> Union[str, dict]:
  722. """
  723. get the tool icon
  724. :param tenant_id: the id of the tenant
  725. :param provider_type: the type of the provider
  726. :param provider_id: the id of the provider
  727. :return:
  728. """
  729. provider_type = provider_type
  730. provider_id = provider_id
  731. if provider_type == ToolProviderType.BUILT_IN:
  732. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  733. if isinstance(provider, PluginToolProviderController):
  734. try:
  735. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  736. except:
  737. return {"background": "#252525", "content": "\ud83d\ude01"}
  738. return cls.generate_builtin_tool_icon_url(provider_id)
  739. elif provider_type == ToolProviderType.API:
  740. return cls.generate_api_tool_icon_url(tenant_id, provider_id)
  741. elif provider_type == ToolProviderType.WORKFLOW:
  742. return cls.generate_workflow_tool_icon_url(tenant_id, provider_id)
  743. elif provider_type == ToolProviderType.PLUGIN:
  744. provider = ToolManager.get_builtin_provider(provider_id, tenant_id)
  745. if isinstance(provider, PluginToolProviderController):
  746. try:
  747. return cls.generate_plugin_tool_icon_url(tenant_id, provider.entity.identity.icon)
  748. except:
  749. return {"background": "#252525", "content": "\ud83d\ude01"}
  750. raise ValueError(f"plugin provider {provider_id} not found")
  751. else:
  752. raise ValueError(f"provider type {provider_type} not found")
  753. ToolManager.load_hardcoded_providers_cache()