builtin_tools_manage_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. import json
  2. import logging
  3. from pathlib import Path
  4. from configs import dify_config
  5. from core.helper.position_helper import is_filtered
  6. from core.model_runtime.utils.encoders import jsonable_encoder
  7. from core.plugin.entities.plugin import GenericProviderID
  8. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  9. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
  10. from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
  11. from core.tools.tool_label_manager import ToolLabelManager
  12. from core.tools.tool_manager import ToolManager
  13. from core.tools.utils.configuration import ProviderConfigEncrypter
  14. from extensions.ext_database import db
  15. from models.tools import BuiltinToolProvider
  16. from services.tools.tools_transform_service import ToolTransformService
  17. logger = logging.getLogger(__name__)
  18. class BuiltinToolManageService:
  19. @staticmethod
  20. def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[ToolApiEntity]:
  21. """
  22. list builtin tool provider tools
  23. :param user_id: the id of the user
  24. :param tenant_id: the id of the tenant
  25. :param provider: the name of the provider
  26. :return: the list of tools
  27. """
  28. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  29. tools = provider_controller.get_tools()
  30. tool_provider_configurations = ProviderConfigEncrypter(
  31. tenant_id=tenant_id,
  32. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  33. provider_type=provider_controller.provider_type.value,
  34. provider_identity=provider_controller.entity.identity.name,
  35. )
  36. # check if user has added the provider
  37. builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  38. credentials = {}
  39. if builtin_provider is not None:
  40. # get credentials
  41. credentials = builtin_provider.credentials
  42. credentials = tool_provider_configurations.decrypt(credentials)
  43. result = []
  44. for tool in tools:
  45. result.append(
  46. ToolTransformService.convert_tool_entity_to_api_entity(
  47. tool=tool,
  48. credentials=credentials,
  49. tenant_id=tenant_id,
  50. labels=ToolLabelManager.get_tool_labels(provider_controller),
  51. )
  52. )
  53. return result
  54. @staticmethod
  55. def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
  56. """
  57. get builtin tool provider info
  58. """
  59. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  60. tool_provider_configurations = ProviderConfigEncrypter(
  61. tenant_id=tenant_id,
  62. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  63. provider_type=provider_controller.provider_type.value,
  64. provider_identity=provider_controller.entity.identity.name,
  65. )
  66. # check if user has added the provider
  67. builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  68. credentials = {}
  69. if builtin_provider is not None:
  70. # get credentials
  71. credentials = builtin_provider.credentials
  72. credentials = tool_provider_configurations.decrypt(credentials)
  73. entity = ToolTransformService.builtin_provider_to_user_provider(
  74. provider_controller=provider_controller,
  75. db_provider=builtin_provider,
  76. decrypt_credentials=True,
  77. )
  78. entity.original_credentials = {}
  79. return entity
  80. @staticmethod
  81. def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
  82. """
  83. list builtin provider credentials schema
  84. :param provider_name: the name of the provider
  85. :param tenant_id: the id of the tenant
  86. :return: the list of tool providers
  87. """
  88. provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
  89. return jsonable_encoder(provider.get_credentials_schema())
  90. @staticmethod
  91. def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
  92. """
  93. update builtin tool provider
  94. """
  95. # get if the provider exists
  96. provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  97. try:
  98. # get provider
  99. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  100. if not provider_controller.need_credentials:
  101. raise ValueError(f"provider {provider_name} does not need credentials")
  102. tool_configuration = ProviderConfigEncrypter(
  103. tenant_id=tenant_id,
  104. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  105. provider_type=provider_controller.provider_type.value,
  106. provider_identity=provider_controller.entity.identity.name,
  107. )
  108. # get original credentials if exists
  109. if provider is not None:
  110. original_credentials = tool_configuration.decrypt(provider.credentials)
  111. masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
  112. # check if the credential has changed, save the original credential
  113. for name, value in credentials.items():
  114. if name in masked_credentials and value == masked_credentials[name]:
  115. credentials[name] = original_credentials[name]
  116. # validate credentials
  117. provider_controller.validate_credentials(user_id, credentials)
  118. # encrypt credentials
  119. credentials = tool_configuration.encrypt(credentials)
  120. except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
  121. raise ValueError(str(e))
  122. if provider is None:
  123. # create provider
  124. provider = BuiltinToolProvider(
  125. tenant_id=tenant_id,
  126. user_id=user_id,
  127. provider=provider_name,
  128. encrypted_credentials=json.dumps(credentials),
  129. )
  130. db.session.add(provider)
  131. db.session.commit()
  132. else:
  133. provider.encrypted_credentials = json.dumps(credentials)
  134. db.session.add(provider)
  135. db.session.commit()
  136. # delete cache
  137. tool_configuration.delete_tool_credentials_cache()
  138. return {"result": "success"}
  139. @staticmethod
  140. def get_builtin_tool_provider_credentials(user_id: str, tenant_id: str, provider: str):
  141. """
  142. get builtin tool provider credentials
  143. """
  144. provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  145. if provider_obj is None:
  146. return {}
  147. provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
  148. tool_configuration = ProviderConfigEncrypter(
  149. tenant_id=tenant_id,
  150. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  151. provider_type=provider_controller.provider_type.value,
  152. provider_identity=provider_controller.entity.identity.name,
  153. )
  154. credentials = tool_configuration.decrypt(provider_obj.credentials)
  155. credentials = tool_configuration.mask_tool_credentials(credentials)
  156. return credentials
  157. @staticmethod
  158. def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
  159. """
  160. delete tool provider
  161. """
  162. provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  163. if provider_obj is None:
  164. raise ValueError(f"you have not added provider {provider_name}")
  165. db.session.delete(provider_obj)
  166. db.session.commit()
  167. # delete cache
  168. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  169. tool_configuration = ProviderConfigEncrypter(
  170. tenant_id=tenant_id,
  171. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  172. provider_type=provider_controller.provider_type.value,
  173. provider_identity=provider_controller.entity.identity.name,
  174. )
  175. tool_configuration.delete_tool_credentials_cache()
  176. return {"result": "success"}
  177. @staticmethod
  178. def get_builtin_tool_provider_icon(provider: str):
  179. """
  180. get tool provider icon and it's mimetype
  181. """
  182. icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
  183. icon_bytes = Path(icon_path).read_bytes()
  184. return icon_bytes, mime_type
  185. @staticmethod
  186. def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
  187. """
  188. list builtin tools
  189. """
  190. # get all builtin providers
  191. provider_controllers = ToolManager.list_builtin_providers(tenant_id)
  192. # get all user added providers
  193. db_providers: list[BuiltinToolProvider] = (
  194. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
  195. )
  196. # rewrite db_providers
  197. for db_provider in db_providers:
  198. try:
  199. GenericProviderID(db_provider.provider)
  200. except Exception:
  201. db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
  202. # find provider
  203. find_provider = lambda provider: next(
  204. filter(lambda db_provider: db_provider.provider == provider, db_providers), None
  205. )
  206. result: list[ToolProviderApiEntity] = []
  207. for provider_controller in provider_controllers:
  208. try:
  209. # handle include, exclude
  210. if is_filtered(
  211. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  212. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  213. data=provider_controller,
  214. name_func=lambda x: x.identity.name,
  215. ):
  216. continue
  217. # convert provider controller to user provider
  218. user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
  219. provider_controller=provider_controller,
  220. db_provider=find_provider(provider_controller.entity.identity.name),
  221. decrypt_credentials=True,
  222. )
  223. # add icon
  224. ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
  225. tools = provider_controller.get_tools()
  226. for tool in tools:
  227. user_builtin_provider.tools.append(
  228. ToolTransformService.convert_tool_entity_to_api_entity(
  229. tenant_id=tenant_id,
  230. tool=tool,
  231. credentials=user_builtin_provider.original_credentials,
  232. labels=ToolLabelManager.get_tool_labels(provider_controller),
  233. )
  234. )
  235. result.append(user_builtin_provider)
  236. except Exception as e:
  237. raise e
  238. return BuiltinToolProviderSort.sort(result)
  239. @staticmethod
  240. def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
  241. try:
  242. full_provider_name = provider_name
  243. provider_id_entity = GenericProviderID(provider_name)
  244. provider_name = provider_id_entity.provider_name
  245. if provider_id_entity.organization != "langgenius":
  246. return None
  247. provider_obj = (
  248. db.session.query(BuiltinToolProvider)
  249. .filter(
  250. BuiltinToolProvider.tenant_id == tenant_id,
  251. (BuiltinToolProvider.provider == provider_name)
  252. | (BuiltinToolProvider.provider == full_provider_name),
  253. )
  254. .first()
  255. )
  256. if provider_obj is None:
  257. return None
  258. provider_obj.provider = GenericProviderID(provider_obj.provider).to_string()
  259. return provider_obj
  260. except Exception:
  261. # it's an old provider without organization
  262. return (
  263. db.session.query(BuiltinToolProvider)
  264. .filter(
  265. BuiltinToolProvider.tenant_id == tenant_id,
  266. (BuiltinToolProvider.provider == provider_name),
  267. )
  268. .first()
  269. )