configuration.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. from copy import deepcopy
  2. from typing import Any
  3. from pydantic import BaseModel
  4. from core.helper import encrypter
  5. from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
  6. from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
  7. from core.tools.entities.tool_entities import (
  8. ToolParameter,
  9. ToolProviderCredentials,
  10. ToolProviderType,
  11. )
  12. from core.tools.provider.tool_provider import ToolProviderController
  13. from core.tools.tool.tool import Tool
  14. class ToolConfigurationManager(BaseModel):
  15. tenant_id: str
  16. provider_controller: ToolProviderController
  17. def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
  18. """
  19. deep copy credentials
  20. """
  21. return deepcopy(credentials)
  22. def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  23. """
  24. encrypt tool credentials with tenant id
  25. return a deep copy of credentials with encrypted values
  26. """
  27. credentials = self._deep_copy(credentials)
  28. # get fields need to be decrypted
  29. fields = self.provider_controller.get_credentials_schema()
  30. for field_name, field in fields.items():
  31. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  32. if field_name in credentials:
  33. encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
  34. credentials[field_name] = encrypted
  35. return credentials
  36. def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
  37. """
  38. mask tool credentials
  39. return a deep copy of credentials with masked values
  40. """
  41. credentials = self._deep_copy(credentials)
  42. # get fields need to be decrypted
  43. fields = self.provider_controller.get_credentials_schema()
  44. for field_name, field in fields.items():
  45. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  46. if field_name in credentials:
  47. if len(credentials[field_name]) > 6:
  48. credentials[field_name] = \
  49. credentials[field_name][:2] + \
  50. '*' * (len(credentials[field_name]) - 4) + \
  51. credentials[field_name][-2:]
  52. else:
  53. credentials[field_name] = '*' * len(credentials[field_name])
  54. return credentials
  55. def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  56. """
  57. decrypt tool credentials with tenant id
  58. return a deep copy of credentials with decrypted values
  59. """
  60. cache = ToolProviderCredentialsCache(
  61. tenant_id=self.tenant_id,
  62. identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
  63. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  64. )
  65. cached_credentials = cache.get()
  66. if cached_credentials:
  67. return cached_credentials
  68. credentials = self._deep_copy(credentials)
  69. # get fields need to be decrypted
  70. fields = self.provider_controller.get_credentials_schema()
  71. for field_name, field in fields.items():
  72. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  73. if field_name in credentials:
  74. try:
  75. credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
  76. except:
  77. pass
  78. cache.set(credentials)
  79. return credentials
  80. def delete_tool_credentials_cache(self):
  81. cache = ToolProviderCredentialsCache(
  82. tenant_id=self.tenant_id,
  83. identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
  84. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  85. )
  86. cache.delete()
  87. class ToolParameterConfigurationManager(BaseModel):
  88. """
  89. Tool parameter configuration manager
  90. """
  91. tenant_id: str
  92. tool_runtime: Tool
  93. provider_name: str
  94. provider_type: ToolProviderType
  95. identity_id: str
  96. def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
  97. """
  98. deep copy parameters
  99. """
  100. return deepcopy(parameters)
  101. def _merge_parameters(self) -> list[ToolParameter]:
  102. """
  103. merge parameters
  104. """
  105. # get tool parameters
  106. tool_parameters = self.tool_runtime.parameters or []
  107. # get tool runtime parameters
  108. runtime_parameters = self.tool_runtime.get_runtime_parameters() or []
  109. # override parameters
  110. current_parameters = tool_parameters.copy()
  111. for runtime_parameter in runtime_parameters:
  112. found = False
  113. for index, parameter in enumerate(current_parameters):
  114. if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
  115. current_parameters[index] = runtime_parameter
  116. found = True
  117. break
  118. if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
  119. current_parameters.append(runtime_parameter)
  120. return current_parameters
  121. def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  122. """
  123. mask tool parameters
  124. return a deep copy of parameters with masked values
  125. """
  126. parameters = self._deep_copy(parameters)
  127. # override parameters
  128. current_parameters = self._merge_parameters()
  129. for parameter in current_parameters:
  130. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  131. if parameter.name in parameters:
  132. if len(parameters[parameter.name]) > 6:
  133. parameters[parameter.name] = \
  134. parameters[parameter.name][:2] + \
  135. '*' * (len(parameters[parameter.name]) - 4) + \
  136. parameters[parameter.name][-2:]
  137. else:
  138. parameters[parameter.name] = '*' * len(parameters[parameter.name])
  139. return parameters
  140. def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  141. """
  142. encrypt tool parameters with tenant id
  143. return a deep copy of parameters with encrypted values
  144. """
  145. # override parameters
  146. current_parameters = self._merge_parameters()
  147. parameters = self._deep_copy(parameters)
  148. for parameter in current_parameters:
  149. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  150. if parameter.name in parameters:
  151. encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
  152. parameters[parameter.name] = encrypted
  153. return parameters
  154. def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
  155. """
  156. decrypt tool parameters with tenant id
  157. return a deep copy of parameters with decrypted values
  158. """
  159. cache = ToolParameterCache(
  160. tenant_id=self.tenant_id,
  161. provider=f'{self.provider_type.value}.{self.provider_name}',
  162. tool_name=self.tool_runtime.identity.name,
  163. cache_type=ToolParameterCacheType.PARAMETER,
  164. identity_id=self.identity_id
  165. )
  166. cached_parameters = cache.get()
  167. if cached_parameters:
  168. return cached_parameters
  169. # override parameters
  170. current_parameters = self._merge_parameters()
  171. has_secret_input = False
  172. for parameter in current_parameters:
  173. if parameter.form == ToolParameter.ToolParameterForm.FORM and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT:
  174. if parameter.name in parameters:
  175. try:
  176. has_secret_input = True
  177. parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
  178. except:
  179. pass
  180. if has_secret_input:
  181. cache.set(parameters)
  182. return parameters
  183. def delete_tool_parameters_cache(self):
  184. cache = ToolParameterCache(
  185. tenant_id=self.tenant_id,
  186. provider=f'{self.provider_type.value}.{self.provider_name}',
  187. tool_name=self.tool_runtime.identity.name,
  188. cache_type=ToolParameterCacheType.PARAMETER,
  189. identity_id=self.identity_id
  190. )
  191. cache.delete()