configuration.py 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from typing import Any
  2. from pydantic import BaseModel
  3. from core.helper import encrypter
  4. from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
  5. from core.tools.entities.tool_entities import ToolProviderCredentials
  6. from core.tools.provider.tool_provider import ToolProviderController
  7. class ToolConfiguration(BaseModel):
  8. tenant_id: str
  9. provider_controller: ToolProviderController
  10. def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
  11. """
  12. deep copy credentials
  13. """
  14. return {key: value for key, value in credentials.items()}
  15. def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  16. """
  17. encrypt tool credentials with tenant id
  18. return a deep copy of credentials with encrypted values
  19. """
  20. credentials = self._deep_copy(credentials)
  21. # get fields need to be decrypted
  22. fields = self.provider_controller.get_credentials_schema()
  23. for field_name, field in fields.items():
  24. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  25. if field_name in credentials:
  26. encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
  27. credentials[field_name] = encrypted
  28. return credentials
  29. def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
  30. """
  31. mask tool credentials
  32. return a deep copy of credentials with masked values
  33. """
  34. credentials = self._deep_copy(credentials)
  35. # get fields need to be decrypted
  36. fields = self.provider_controller.get_credentials_schema()
  37. for field_name, field in fields.items():
  38. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  39. if field_name in credentials:
  40. if len(credentials[field_name]) > 6:
  41. credentials[field_name] = \
  42. credentials[field_name][:2] + \
  43. '*' * (len(credentials[field_name]) - 4) +\
  44. credentials[field_name][-2:]
  45. else:
  46. credentials[field_name] = '*' * len(credentials[field_name])
  47. return credentials
  48. def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
  49. """
  50. decrypt tool credentials with tenant id
  51. return a deep copy of credentials with decrypted values
  52. """
  53. cache = ToolProviderCredentialsCache(
  54. tenant_id=self.tenant_id,
  55. identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
  56. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  57. )
  58. cached_credentials = cache.get()
  59. if cached_credentials:
  60. return cached_credentials
  61. credentials = self._deep_copy(credentials)
  62. # get fields need to be decrypted
  63. fields = self.provider_controller.get_credentials_schema()
  64. for field_name, field in fields.items():
  65. if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
  66. if field_name in credentials:
  67. try:
  68. credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
  69. except:
  70. pass
  71. cache.set(credentials)
  72. return credentials
  73. def delete_tool_credentials_cache(self):
  74. cache = ToolProviderCredentialsCache(
  75. tenant_id=self.tenant_id,
  76. identity_id=f'{self.provider_controller.app_type.value}.{self.provider_controller.identity.name}',
  77. cache_type=ToolProviderCredentialsCacheType.PROVIDER
  78. )
  79. cache.delete()