builtin_tools_manage_service.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. import json
  2. import logging
  3. from pathlib import Path
  4. from sqlalchemy.orm import Session
  5. from configs import dify_config
  6. from core.helper.position_helper import is_filtered
  7. from core.model_runtime.utils.encoders import jsonable_encoder
  8. from core.plugin.entities.plugin import GenericProviderID
  9. from core.plugin.manager.exc import PluginDaemonClientSideError
  10. from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
  11. from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
  12. from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
  13. from core.tools.tool_label_manager import ToolLabelManager
  14. from core.tools.tool_manager import ToolManager
  15. from core.tools.utils.configuration import ProviderConfigEncrypter
  16. from extensions.ext_database import db
  17. from models.tools import BuiltinToolProvider
  18. from services.tools.tools_transform_service import ToolTransformService
  19. logger = logging.getLogger(__name__)
  20. class BuiltinToolManageService:
  21. @staticmethod
  22. def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
  23. """
  24. list builtin tool provider tools
  25. :param user_id: the id of the user
  26. :param tenant_id: the id of the tenant
  27. :param provider: the name of the provider
  28. :return: the list of tools
  29. """
  30. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  31. tools = provider_controller.get_tools()
  32. tool_provider_configurations = ProviderConfigEncrypter(
  33. tenant_id=tenant_id,
  34. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  35. provider_type=provider_controller.provider_type.value,
  36. provider_identity=provider_controller.entity.identity.name,
  37. )
  38. # check if user has added the provider
  39. builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  40. credentials = {}
  41. if builtin_provider is not None:
  42. # get credentials
  43. credentials = builtin_provider.credentials
  44. credentials = tool_provider_configurations.decrypt(credentials)
  45. result: list[ToolApiEntity] = []
  46. for tool in tools or []:
  47. result.append(
  48. ToolTransformService.convert_tool_entity_to_api_entity(
  49. tool=tool,
  50. credentials=credentials,
  51. tenant_id=tenant_id,
  52. labels=ToolLabelManager.get_tool_labels(provider_controller),
  53. )
  54. )
  55. return result
  56. @staticmethod
  57. def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
  58. """
  59. get builtin tool provider info
  60. """
  61. provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
  62. tool_provider_configurations = ProviderConfigEncrypter(
  63. tenant_id=tenant_id,
  64. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  65. provider_type=provider_controller.provider_type.value,
  66. provider_identity=provider_controller.entity.identity.name,
  67. )
  68. # check if user has added the provider
  69. builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
  70. credentials = {}
  71. if builtin_provider is not None:
  72. # get credentials
  73. credentials = builtin_provider.credentials
  74. credentials = tool_provider_configurations.decrypt(credentials)
  75. entity = ToolTransformService.builtin_provider_to_user_provider(
  76. provider_controller=provider_controller,
  77. db_provider=builtin_provider,
  78. decrypt_credentials=True,
  79. )
  80. entity.original_credentials = {}
  81. return entity
  82. @staticmethod
  83. def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
  84. """
  85. list builtin provider credentials schema
  86. :param provider_name: the name of the provider
  87. :param tenant_id: the id of the tenant
  88. :return: the list of tool providers
  89. """
  90. provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
  91. return jsonable_encoder(provider.get_credentials_schema())
  92. @staticmethod
  93. def update_builtin_tool_provider(
  94. session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
  95. ):
  96. """
  97. update builtin tool provider
  98. """
  99. # get if the provider exists
  100. provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  101. try:
  102. # get provider
  103. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  104. if not provider_controller.need_credentials:
  105. raise ValueError(f"provider {provider_name} does not need credentials")
  106. tool_configuration = ProviderConfigEncrypter(
  107. tenant_id=tenant_id,
  108. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  109. provider_type=provider_controller.provider_type.value,
  110. provider_identity=provider_controller.entity.identity.name,
  111. )
  112. # get original credentials if exists
  113. if provider is not None:
  114. original_credentials = tool_configuration.decrypt(provider.credentials)
  115. masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
  116. # check if the credential has changed, save the original credential
  117. for name, value in credentials.items():
  118. if name in masked_credentials and value == masked_credentials[name]:
  119. credentials[name] = original_credentials[name]
  120. # validate credentials
  121. provider_controller.validate_credentials(user_id, credentials)
  122. # encrypt credentials
  123. credentials = tool_configuration.encrypt(credentials)
  124. except (
  125. PluginDaemonClientSideError,
  126. ToolProviderNotFoundError,
  127. ToolNotFoundError,
  128. ToolProviderCredentialValidationError,
  129. ) as e:
  130. raise ValueError(str(e))
  131. if provider is None:
  132. # create provider
  133. provider = BuiltinToolProvider(
  134. tenant_id=tenant_id,
  135. user_id=user_id,
  136. provider=provider_name,
  137. encrypted_credentials=json.dumps(credentials),
  138. )
  139. session.add(provider)
  140. else:
  141. provider.encrypted_credentials = json.dumps(credentials)
  142. # delete cache
  143. tool_configuration.delete_tool_credentials_cache()
  144. return {"result": "success"}
  145. @staticmethod
  146. def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
  147. """
  148. get builtin tool provider credentials
  149. """
  150. provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  151. if provider_obj is None:
  152. return {}
  153. provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
  154. tool_configuration = ProviderConfigEncrypter(
  155. tenant_id=tenant_id,
  156. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  157. provider_type=provider_controller.provider_type.value,
  158. provider_identity=provider_controller.entity.identity.name,
  159. )
  160. credentials = tool_configuration.decrypt(provider_obj.credentials)
  161. credentials = tool_configuration.mask_tool_credentials(credentials)
  162. return credentials
  163. @staticmethod
  164. def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
  165. """
  166. delete tool provider
  167. """
  168. provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
  169. if provider_obj is None:
  170. raise ValueError(f"you have not added provider {provider_name}")
  171. db.session.delete(provider_obj)
  172. db.session.commit()
  173. # delete cache
  174. provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
  175. tool_configuration = ProviderConfigEncrypter(
  176. tenant_id=tenant_id,
  177. config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
  178. provider_type=provider_controller.provider_type.value,
  179. provider_identity=provider_controller.entity.identity.name,
  180. )
  181. tool_configuration.delete_tool_credentials_cache()
  182. return {"result": "success"}
  183. @staticmethod
  184. def get_builtin_tool_provider_icon(provider: str):
  185. """
  186. get tool provider icon and it's mimetype
  187. """
  188. icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider)
  189. icon_bytes = Path(icon_path).read_bytes()
  190. return icon_bytes, mime_type
  191. @staticmethod
  192. def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
  193. """
  194. list builtin tools
  195. """
  196. # get all builtin providers
  197. provider_controllers = ToolManager.list_builtin_providers(tenant_id)
  198. # get all user added providers
  199. db_providers: list[BuiltinToolProvider] = (
  200. db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
  201. )
  202. # rewrite db_providers
  203. for db_provider in db_providers:
  204. try:
  205. GenericProviderID(db_provider.provider)
  206. except Exception:
  207. db_provider.provider = f"langgenius/{db_provider.provider}/{db_provider.provider}"
  208. # find provider
  209. def find_provider(provider):
  210. return next(filter(lambda db_provider: db_provider.provider == provider, db_providers), None)
  211. result: list[ToolProviderApiEntity] = []
  212. for provider_controller in provider_controllers:
  213. try:
  214. # handle include, exclude
  215. if is_filtered(
  216. include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
  217. exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
  218. data=provider_controller,
  219. name_func=lambda x: x.identity.name,
  220. ):
  221. continue
  222. # convert provider controller to user provider
  223. user_builtin_provider = ToolTransformService.builtin_provider_to_user_provider(
  224. provider_controller=provider_controller,
  225. db_provider=find_provider(provider_controller.entity.identity.name),
  226. decrypt_credentials=True,
  227. )
  228. # add icon
  229. ToolTransformService.repack_provider(tenant_id=tenant_id, provider=user_builtin_provider)
  230. tools = provider_controller.get_tools()
  231. for tool in tools or []:
  232. user_builtin_provider.tools.append(
  233. ToolTransformService.convert_tool_entity_to_api_entity(
  234. tenant_id=tenant_id,
  235. tool=tool,
  236. credentials=user_builtin_provider.original_credentials,
  237. labels=ToolLabelManager.get_tool_labels(provider_controller),
  238. )
  239. )
  240. result.append(user_builtin_provider)
  241. except Exception as e:
  242. raise e
  243. return BuiltinToolProviderSort.sort(result)
  244. @staticmethod
  245. def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
  246. try:
  247. full_provider_name = provider_name
  248. provider_id_entity = GenericProviderID(provider_name)
  249. provider_name = provider_id_entity.provider_name
  250. if provider_id_entity.organization != "langgenius":
  251. provider_obj = (
  252. db.session.query(BuiltinToolProvider)
  253. .filter(
  254. BuiltinToolProvider.tenant_id == tenant_id,
  255. BuiltinToolProvider.provider == full_provider_name,
  256. )
  257. .first()
  258. )
  259. else:
  260. provider_obj = (
  261. db.session.query(BuiltinToolProvider)
  262. .filter(
  263. BuiltinToolProvider.tenant_id == tenant_id,
  264. (BuiltinToolProvider.provider == provider_name)
  265. | (BuiltinToolProvider.provider == full_provider_name),
  266. )
  267. .first()
  268. )
  269. if provider_obj is None:
  270. return None
  271. provider_obj.provider = GenericProviderID(provider_obj.provider).to_string()
  272. return provider_obj
  273. except Exception:
  274. # it's an old provider without organization
  275. return (
  276. db.session.query(BuiltinToolProvider)
  277. .filter(
  278. BuiltinToolProvider.tenant_id == tenant_id,
  279. (BuiltinToolProvider.provider == provider_name),
  280. )
  281. .first()
  282. )