configuration.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. import os
  2. from typing import Any, Union
  3. from pydantic import BaseModel
  4. from yaml import FullLoader, load
  5. from core.helper import encrypter
  6. from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
  7. from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
  8. from core.tools.entities.tool_entities import (
  9. ModelToolConfiguration,
  10. ModelToolProviderConfiguration,
  11. ToolParameter,
  12. ToolProviderCredentials,
  13. )
  14. from core.tools.provider.tool_provider import ToolProviderController
  15. from core.tools.tool.tool import Tool
  16. class ToolConfigurationManager(BaseModel):
  17. tenant_id: str
  18. provider_controller: ToolProviderController
  19. def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
  20. """
  21. deep copy credentials
  22. """
  23. return {key: value for key, value in credentials.items()}
  24. def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  25. """
  26. encrypt tool credentials with tenant id
  27. return a deep copy of credentials with encrypted values
  28. """
  29. credentials = self._deep_copy(credentials)
  30. # get fields need to be decrypted
  31. fields = self.provider_controller.get_credentials_schema()
  32. for field_name, field in fields.items():
  33. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  34. if field_name in credentials:
  35. encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
  36. credentials[field_name] = encrypted
  37. return credentials
  38. def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
  39. """
  40. mask tool credentials
  41. return a deep copy of credentials with masked values
  42. """
  43. credentials = self._deep_copy(credentials)
  44. # get fields need to be decrypted
  45. fields = self.provider_controller.get_credentials_schema()
  46. for field_name, field in fields.items():
  47. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  48. if field_name in credentials:
  49. if len(credentials[field_name]) > 6:
  50. credentials[field_name] = \
  51. credentials[field_name][:2] + \
  52. '*' * (len(credentials[field_name]) - 4) +\
  53. credentials[field_name][-2:]
  54. else:
  55. credentials[field_name] = '*' * len(credentials[field_name])
  56. return credentials
  57. def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  58. """
  59. decrypt tool credentials with tenant id
  60. return a deep copy of credentials with decrypted values
  61. """
  62. cache = ToolProviderCredentialsCache(
  63. tenant_id=self.tenant_id,
  64. identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
  65. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  66. )
  67. cached_credentials = cache.get()
  68. if cached_credentials:
  69. return cached_credentials
  70. credentials = self._deep_copy(credentials)
  71. # get fields need to be decrypted
  72. fields = self.provider_controller.get_credentials_schema()
  73. for field_name, field in fields.items():
  74. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  75. if field_name in credentials:
  76. try:
  77. credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
  78. except:
  79. pass
  80. cache.set(credentials)
  81. return credentials
  82. def delete_tool_credentials_cache(self):
  83. cache = ToolProviderCredentialsCache(
  84. tenant_id=self.tenant_id,
  85. identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
  86. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  87. )
  88. cache.delete()
  89. class ToolParameterConfigurationManager(BaseModel):
  90. """
  91. Tool parameter configuration manager
  92. """
  93. tenant_id: str
  94. tool_runtime: Tool
  95. provider_name: str
  96. provider_type: str
  97. def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
  98. """
  99. deep copy parameters
  100. """
  101. return {key: value for key, value in parameters.items()}
  102. def _merge_parameters(self) -> list[ToolParameter]:
  103. """
  104. merge parameters
  105. """
  106. # get tool parameters
  107. tool_parameters = self.tool_runtime.parameters or []
  108. # get tool runtime parameters
  109. runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
  110. # override parameters
  111. current_parameters = tool_parameters.copy()
  112. for runtime_parameter in runtime_parameters:
  113. found = False
  114. for index, parameter in enumerate(current_parameters):
  115. if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
  116. current_parameters[index] = runtime_parameter
  117. found = True
  118. break
  119. if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  120. current_parameters.append(runtime_parameter)
  121. return current_parameters
  122. def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  123. """
  124. mask tool parameters
  125. return a deep copy of parameters with masked values
  126. """
  127. parameters = self._deep_copy(parameters)
  128. # override parameters
  129. current_parameters = self._merge_parameters()
  130. for parameter in current_parameters:
  131. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  132. if parameter.name in parameters:
  133. if len(parameters[parameter.name]) > 6:
  134. parameters[parameter.name] = \
  135. parameters[parameter.name][:2] + \
  136. '*' * (len(parameters[parameter.name]) - 4) +\
  137. parameters[parameter.name][-2:]
  138. else:
  139. parameters[parameter.name] = '*' * len(parameters[parameter.name])
  140. return parameters
  141. def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  142. """
  143. encrypt tool parameters with tenant id
  144. return a deep copy of parameters with encrypted values
  145. """
  146. # override parameters
  147. current_parameters = self._merge_parameters()
  148. for parameter in current_parameters:
  149. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  150. if parameter.name in parameters:
  151. encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
  152. parameters[parameter.name] = encrypted
  153. return parameters
  154. def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  155. """
  156. decrypt tool parameters with tenant id
  157. return a deep copy of parameters with decrypted values
  158. """
  159. cache = ToolParameterCache(
  160. tenant_id=self.tenant_id,
  161. provider=f'{self.provider_type}.{self.provider_name}',
  162. tool_name=self.tool_runtime.identity.name,
  163. cache_type=ToolParameterCacheType.PARAMETER
  164. )
  165. cached_parameters = cache.get()
  166. if cached_parameters:
  167. return cached_parameters
  168. # override parameters
  169. current_parameters = self._merge_parameters()
  170. has_secret_input = False
  171. for parameter in current_parameters:
  172. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  173. if parameter.name in parameters:
  174. try:
  175. has_secret_input = True
  176. parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
  177. except:
  178. pass
  179. if has_secret_input:
  180. cache.set(parameters)
  181. return parameters
  182. def delete_tool_parameters_cache(self):
  183. cache = ToolParameterCache(
  184. tenant_id=self.tenant_id,
  185. provider=f'{self.provider_type}.{self.provider_name}',
  186. tool_name=self.tool_runtime.identity.name,
  187. cache_type=ToolParameterCacheType.PARAMETER
  188. )
  189. cache.delete()
  190. class ModelToolConfigurationManager:
  191. """
  192. Model as tool configuration
  193. """
  194. _configurations: dict[str, ModelToolProviderConfiguration] = {}
  195. _model_configurations: dict[str, ModelToolConfiguration] = {}
  196. _inited = False
  197. @classmethod
  198. def _init_configuration(cls):
  199. """
  200. init configuration
  201. """
  202. absolute_path = os.path.abspath(os.path.dirname(__file__))
  203. model_tools_path = os.path.join(absolute_path, '..', 'model_tools')
  204. # get all .yaml file
  205. files = [f for f in os.listdir(model_tools_path) if f.endswith('.yaml')]
  206. for file in files:
  207. provider = file.split('.')[0]
  208. with open(os.path.join(model_tools_path, file), encoding='utf-8') as f:
  209. configurations = ModelToolProviderConfiguration(**load(f, Loader=FullLoader))
  210. models = configurations.models or []
  211. for model in models:
  212. model_key = f'{provider}.{model.model}'
  213. cls._model_configurations[model_key] = model
  214. cls._configurations[provider] = configurations
  215. cls._inited = True
  216. @classmethod
  217. def get_configuration(cls, provider: str) -> Union[ModelToolProviderConfiguration, None]:
  218. """
  219. get configuration by provider
  220. """
  221. if not cls._inited:
  222. cls._init_configuration()
  223. return cls._configurations.get(provider, None)
  224. @classmethod
  225. def get_all_configuration(cls) -> dict[str, ModelToolProviderConfiguration]:
  226. """
  227. get all configurations
  228. """
  229. if not cls._inited:
  230. cls._init_configuration()
  231. return cls._configurations
  232. @classmethod
  233. def get_model_configuration(cls, provider: str, model: str) -> Union[ModelToolConfiguration, None]:
  234. """
  235. get model configuration
  236. """
  237. key = f'{provider}.{model}'
  238. if not cls._inited:
  239. cls._init_configuration()
  240. return cls._model_configurations.get(key, None)