builtin_tools_manage_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331
  1. import json
  2. import logging
  3. from pathlib import Path
  4. from sqlalchemy.orm import Session
  5. from configs import dify_config
  6. from core.helper.position_helper import is_filtered
  7. from core.model_runtime.utils.encoders import jsonable_encoder
  8. from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
  9. from core.plugin.manager.exc import PluginDaemonClientSideError
  10. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  11. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
  12. from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
  13. from core.tools.tool_label_manager import ToolLabelManager
  14. from core.tools.tool_manager import ToolManager
  15. from core.tools.utils.configuration import ProviderConfigEncrypter
  16. from extensions.ext_database import db
  17. from models.tools import BuiltinToolProvider
  18. from services.tools.tools_transform_service import ToolTransformService
  19. logger = logging.getLogger(__name__)
  20. class BuiltinToolManageService:
  21. @staticmethod
  22. def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
  23. """
  24. list builtin tool provider tools
  25. :param user_id: the id of the user
  26. :param tenant_id: the id of the tenant
  27. :param provider: the name of the provider
  28. :return: the list of tools
  29. """
  30. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  31. tools = provider_controller.get_tools()
  32. tool_provider_configurations = ProviderConfigEncrypter(
  33. tenant_id=tenant_id,
  34. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  35. provider_type=provider_controller.provider_type.value,
  36. provider_identity=provider_controller.entity.identity.name,
  37. )
  38. # check if user has added the provider
  39. builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  40. credentials = {}
  41. if builtin_provider is not None:
  42. # get credentials
  43. credentials = builtin_provider.credentials
  44. credentials = tool_provider_configurations.decrypt(credentials)
  45. result: list[ToolApiEntity] = []
  46. for tool in tools or []:
  47. result.append(
  48. ToolTransformService.convert_tool_entity_to_api_entity(
  49. tool=tool,
  50. credentials=credentials,
  51. tenant_id=tenant_id,
  52. labels=ToolLabelManager.get_tool_labels(provider_controller),
  53. )
  54. )
  55. return result
  56. @staticmethod
  57. def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
  58. """
  59. get builtin tool provider info
  60. """
  61. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  62. tool_provider_configurations = ProviderConfigEncrypter(
  63. tenant_id=tenant_id,
  64. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  65. provider_type=provider_controller.provider_type.value,
  66. provider_identity=provider_controller.entity.identity.name,
  67. )
  68. # check if user has added the provider
  69. builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  70. credentials = {}
  71. if builtin_provider is not None:
  72. # get credentials
  73. credentials = builtin_provider.credentials
  74. credentials = tool_provider_configurations.decrypt(credentials)
  75. entity = ToolTransformService.builtin_provider_to_user_provider(
  76. provider_controller=provider_controller,
  77. db_provider=builtin_provider,
  78. decrypt_credentials=True,
  79. )
  80. entity.original_credentials = {}
  81. return entity
  82. @staticmethod
  83. def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
  84. """
  85. list builtin provider credentials schema
  86. :param provider_name: the name of the provider
  87. :param tenant_id: the id of the tenant
  88. :return: the list of tool providers
  89. """
  90. provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
  91. return jsonable_encoder(provider.get_credentials_schema())
  92. @staticmethod
  93. def update_builtin_tool_provider(
  94. session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
  95. ):
  96. """
  97. update builtin tool provider
  98. """
  99. # get if the provider exists
  100. provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  101. try:
  102. # get provider
  103. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  104. if not provider_controller.need_credentials:
  105. raise ValueError(f"provider {provider_name} does not need credentials")
  106. tool_configuration = ProviderConfigEncrypter(
  107. tenant_id=tenant_id,
  108. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  109. provider_type=provider_controller.provider_type.value,
  110. provider_identity=provider_controller.entity.identity.name,
  111. )
  112. # get original credentials if exists
  113. if provider is not None:
  114. original_credentials = tool_configuration.decrypt(provider.credentials)
  115. masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
  116. # check if the credential has changed, save the original credential
  117. for name, value in credentials.items():
  118. if name in masked_credentials and value == masked_credentials[name]:
  119. credentials[name] = original_credentials[name]
  120. # validate credentials
  121. provider_controller.validate_credentials(user_id, credentials)
  122. # encrypt credentials
  123. credentials = tool_configuration.encrypt(credentials)
  124. except (
  125. PluginDaemonClientSideError,
  126. ToolProviderNotFoundError,
  127. ToolNotFoundError,
  128. ToolProviderCredentialValidationError,
  129. ) as e:
  130. raise ValueError(str(e))
  131. if provider is None:
  132. # create provider
  133. provider = BuiltinToolProvider(
  134. tenant_id=tenant_id,
  135. user_id=user_id,
  136. provider=provider_name,
  137. encrypted_credentials=json.dumps(credentials),
  138. )
  139. db.session.add(provider)
  140. else:
  141. provider.encrypted_credentials = json.dumps(credentials)
  142. # delete cache
  143. tool_configuration.delete_tool_credentials_cache()
  144. db.session.commit()
  145. return {"result": "success"}
  146. @staticmethod
  147. def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
  148. """
  149. get builtin tool provider credentials
  150. """
  151. provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  152. if provider_obj is None:
  153. return {}
  154. provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
  155. tool_configuration = ProviderConfigEncrypter(
  156. tenant_id=tenant_id,
  157. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  158. provider_type=provider_controller.provider_type.value,
  159. provider_identity=provider_controller.entity.identity.name,
  160. )
  161. credentials = tool_configuration.decrypt(provider_obj.credentials)
  162. credentials = tool_configuration.mask_tool_credentials(credentials)
  163. return credentials
  164. @staticmethod
  165. def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
  166. """
  167. delete tool provider
  168. """
  169. provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  170. if provider_obj is None:
  171. raise ValueError(f"you have not added provider {provider_name}")
  172. db.session.delete(provider_obj)
  173. db.session.commit()
  174. # delete cache
  175. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  176. tool_configuration = ProviderConfigEncrypter(
  177. tenant_id=tenant_id,
  178. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  179. provider_type=provider_controller.provider_type.value,
  180. provider_identity=provider_controller.entity.identity.name,
  181. )
  182. tool_configuration.delete_tool_credentials_cache()
  183. return {"result": "success"}
  184. @staticmethod
  185. def get_builtin_tool_provider_icon(provider: str):
  186. """
  187. get tool provider icon and it's mimetype
  188. """
  189. icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
  190. icon_bytes = Path(icon_path).read_bytes()
  191. return icon_bytes, mime_type
  192. @staticmethod
  193. def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
  194. """
  195. list builtin tools
  196. """
  197. # get all builtin providers
  198. provider_controllers = ToolManager.list_builtin_providers(tenant_id)
  199. with db.session.no_autoflush:
  200. # get all user added providers
  201. db_providers: list[BuiltinToolProvider] = (
  202. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
  203. )
  204. # rewrite db_providers
  205. for db_provider in db_providers:
  206. db_provider.provider = str(ToolProviderID(db_provider.provider))
  207. # find provider
  208. def find_provider(provider):
  209. return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
  210. result: list[ToolProviderApiEntity] = []
  211. for provider_controller in provider_controllers:
  212. try:
  213. # handle include, exclude
  214. if is_filtered(
  215. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  216. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  217. data=provider_controller,
  218. name_func=lambda x: x.identity.name,
  219. ):
  220. continue
  221. # convert provider controller to user provider
  222. user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
  223. provider_controller=provider_controller,
  224. db_provider=find_provider(provider_controller.entity.identity.name),
  225. decrypt_credentials=True,
  226. )
  227. # add icon
  228. ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
  229. tools = provider_controller.get_tools()
  230. for tool in tools or []:
  231. user_builtin_provider.tools.append(
  232. ToolTransformService.convert_tool_entity_to_api_entity(
  233. tenant_id=tenant_id,
  234. tool=tool,
  235. credentials=user_builtin_provider.original_credentials,
  236. labels=ToolLabelManager.get_tool_labels(provider_controller),
  237. )
  238. )
  239. result.append(user_builtin_provider)
  240. except Exception as e:
  241. raise e
  242. return BuiltinToolProviderSort.sort(result)
  243. @staticmethod
  244. def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
  245. try:
  246. full_provider_name = provider_name
  247. provider_id_entity = GenericProviderID(provider_name)
  248. provider_name = provider_id_entity.provider_name
  249. if provider_id_entity.organization != "langgenius":
  250. provider_obj = (
  251. db.session.query(BuiltinToolProvider)
  252. .filter(
  253. BuiltinToolProvider.tenant_id == tenant_id,
  254. BuiltinToolProvider.provider == full_provider_name,
  255. )
  256. .first()
  257. )
  258. else:
  259. provider_obj = (
  260. db.session.query(BuiltinToolProvider)
  261. .filter(
  262. BuiltinToolProvider.tenant_id == tenant_id,
  263. (BuiltinToolProvider.provider == provider_name)
  264. | (BuiltinToolProvider.provider == full_provider_name),
  265. )
  266. .first()
  267. )
  268. if provider_obj is None:
  269. return None
  270. provider_obj.provider = GenericProviderID(provider_obj.provider).to_string()
  271. return provider_obj
  272. except Exception:
  273. # it's an old provider without organization
  274. return (
  275. db.session.query(BuiltinToolProvider)
  276. .filter(
  277. BuiltinToolProvider.tenant_id == tenant_id,
  278. (BuiltinToolProvider.provider == provider_name),
  279. )
  280. .first()
  281. )