builtin_tools_manage_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import json
  2. import logging
  3. from pathlib import Path
  4. from sqlalchemy.orm import Session
  5. from configs import dify_config
  6. from core.helper.position_helper import is_filtered
  7. from core.model_runtime.utils.encoders import jsonable_encoder
  8. from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
  9. from core.plugin.manager.exc import PluginDaemonClientSideError
  10. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  11. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
  12. from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
  13. from core.tools.tool_label_manager import ToolLabelManager
  14. from core.tools.tool_manager import ToolManager
  15. from core.tools.utils.configuration import ProviderConfigEncrypter
  16. from extensions.ext_database import db
  17. from models.tools import BuiltinToolProvider
  18. from services.tools.tools_transform_service import ToolTransformService
  19. logger = logging.getLogger(__name__)
  20. class BuiltinToolManageService:
  21. @staticmethod
  22. def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
  23. """
  24. list builtin tool provider tools
  25. :param user_id: the id of the user
  26. :param tenant_id: the id of the tenant
  27. :param provider: the name of the provider
  28. :return: the list of tools
  29. """
  30. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  31. tools = provider_controller.get_tools()
  32. tool_provider_configurations = ProviderConfigEncrypter(
  33. tenant_id=tenant_id,
  34. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  35. provider_type=provider_controller.provider_type.value,
  36. provider_identity=provider_controller.entity.identity.name,
  37. )
  38. # check if user has added the provider
  39. builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  40. credentials = {}
  41. if builtin_provider is not None:
  42. # get credentials
  43. credentials = builtin_provider.credentials
  44. credentials = tool_provider_configurations.decrypt(credentials)
  45. result: list[ToolApiEntity] = []
  46. for tool in tools or []:
  47. result.append(
  48. ToolTransformService.convert_tool_entity_to_api_entity(
  49. tool=tool,
  50. credentials=credentials,
  51. tenant_id=tenant_id,
  52. labels=ToolLabelManager.get_tool_labels(provider_controller),
  53. )
  54. )
  55. return result
  56. @staticmethod
  57. def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
  58. """
  59. get builtin tool provider info
  60. """
  61. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  62. tool_provider_configurations = ProviderConfigEncrypter(
  63. tenant_id=tenant_id,
  64. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  65. provider_type=provider_controller.provider_type.value,
  66. provider_identity=provider_controller.entity.identity.name,
  67. )
  68. # check if user has added the provider
  69. builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  70. credentials = {}
  71. if builtin_provider is not None:
  72. # get credentials
  73. credentials = builtin_provider.credentials
  74. credentials = tool_provider_configurations.decrypt(credentials)
  75. entity = ToolTransformService.builtin_provider_to_user_provider(
  76. provider_controller=provider_controller,
  77. db_provider=builtin_provider,
  78. decrypt_credentials=True,
  79. )
  80. entity.original_credentials = {}
  81. return entity
  82. @staticmethod
  83. def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
  84. """
  85. list builtin provider credentials schema
  86. :param provider_name: the name of the provider
  87. :param tenant_id: the id of the tenant
  88. :return: the list of tool providers
  89. """
  90. provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
  91. return jsonable_encoder(provider.get_credentials_schema())
  92. @staticmethod
  93. def update_builtin_tool_provider(
  94. session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
  95. ):
  96. """
  97. update builtin tool provider
  98. """
  99. # get if the provider exists
  100. provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  101. try:
  102. # get provider
  103. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  104. if not provider_controller.need_credentials:
  105. raise ValueError(f"provider {provider_name} does not need credentials")
  106. tool_configuration = ProviderConfigEncrypter(
  107. tenant_id=tenant_id,
  108. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  109. provider_type=provider_controller.provider_type.value,
  110. provider_identity=provider_controller.entity.identity.name,
  111. )
  112. # get original credentials if exists
  113. if provider is not None:
  114. original_credentials = tool_configuration.decrypt(provider.credentials)
  115. masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
  116. # check if the credential has changed, save the original credential
  117. for name, value in credentials.items():
  118. if name in masked_credentials and value == masked_credentials[name]:
  119. credentials[name] = original_credentials[name]
  120. # validate credentials
  121. provider_controller.validate_credentials(user_id, credentials)
  122. # encrypt credentials
  123. credentials = tool_configuration.encrypt(credentials)
  124. except (
  125. PluginDaemonClientSideError,
  126. ToolProviderNotFoundError,
  127. ToolNotFoundError,
  128. ToolProviderCredentialValidationError,
  129. ) as e:
  130. raise ValueError(str(e))
  131. if provider is None:
  132. # create provider
  133. provider = BuiltinToolProvider(
  134. tenant_id=tenant_id,
  135. user_id=user_id,
  136. provider=provider_name,
  137. encrypted_credentials=json.dumps(credentials),
  138. )
  139. db.session.add(provider)
  140. else:
  141. provider.encrypted_credentials = json.dumps(credentials)
  142. # delete cache
  143. tool_configuration.delete_tool_credentials_cache()
  144. db.session.commit()
  145. return {"result": "success"}
  146. @staticmethod
  147. def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
  148. """
  149. get builtin tool provider credentials
  150. """
  151. provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  152. if provider_obj is None:
  153. return {}
  154. provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
  155. tool_configuration = ProviderConfigEncrypter(
  156. tenant_id=tenant_id,
  157. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  158. provider_type=provider_controller.provider_type.value,
  159. provider_identity=provider_controller.entity.identity.name,
  160. )
  161. credentials = tool_configuration.decrypt(provider_obj.credentials)
  162. credentials = tool_configuration.mask_tool_credentials(credentials)
  163. return credentials
  164. @staticmethod
  165. def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
  166. """
  167. delete tool provider
  168. """
  169. provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  170. if provider_obj is None:
  171. raise ValueError(f"you have not added provider {provider_name}")
  172. db.session.delete(provider_obj)
  173. db.session.commit()
  174. # delete cache
  175. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  176. tool_configuration = ProviderConfigEncrypter(
  177. tenant_id=tenant_id,
  178. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  179. provider_type=provider_controller.provider_type.value,
  180. provider_identity=provider_controller.entity.identity.name,
  181. )
  182. tool_configuration.delete_tool_credentials_cache()
  183. return {"result": "success"}
  184. @staticmethod
  185. def get_builtin_tool_provider_icon(provider: str):
  186. """
  187. get tool provider icon and it's mimetype
  188. """
  189. icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
  190. icon_bytes = Path(icon_path).read_bytes()
  191. return icon_bytes, mime_type
  192. @staticmethod
  193. def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
  194. """
  195. list builtin tools
  196. """
  197. # get all builtin providers
  198. provider_controllers = ToolManager.list_builtin_providers(tenant_id)
  199. # get all user added providers
  200. db_providers: list[BuiltinToolProvider] = (
  201. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
  202. )
  203. # rewrite db_providers
  204. for db_provider in db_providers:
  205. db_provider.provider = str(ToolProviderID(db_provider.provider))
  206. # find provider
  207. def find_provider(provider):
  208. return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
  209. result: list[ToolProviderApiEntity] = []
  210. for provider_controller in provider_controllers:
  211. try:
  212. # handle include, exclude
  213. if is_filtered(
  214. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  215. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  216. data=provider_controller,
  217. name_func=lambda x: x.identity.name,
  218. ):
  219. continue
  220. # convert provider controller to user provider
  221. user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
  222. provider_controller=provider_controller,
  223. db_provider=find_provider(provider_controller.entity.identity.name),
  224. decrypt_credentials=True,
  225. )
  226. # add icon
  227. ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
  228. tools = provider_controller.get_tools()
  229. for tool in tools or []:
  230. user_builtin_provider.tools.append(
  231. ToolTransformService.convert_tool_entity_to_api_entity(
  232. tenant_id=tenant_id,
  233. tool=tool,
  234. credentials=user_builtin_provider.original_credentials,
  235. labels=ToolLabelManager.get_tool_labels(provider_controller),
  236. )
  237. )
  238. result.append(user_builtin_provider)
  239. except Exception as e:
  240. raise e
  241. return BuiltinToolProviderSort.sort(result)
  242. @staticmethod
  243. def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
  244. try:
  245. full_provider_name = provider_name
  246. provider_id_entity = GenericProviderID(provider_name)
  247. provider_name = provider_id_entity.provider_name
  248. if provider_id_entity.organization != "langgenius":
  249. provider_obj = (
  250. db.session.query(BuiltinToolProvider)
  251. .filter(
  252. BuiltinToolProvider.tenant_id == tenant_id,
  253. BuiltinToolProvider.provider == full_provider_name,
  254. )
  255. .first()
  256. )
  257. else:
  258. provider_obj = (
  259. db.session.query(BuiltinToolProvider)
  260. .filter(
  261. BuiltinToolProvider.tenant_id == tenant_id,
  262. (BuiltinToolProvider.provider == provider_name)
  263. | (BuiltinToolProvider.provider == full_provider_name),
  264. )
  265. .first()
  266. )
  267. if provider_obj is None:
  268. return None
  269. provider_obj.provider = GenericProviderID(provider_obj.provider).to_string()
  270. return provider_obj
  271. except Exception:
  272. # it's an old provider without organization
  273. return (
  274. db.session.query(BuiltinToolProvider)
  275. .filter(
  276. BuiltinToolProvider.tenant_id == tenant_id,
  277. (BuiltinToolProvider.provider == provider_name),
  278. )
  279. .first()
  280. )