builtin_tools_manage_service.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. import json
  2. import logging
  3. from pathlib import Path
  4. from configs import dify_config
  5. from core.helper.position_helper import is_filtered
  6. from core.model_runtime.utils.encoders import jsonable_encoder
  7. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  8. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
  9. from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
  10. from core.tools.tool_label_manager import ToolLabelManager
  11. from core.tools.tool_manager import ToolManager
  12. from core.tools.utils.configuration import ProviderConfigEncrypter
  13. from extensions.ext_database import db
  14. from models.tools import BuiltinToolProvider
  15. from services.tools.tools_transform_service import ToolTransformService
  16. logger = logging.getLogger(__name__)
  17. class BuiltinToolManageService:
  18. @staticmethod
  19. def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[ToolApiEntity]:
  20. """
  21. list builtin tool provider tools
  22. :param user_id: the id of the user
  23. :param tenant_id: the id of the tenant
  24. :param provider: the name of the provider
  25. :return: the list of tools
  26. """
  27. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  28. tools = provider_controller.get_tools()
  29. tool_provider_configurations = ProviderConfigEncrypter(
  30. tenant_id=tenant_id,
  31. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  32. provider_type=provider_controller.provider_type.value,
  33. provider_identity=provider_controller.entity.identity.name,
  34. )
  35. # check if user has added the provider
  36. builtin_provider: BuiltinToolProvider | None = (
  37. db.session.query(BuiltinToolProvider)
  38. .filter(
  39. BuiltinToolProvider.tenant_id == tenant_id,
  40. BuiltinToolProvider.provider == provider,
  41. )
  42. .first()
  43. )
  44. credentials = {}
  45. if builtin_provider is not None:
  46. # get credentials
  47. credentials = builtin_provider.credentials
  48. credentials = tool_provider_configurations.decrypt(credentials)
  49. result = []
  50. for tool in tools:
  51. result.append(
  52. ToolTransformService.tool_to_user_tool(
  53. tool=tool,
  54. credentials=credentials,
  55. tenant_id=tenant_id,
  56. labels=ToolLabelManager.get_tool_labels(provider_controller),
  57. )
  58. )
  59. return result
  60. @staticmethod
  61. def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
  62. """
  63. list builtin provider credentials schema
  64. :param provider_name: the name of the provider
  65. :param tenant_id: the id of the tenant
  66. :return: the list of tool providers
  67. """
  68. provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
  69. return jsonable_encoder(provider.get_credentials_schema())
  70. @staticmethod
  71. def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
  72. """
  73. update builtin tool provider
  74. """
  75. # get if the provider exists
  76. provider: BuiltinToolProvider | None = (
  77. db.session.query(BuiltinToolProvider)
  78. .filter(
  79. BuiltinToolProvider.tenant_id == tenant_id,
  80. BuiltinToolProvider.provider == provider_name,
  81. )
  82. .first()
  83. )
  84. try:
  85. # get provider
  86. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  87. if not provider_controller.need_credentials:
  88. raise ValueError(f"provider {provider_name} does not need credentials")
  89. tool_configuration = ProviderConfigEncrypter(
  90. tenant_id=tenant_id,
  91. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  92. provider_type=provider_controller.provider_type.value,
  93. provider_identity=provider_controller.entity.identity.name,
  94. )
  95. # get original credentials if exists
  96. if provider is not None:
  97. original_credentials = tool_configuration.decrypt(provider.credentials)
  98. masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
  99. # check if the credential has changed, save the original credential
  100. for name, value in credentials.items():
  101. if name in masked_credentials and value == masked_credentials[name]:
  102. credentials[name] = original_credentials[name]
  103. # validate credentials
  104. provider_controller.validate_credentials(user_id, credentials)
  105. # encrypt credentials
  106. credentials = tool_configuration.encrypt(credentials)
  107. except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
  108. raise ValueError(str(e))
  109. if provider is None:
  110. # create provider
  111. provider = BuiltinToolProvider(
  112. tenant_id=tenant_id,
  113. user_id=user_id,
  114. provider=provider_name,
  115. encrypted_credentials=json.dumps(credentials),
  116. )
  117. db.session.add(provider)
  118. db.session.commit()
  119. else:
  120. provider.encrypted_credentials = json.dumps(credentials)
  121. db.session.add(provider)
  122. db.session.commit()
  123. # delete cache
  124. tool_configuration.delete_tool_credentials_cache()
  125. return {"result": "success"}
  126. @staticmethod
  127. def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
  128. """
  129. get builtin tool provider credentials
  130. """
  131. provider_obj: BuiltinToolProvider | None = (
  132. db.session.query(BuiltinToolProvider)
  133. .filter(
  134. BuiltinToolProvider.tenant_id == tenant_id,
  135. BuiltinToolProvider.provider == provider,
  136. )
  137. .first()
  138. )
  139. if provider_obj is None:
  140. return {}
  141. provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
  142. tool_configuration = ProviderConfigEncrypter(
  143. tenant_id=tenant_id,
  144. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  145. provider_type=provider_controller.provider_type.value,
  146. provider_identity=provider_controller.entity.identity.name,
  147. )
  148. credentials = tool_configuration.decrypt(provider_obj.credentials)
  149. credentials = tool_configuration.mask_tool_credentials(credentials)
  150. return credentials
  151. @staticmethod
  152. def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
  153. """
  154. delete tool provider
  155. """
  156. provider_obj: BuiltinToolProvider | None = (
  157. db.session.query(BuiltinToolProvider)
  158. .filter(
  159. BuiltinToolProvider.tenant_id == tenant_id,
  160. BuiltinToolProvider.provider == provider_name,
  161. )
  162. .first()
  163. )
  164. if provider_obj is None:
  165. raise ValueError(f"you have not added provider {provider_name}")
  166. db.session.delete(provider_obj)
  167. db.session.commit()
  168. # delete cache
  169. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  170. tool_configuration = ProviderConfigEncrypter(
  171. tenant_id=tenant_id,
  172. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  173. provider_type=provider_controller.provider_type.value,
  174. provider_identity=provider_controller.entity.identity.name,
  175. )
  176. tool_configuration.delete_tool_credentials_cache()
  177. return {"result": "success"}
  178. @staticmethod
  179. def get_builtin_tool_provider_icon(provider: str, tenant_id: str):
  180. """
  181. get tool provider icon and it's mimetype
  182. """
  183. icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider, tenant_id)
  184. icon_bytes = Path(icon_path).read_bytes()
  185. return icon_bytes, mime_type
  186. @staticmethod
  187. def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
  188. """
  189. list builtin tools
  190. """
  191. # get all builtin providers
  192. provider_controllers = ToolManager.list_builtin_providers(tenant_id)
  193. # get all user added providers
  194. db_providers: list[BuiltinToolProvider] = (
  195. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
  196. )
  197. # find provider
  198. find_provider = lambda provider: next(
  199. filter(lambda db_provider: db_provider.provider == provider, db_providers), None
  200. )
  201. result: list[ToolProviderApiEntity] = []
  202. for provider_controller in provider_controllers:
  203. try:
  204. # handle include, exclude
  205. if is_filtered(
  206. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  207. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  208. data=provider_controller,
  209. name_func=lambda x: x.identity.name,
  210. ):
  211. continue
  212. # convert provider controller to user provider
  213. user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
  214. provider_controller=provider_controller,
  215. db_provider=find_provider(provider_controller.entity.identity.name),
  216. decrypt_credentials=True,
  217. )
  218. # add icon
  219. ToolTransformService.repack_provider(user_builtin_provider)
  220. tools = provider_controller.get_tools()
  221. for tool in tools:
  222. user_builtin_provider.tools.append(
  223. ToolTransformService.tool_to_user_tool(
  224. tenant_id=tenant_id,
  225. tool=tool,
  226. credentials=user_builtin_provider.original_credentials,
  227. labels=ToolLabelManager.get_tool_labels(provider_controller),
  228. )
  229. )
  230. result.append(user_builtin_provider)
  231. except Exception as e:
  232. raise e
  233. return BuiltinToolProviderSort.sort(result)