configuration.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233
  1. from collections.abc import Mapping
  2. from copy import deepcopy
  3. from typing import Any
  4. from pydantic import BaseModel
  5. from core.entities.provider_entities import BasicProviderConfig
  6. from core.helper import encrypter
  7. from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
  8. from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
  9. from core.tools.entities.tool_entities import (
  10. ToolParameter,
  11. ToolProviderType,
  12. )
  13. from core.tools.tool.tool import Tool
  14. class ToolConfigurationManager(BaseModel):
  15. tenant_id: str
  16. config: Mapping[str, BasicProviderConfig]
  17. provider_type: str
  18. provider_identity: str
  19. def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
  20. """
  21. deep copy credentials
  22. """
  23. return deepcopy(credentials)
  24. def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  25. """
  26. encrypt tool credentials with tenant id
  27. return a deep copy of credentials with encrypted values
  28. """
  29. credentials = self._deep_copy(credentials)
  30. # get fields need to be decrypted
  31. fields = self.config
  32. for field_name, field in fields.items():
  33. if field.type == BasicProviderConfig.Type.SECRET_INPUT:
  34. if field_name in credentials:
  35. encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
  36. credentials[field_name] = encrypted
  37. return credentials
  38. def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
  39. """
  40. mask tool credentials
  41. return a deep copy of credentials with masked values
  42. """
  43. credentials = self._deep_copy(credentials)
  44. # get fields need to be decrypted
  45. fields = self.config
  46. for field_name, field in fields.items():
  47. if field.type == BasicProviderConfig.Type.SECRET_INPUT:
  48. if field_name in credentials:
  49. if len(credentials[field_name]) > 6:
  50. credentials[field_name] = \
  51. credentials[field_name][:2] + \
  52. '*' * (len(credentials[field_name]) - 4) + \
  53. credentials[field_name][-2:]
  54. else:
  55. credentials[field_name] = '*' * len(credentials[field_name])
  56. return credentials
  57. def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  58. """
  59. decrypt tool credentials with tenant id
  60. return a deep copy of credentials with decrypted values
  61. """
  62. cache = ToolProviderCredentialsCache(
  63. tenant_id=self.tenant_id,
  64. identity_id=f'{self.provider_type}.{self.provider_identity}',
  65. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  66. )
  67. cached_credentials = cache.get()
  68. if cached_credentials:
  69. return cached_credentials
  70. credentials = self._deep_copy(credentials)
  71. # get fields need to be decrypted
  72. fields = self.config
  73. for field_name, field in fields.items():
  74. if field.type == BasicProviderConfig.Type.SECRET_INPUT:
  75. if field_name in credentials:
  76. try:
  77. credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
  78. except:
  79. pass
  80. cache.set(credentials)
  81. return credentials
  82. def delete_tool_credentials_cache(self):
  83. cache = ToolProviderCredentialsCache(
  84. tenant_id=self.tenant_id,
  85. identity_id=f'{self.provider_type}.{self.provider_identity}',
  86. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  87. )
  88. cache.delete()
  89. class ToolParameterConfigurationManager(BaseModel):
  90. """
  91. Tool parameter configuration manager
  92. """
  93. tenant_id: str
  94. tool_runtime: Tool
  95. provider_name: str
  96. provider_type: ToolProviderType
  97. identity_id: str
  98. def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
  99. """
  100. deep copy parameters
  101. """
  102. return deepcopy(parameters)
  103. def _merge_parameters(self) -> list[ToolParameter]:
  104. """
  105. merge parameters
  106. """
  107. # get tool parameters
  108. tool_parameters = self.tool_runtime.parameters or []
  109. # get tool runtime parameters
  110. runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
  111. # override parameters
  112. current_parameters = tool_parameters.copy()
  113. for runtime_parameter in runtime_parameters:
  114. found = False
  115. for index, parameter in enumerate(current_parameters):
  116. if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
  117. current_parameters[index] = runtime_parameter
  118. found = True
  119. break
  120. if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  121. current_parameters.append(runtime_parameter)
  122. return current_parameters
  123. def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  124. """
  125. mask tool parameters
  126. return a deep copy of parameters with masked values
  127. """
  128. parameters = self._deep_copy(parameters)
  129. # override parameters
  130. current_parameters = self._merge_parameters()
  131. for parameter in current_parameters:
  132. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  133. if parameter.name in parameters:
  134. if len(parameters[parameter.name]) > 6:
  135. parameters[parameter.name] = \
  136. parameters[parameter.name][:2] + \
  137. '*' * (len(parameters[parameter.name]) - 4) + \
  138. parameters[parameter.name][-2:]
  139. else:
  140. parameters[parameter.name] = '*' * len(parameters[parameter.name])
  141. return parameters
  142. def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  143. """
  144. encrypt tool parameters with tenant id
  145. return a deep copy of parameters with encrypted values
  146. """
  147. # override parameters
  148. current_parameters = self._merge_parameters()
  149. parameters = self._deep_copy(parameters)
  150. for parameter in current_parameters:
  151. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  152. if parameter.name in parameters:
  153. encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
  154. parameters[parameter.name] = encrypted
  155. return parameters
  156. def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  157. """
  158. decrypt tool parameters with tenant id
  159. return a deep copy of parameters with decrypted values
  160. """
  161. cache = ToolParameterCache(
  162. tenant_id=self.tenant_id,
  163. provider=f'{self.provider_type.value}.{self.provider_name}',
  164. tool_name=self.tool_runtime.identity.name,
  165. cache_type=ToolParameterCacheType.PARAMETER,
  166. identity_id=self.identity_id
  167. )
  168. cached_parameters = cache.get()
  169. if cached_parameters:
  170. return cached_parameters
  171. # override parameters
  172. current_parameters = self._merge_parameters()
  173. has_secret_input = False
  174. for parameter in current_parameters:
  175. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  176. if parameter.name in parameters:
  177. try:
  178. has_secret_input = True
  179. parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
  180. except:
  181. pass
  182. if has_secret_input:
  183. cache.set(parameters)
  184. return parameters
  185. def delete_tool_parameters_cache(self):
  186. cache = ToolParameterCache(
  187. tenant_id=self.tenant_id,
  188. provider=f'{self.provider_type.value}.{self.provider_name}',
  189. tool_name=self.tool_runtime.identity.name,
  190. cache_type=ToolParameterCacheType.PARAMETER,
  191. identity_id=self.identity_id
  192. )
  193. cache.delete()