configuration.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. from collections.abc import Mapping
  2. from copy import deepcopy
  3. from typing import Any
  4. from pydantic import BaseModel
  5. from core.entities.provider_entities import BasicProviderConfig
  6. from core.helper import encrypter
  7. from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
  8. from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
  9. from core.tools.entities.tool_entities import (
  10. ToolParameter,
  11. ToolProviderType,
  12. )
  13. from core.tools.tool.tool import Tool
  14. class ProviderConfigEncrypter(BaseModel):
  15. tenant_id: str
  16. config: Mapping[str, BasicProviderConfig]
  17. provider_type: str
  18. provider_identity: str
  19. def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
  20. """
  21. deep copy data
  22. """
  23. return deepcopy(data)
  24. def encrypt(self, data: dict[str, str]) -> Mapping[str, str]:
  25. """
  26. encrypt tool credentials with tenant id
  27. return a deep copy of credentials with encrypted values
  28. """
  29. data = self._deep_copy(data)
  30. # get fields need to be decrypted
  31. fields = self.config
  32. for field_name, field in fields.items():
  33. if field.type == BasicProviderConfig.Type.SECRET_INPUT:
  34. if field_name in data:
  35. encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name])
  36. data[field_name] = encrypted
  37. return data
  38. def mask_tool_credentials(self, data: dict[str, Any]) -> Mapping[str, Any]:
  39. """
  40. mask tool credentials
  41. return a deep copy of credentials with masked values
  42. """
  43. data = self._deep_copy(data)
  44. # get fields need to be decrypted
  45. fields = self.config
  46. for field_name, field in fields.items():
  47. if field.type == BasicProviderConfig.Type.SECRET_INPUT:
  48. if field_name in data:
  49. if len(data[field_name]) > 6:
  50. data[field_name] = \
  51. data[field_name][:2] + \
  52. '*' * (len(data[field_name]) - 4) + \
  53. data[field_name][-2:]
  54. else:
  55. data[field_name] = '*' * len(data[field_name])
  56. return data
  57. def decrypt(self, data: dict[str, str]) -> Mapping[str, str]:
  58. """
  59. decrypt tool credentials with tenant id
  60. return a deep copy of credentials with decrypted values
  61. """
  62. cache = ToolProviderCredentialsCache(
  63. tenant_id=self.tenant_id,
  64. identity_id=f'{self.provider_type}.{self.provider_identity}',
  65. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  66. )
  67. cached_credentials = cache.get()
  68. if cached_credentials:
  69. return cached_credentials
  70. data = self._deep_copy(data)
  71. # get fields need to be decrypted
  72. fields = self.config
  73. for field_name, field in fields.items():
  74. if field.type == BasicProviderConfig.Type.SECRET_INPUT:
  75. if field_name in data:
  76. try:
  77. data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
  78. except:
  79. pass
  80. cache.set(data)
  81. return data
  82. def delete_tool_credentials_cache(self):
  83. cache = ToolProviderCredentialsCache(
  84. tenant_id=self.tenant_id,
  85. identity_id=f'{self.provider_type}.{self.provider_identity}',
  86. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  87. )
  88. cache.delete()
  89. class ToolParameterConfigurationManager(BaseModel):
  90. """
  91. Tool parameter configuration manager
  92. """
  93. tenant_id: str
  94. tool_runtime: Tool
  95. provider_name: str
  96. provider_type: ToolProviderType
  97. identity_id: str
  98. def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
  99. """
  100. deep copy parameters
  101. """
  102. return deepcopy(parameters)
  103. def _merge_parameters(self) -> list[ToolParameter]:
  104. """
  105. merge parameters
  106. """
  107. # get tool parameters
  108. tool_parameters = self.tool_runtime.parameters or []
  109. # get tool runtime parameters
  110. runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
  111. # override parameters
  112. current_parameters = tool_parameters.copy()
  113. for runtime_parameter in runtime_parameters:
  114. found = False
  115. for index, parameter in enumerate(current_parameters):
  116. if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
  117. current_parameters[index] = runtime_parameter
  118. found = True
  119. break
  120. if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  121. current_parameters.append(runtime_parameter)
  122. return current_parameters
  123. def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  124. """
  125. mask tool parameters
  126. return a deep copy of parameters with masked values
  127. """
  128. parameters = self._deep_copy(parameters)
  129. # override parameters
  130. current_parameters = self._merge_parameters()
  131. for parameter in current_parameters:
  132. if (
  133. parameter.form == ToolParameter.ToolParameterForm.FORM
  134. and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
  135. ):
  136. if parameter.name in parameters:
  137. if len(parameters[parameter.name]) > 6:
  138. parameters[parameter.name] = (
  139. parameters[parameter.name][:2]
  140. + "*" * (len(parameters[parameter.name]) - 4)
  141. + parameters[parameter.name][-2:]
  142. )
  143. else:
  144. parameters[parameter.name] = "*" * len(parameters[parameter.name])
  145. return parameters
  146. def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  147. """
  148. encrypt tool parameters with tenant id
  149. return a deep copy of parameters with encrypted values
  150. """
  151. # override parameters
  152. current_parameters = self._merge_parameters()
  153. parameters = self._deep_copy(parameters)
  154. for parameter in current_parameters:
  155. if (
  156. parameter.form == ToolParameter.ToolParameterForm.FORM
  157. and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
  158. ):
  159. if parameter.name in parameters:
  160. encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
  161. parameters[parameter.name] = encrypted
  162. return parameters
  163. def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  164. """
  165. decrypt tool parameters with tenant id
  166. return a deep copy of parameters with decrypted values
  167. """
  168. cache = ToolParameterCache(
  169. tenant_id=self.tenant_id,
  170. provider=f'{self.provider_type.value}.{self.provider_name}',
  171. tool_name=self.tool_runtime.identity.name,
  172. cache_type=ToolParameterCacheType.PARAMETER,
  173. identity_id=self.identity_id,
  174. )
  175. cached_parameters = cache.get()
  176. if cached_parameters:
  177. return cached_parameters
  178. # override parameters
  179. current_parameters = self._merge_parameters()
  180. has_secret_input = False
  181. for parameter in current_parameters:
  182. if (
  183. parameter.form == ToolParameter.ToolParameterForm.FORM
  184. and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
  185. ):
  186. if parameter.name in parameters:
  187. try:
  188. has_secret_input = True
  189. parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
  190. except:
  191. pass
  192. if has_secret_input:
  193. cache.set(parameters)
  194. return parameters
  195. def delete_tool_parameters_cache(self):
  196. cache = ToolParameterCache(
  197. tenant_id=self.tenant_id,
  198. provider=f'{self.provider_type.value}.{self.provider_name}',
  199. tool_name=self.tool_runtime.identity.name,
  200. cache_type=ToolParameterCacheType.PARAMETER,
  201. identity_id=self.identity_id,
  202. )
  203. cache.delete()