builtin_tools_manage_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. import json
  2. import logging
  3. from pathlib import Path
  4. from configs import dify_config
  5. from core.helper.position_helper import is_filtered
  6. from core.model_runtime.utils.encoders import jsonable_encoder
  7. from core.plugin.entities.plugin import GenericProviderID
  8. from core.plugin.manager.exc import PluginDaemonClientSideError
  9. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  10. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
  11. from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
  12. from core.tools.tool_label_manager import ToolLabelManager
  13. from core.tools.tool_manager import ToolManager
  14. from core.tools.utils.configuration import ProviderConfigEncrypter
  15. from extensions.ext_database import db
  16. from models.tools import BuiltinToolProvider
  17. from services.tools.tools_transform_service import ToolTransformService
  18. logger = logging.getLogger(__name__)
  19. class BuiltinToolManageService:
  20. @staticmethod
  21. def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[ToolApiEntity]:
  22. """
  23. list builtin tool provider tools
  24. :param user_id: the id of the user
  25. :param tenant_id: the id of the tenant
  26. :param provider: the name of the provider
  27. :return: the list of tools
  28. """
  29. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  30. tools = provider_controller.get_tools()
  31. tool_provider_configurations = ProviderConfigEncrypter(
  32. tenant_id=tenant_id,
  33. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  34. provider_type=provider_controller.provider_type.value,
  35. provider_identity=provider_controller.entity.identity.name,
  36. )
  37. # check if user has added the provider
  38. builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  39. credentials = {}
  40. if builtin_provider is not None:
  41. # get credentials
  42. credentials = builtin_provider.credentials
  43. credentials = tool_provider_configurations.decrypt(credentials)
  44. result = []
  45. for tool in tools:
  46. result.append(
  47. ToolTransformService.convert_tool_entity_to_api_entity(
  48. tool=tool,
  49. credentials=credentials,
  50. tenant_id=tenant_id,
  51. labels=ToolLabelManager.get_tool_labels(provider_controller),
  52. )
  53. )
  54. return result
  55. @staticmethod
  56. def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
  57. """
  58. get builtin tool provider info
  59. """
  60. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  61. tool_provider_configurations = ProviderConfigEncrypter(
  62. tenant_id=tenant_id,
  63. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  64. provider_type=provider_controller.provider_type.value,
  65. provider_identity=provider_controller.entity.identity.name,
  66. )
  67. # check if user has added the provider
  68. builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  69. credentials = {}
  70. if builtin_provider is not None:
  71. # get credentials
  72. credentials = builtin_provider.credentials
  73. credentials = tool_provider_configurations.decrypt(credentials)
  74. entity = ToolTransformService.builtin_provider_to_user_provider(
  75. provider_controller=provider_controller,
  76. db_provider=builtin_provider,
  77. decrypt_credentials=True,
  78. )
  79. entity.original_credentials = {}
  80. return entity
  81. @staticmethod
  82. def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
  83. """
  84. list builtin provider credentials schema
  85. :param provider_name: the name of the provider
  86. :param tenant_id: the id of the tenant
  87. :return: the list of tool providers
  88. """
  89. provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
  90. return jsonable_encoder(provider.get_credentials_schema())
  91. @staticmethod
  92. def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
  93. """
  94. update builtin tool provider
  95. """
  96. # get if the provider exists
  97. provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  98. try:
  99. # get provider
  100. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  101. if not provider_controller.need_credentials:
  102. raise ValueError(f"provider {provider_name} does not need credentials")
  103. tool_configuration = ProviderConfigEncrypter(
  104. tenant_id=tenant_id,
  105. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  106. provider_type=provider_controller.provider_type.value,
  107. provider_identity=provider_controller.entity.identity.name,
  108. )
  109. # get original credentials if exists
  110. if provider is not None:
  111. original_credentials = tool_configuration.decrypt(provider.credentials)
  112. masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
  113. # check if the credential has changed, save the original credential
  114. for name, value in credentials.items():
  115. if name in masked_credentials and value == masked_credentials[name]:
  116. credentials[name] = original_credentials[name]
  117. # validate credentials
  118. provider_controller.validate_credentials(user_id, credentials)
  119. # encrypt credentials
  120. credentials = tool_configuration.encrypt(credentials)
  121. except (
  122. PluginDaemonClientSideError,
  123. ToolProviderNotFoundError,
  124. ToolNotFoundError,
  125. ToolProviderCredentialValidationError,
  126. ) as e:
  127. raise ValueError(str(e))
  128. if provider is None:
  129. # create provider
  130. provider = BuiltinToolProvider(
  131. tenant_id=tenant_id,
  132. user_id=user_id,
  133. provider=provider_name,
  134. encrypted_credentials=json.dumps(credentials),
  135. )
  136. db.session.add(provider)
  137. db.session.commit()
  138. else:
  139. provider.encrypted_credentials = json.dumps(credentials)
  140. db.session.add(provider)
  141. db.session.commit()
  142. # delete cache
  143. tool_configuration.delete_tool_credentials_cache()
  144. return {"result": "success"}
  145. @staticmethod
  146. def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
  147. """
  148. get builtin tool provider credentials
  149. """
  150. provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  151. if provider_obj is None:
  152. return {}
  153. provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
  154. tool_configuration = ProviderConfigEncrypter(
  155. tenant_id=tenant_id,
  156. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  157. provider_type=provider_controller.provider_type.value,
  158. provider_identity=provider_controller.entity.identity.name,
  159. )
  160. credentials = tool_configuration.decrypt(provider_obj.credentials)
  161. credentials = tool_configuration.mask_tool_credentials(credentials)
  162. return credentials
  163. @staticmethod
  164. def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
  165. """
  166. delete tool provider
  167. """
  168. provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  169. if provider_obj is None:
  170. raise ValueError(f"you have not added provider {provider_name}")
  171. db.session.delete(provider_obj)
  172. db.session.commit()
  173. # delete cache
  174. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  175. tool_configuration = ProviderConfigEncrypter(
  176. tenant_id=tenant_id,
  177. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  178. provider_type=provider_controller.provider_type.value,
  179. provider_identity=provider_controller.entity.identity.name,
  180. )
  181. tool_configuration.delete_tool_credentials_cache()
  182. return {"result": "success"}
  183. @staticmethod
  184. def get_builtin_tool_provider_icon(provider: str):
  185. """
  186. get tool provider icon and it's mimetype
  187. """
  188. icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
  189. icon_bytes = Path(icon_path).read_bytes()
  190. return icon_bytes, mime_type
  191. @staticmethod
  192. def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
  193. """
  194. list builtin tools
  195. """
  196. # get all builtin providers
  197. provider_controllers = ToolManager.list_builtin_providers(tenant_id)
  198. # get all user added providers
  199. db_providers: list[BuiltinToolProvider] = (
  200. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
  201. )
  202. # rewrite db_providers
  203. for db_provider in db_providers:
  204. try:
  205. GenericProviderID(db_provider.provider)
  206. except Exception:
  207. db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
  208. # find provider
  209. find_provider = lambda provider: next(
  210. filter(lambda db_provider: db_provider.provider == provider, db_providers), None
  211. )
  212. result: list[ToolProviderApiEntity] = []
  213. for provider_controller in provider_controllers:
  214. try:
  215. # handle include, exclude
  216. if is_filtered(
  217. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  218. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  219. data=provider_controller,
  220. name_func=lambda x: x.identity.name,
  221. ):
  222. continue
  223. # convert provider controller to user provider
  224. user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
  225. provider_controller=provider_controller,
  226. db_provider=find_provider(provider_controller.entity.identity.name),
  227. decrypt_credentials=True,
  228. )
  229. # add icon
  230. ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
  231. tools = provider_controller.get_tools()
  232. for tool in tools:
  233. user_builtin_provider.tools.append(
  234. ToolTransformService.convert_tool_entity_to_api_entity(
  235. tenant_id=tenant_id,
  236. tool=tool,
  237. credentials=user_builtin_provider.original_credentials,
  238. labels=ToolLabelManager.get_tool_labels(provider_controller),
  239. )
  240. )
  241. result.append(user_builtin_provider)
  242. except Exception as e:
  243. raise e
  244. return BuiltinToolProviderSort.sort(result)
  245. @staticmethod
  246. def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
  247. try:
  248. full_provider_name = provider_name
  249. provider_id_entity = GenericProviderID(provider_name)
  250. provider_name = provider_id_entity.provider_name
  251. if provider_id_entity.organization != "langgenius":
  252. return None
  253. provider_obj = (
  254. db.session.query(BuiltinToolProvider)
  255. .filter(
  256. BuiltinToolProvider.tenant_id == tenant_id,
  257. (BuiltinToolProvider.provider == provider_name)
  258. | (BuiltinToolProvider.provider == full_provider_name),
  259. )
  260. .first()
  261. )
  262. if provider_obj is None:
  263. return None
  264. provider_obj.provider = GenericProviderID(provider_obj.provider).to_string()
  265. return provider_obj
  266. except Exception:
  267. # it's an old provider without organization
  268. return (
  269. db.session.query(BuiltinToolProvider)
  270. .filter(
  271. BuiltinToolProvider.tenant_id == tenant_id,
  272. (BuiltinToolProvider.provider == provider_name),
  273. )
  274. .first()
  275. )