Browse Source

feat: inner api encrypt

Yeuoly 9 months ago
parent
commit
de01ca8d55

+ 6 - 1
api/controllers/inner_api/plugin/plugin.py

@@ -8,6 +8,7 @@ from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
 from controllers.inner_api.wraps import plugin_inner_api_only
 from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
 from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
+from core.plugin.encrypt import PluginEncrypter
 from core.plugin.entities.request import (
     RequestInvokeApp,
     RequestInvokeEncrypt,
@@ -139,7 +140,10 @@ class PluginInvokeEncryptApi(Resource):
     @get_tenant
     @plugin_data(payload_type=RequestInvokeEncrypt)
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeEncrypt):
-        """"""
+        """
+        encrypt or decrypt data
+        """
+        return PluginEncrypter.invoke_encrypt(tenant_model, payload)
 
 api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
 api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')
@@ -150,3 +154,4 @@ api.add_resource(PluginInvokeModerationApi, '/invoke/moderation')
 api.add_resource(PluginInvokeToolApi, '/invoke/tool')
 api.add_resource(PluginInvokeNodeApi, '/invoke/node')
 api.add_resource(PluginInvokeAppApi, '/invoke/app')
+api.add_resource(PluginInvokeEncryptApi, '/invoke/encrypt')

+ 22 - 0
api/core/plugin/encrypt/__init__.py

@@ -0,0 +1,22 @@
+from collections.abc import Mapping
+from typing import Any
+
+from core.plugin.entities.request import RequestInvokeEncrypt
+from core.tools.utils.configuration import ProviderConfigEncrypter
+from models.account import Tenant
+
+
+class PluginEncrypter:
+    @classmethod
+    def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> Mapping[str, Any]:
+        encrypter = ProviderConfigEncrypter(
+            tenant_id=tenant.id,
+            config=payload.data,
+            provider_type=payload.type,
+            provider_identity=payload.identity,
+        )
+
+        if payload.opt == "encrypt":
+            return encrypter.encrypt(payload.data)
+        else:
+            return encrypter.decrypt(payload.data)

+ 3 - 1
api/core/plugin/entities/request.py

@@ -112,5 +112,7 @@ class RequestInvokeEncrypt(BaseModel):
     Request to encryption
     """
     opt: Literal["encrypt", "decrypt"]
+    type: Literal["endpoint"]
+    identity: str
     data: dict = Field(default_factory=dict)
-    config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)
+    config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)

+ 7 - 7
api/core/tools/tool_manager.py

@@ -24,7 +24,7 @@ from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.tool import Tool
 from core.tools.tool.workflow_tool import WorkflowTool
 from core.tools.tool_label_manager import ToolLabelManager
-from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
+from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
 from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 from core.workflow.nodes.tool.entities import ToolEntity
 from extensions.ext_database import db
@@ -116,14 +116,14 @@ class ToolManager:
             # decrypt the credentials
             credentials = builtin_provider.credentials
             controller = cls.get_builtin_provider(provider_id)
-            tool_configuration = ToolConfigurationManager(
+            tool_configuration = ProviderConfigEncrypter(
                 tenant_id=tenant_id, 
                 config=controller.get_credentials_schema(),
                 provider_type=controller.provider_type.value,
                 provider_identity=controller.identity.name
             )
 
-            decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
+            decrypted_credentials = tool_configuration.decrypt(credentials)
 
             return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={
                 'tenant_id': tenant_id,
@@ -140,13 +140,13 @@ class ToolManager:
             api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
 
             # decrypt the credentials
-            tool_configuration = ToolConfigurationManager(
+            tool_configuration = ProviderConfigEncrypter(
                 tenant_id=tenant_id, 
                 config=api_provider.get_credentials_schema(),
                 provider_type=api_provider.provider_type.value,
                 provider_identity=api_provider.identity.name
             )
-            decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
+            decrypted_credentials = tool_configuration.decrypt(credentials)
 
             return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
                 'tenant_id': tenant_id,
@@ -523,14 +523,14 @@ class ToolManager:
             provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
         )
         # init tool configuration
-        tool_configuration = ToolConfigurationManager(
+        tool_configuration = ProviderConfigEncrypter(
             tenant_id=tenant_id,
             config=controller.get_credentials_schema(),
             provider_type=controller.provider_type.value,
             provider_identity=controller.identity.name
         )
 
-        decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
+        decrypted_credentials = tool_configuration.decrypt(credentials)
         masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
 
         try:

+ 26 - 26
api/core/tools/utils/configuration.py

@@ -15,60 +15,60 @@ from core.tools.entities.tool_entities import (
 from core.tools.tool.tool import Tool
 
 
-class ToolConfigurationManager(BaseModel):
+class ProviderConfigEncrypter(BaseModel):
     tenant_id: str
     config: Mapping[str, BasicProviderConfig]
     provider_type: str
     provider_identity: str
 
-    def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
+    def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
         """
-        deep copy credentials
+        deep copy data
         """
-        return deepcopy(credentials)
+        return deepcopy(data)
 
-    def encrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
+    def encrypt(self, data: dict[str, str]) -> Mapping[str, str]:
         """
         encrypt tool credentials with tenant id
 
         return a deep copy of credentials with encrypted values
         """
-        credentials = self._deep_copy(credentials)
+        data = self._deep_copy(data)
 
         # get fields need to be decrypted
         fields = self.config
         for field_name, field in fields.items():
             if field.type == BasicProviderConfig.Type.SECRET_INPUT:
-                if field_name in credentials:
-                    encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
-                    credentials[field_name] = encrypted
+                if field_name in data:
+                    encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name])
+                    data[field_name] = encrypted
 
-        return credentials
+        return data
 
-    def mask_tool_credentials(self, credentials: dict[str, Any]) -> dict[str, Any]:
+    def mask_tool_credentials(self, data: dict[str, Any]) -> Mapping[str, Any]:
         """
         mask tool credentials
 
         return a deep copy of credentials with masked values
         """
-        credentials = self._deep_copy(credentials)
+        data = self._deep_copy(data)
 
         # get fields need to be decrypted
         fields = self.config
         for field_name, field in fields.items():
             if field.type == BasicProviderConfig.Type.SECRET_INPUT:
-                if field_name in credentials:
-                    if len(credentials[field_name]) > 6:
-                        credentials[field_name] = \
-                            credentials[field_name][:2] + \
-                            '*' * (len(credentials[field_name]) - 4) + \
-                            credentials[field_name][-2:]
+                if field_name in data:
+                    if len(data[field_name]) > 6:
+                        data[field_name] = \
+                            data[field_name][:2] + \
+                            '*' * (len(data[field_name]) - 4) + \
+                            data[field_name][-2:]
                     else:
-                        credentials[field_name] = '*' * len(credentials[field_name])
+                        data[field_name] = '*' * len(data[field_name])
 
-        return credentials
+        return data
 
-    def decrypt_tool_credentials(self, credentials: dict[str, str]) -> dict[str, str]:
+    def decrypt(self, data: dict[str, str]) -> Mapping[str, str]:
         """
         decrypt tool credentials with tenant id
 
@@ -82,19 +82,19 @@ class ToolConfigurationManager(BaseModel):
         cached_credentials = cache.get()
         if cached_credentials:
             return cached_credentials
-        credentials = self._deep_copy(credentials)
+        data = self._deep_copy(data)
         # get fields need to be decrypted
         fields = self.config
         for field_name, field in fields.items():
             if field.type == BasicProviderConfig.Type.SECRET_INPUT:
-                if field_name in credentials:
+                if field_name in data:
                     try:
-                        credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
+                        data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
                     except:
                         pass
 
-        cache.set(credentials)
-        return credentials
+        cache.set(data)
+        return data
 
     def delete_tool_credentials_cache(self):
         cache = ToolProviderCredentialsCache(

+ 8 - 8
api/services/tools/api_tools_manage_service.py

@@ -15,7 +15,7 @@ from core.tools.entities.tool_entities import (
 from core.tools.provider.api_tool_provider import ApiToolProviderController
 from core.tools.tool_label_manager import ToolLabelManager
 from core.tools.tool_manager import ToolManager
-from core.tools.utils.configuration import ToolConfigurationManager
+from core.tools.utils.configuration import ProviderConfigEncrypter
 from core.tools.utils.parser import ApiBasedToolSchemaParser
 from extensions.ext_database import db
 from models.tools import ApiToolProvider
@@ -156,14 +156,14 @@ class ApiToolManageService:
         provider_controller.load_bundled_tools(tool_bundles)
 
         # encrypt credentials
-        tool_configuration = ToolConfigurationManager(
+        tool_configuration = ProviderConfigEncrypter(
             tenant_id=tenant_id,
             config=provider_controller.get_credentials_schema(),
             provider_type=provider_controller.provider_type.value,
             provider_identity=provider_controller.identity.name
         )
 
-        encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
+        encrypted_credentials = tool_configuration.encrypt(credentials)
         db_provider.credentials_str = json.dumps(encrypted_credentials)
 
         db.session.add(db_provider)
@@ -286,21 +286,21 @@ class ApiToolManageService:
         provider_controller.load_bundled_tools(tool_bundles)
 
         # get original credentials if exists
-        tool_configuration = ToolConfigurationManager(
+        tool_configuration = ProviderConfigEncrypter(
             tenant_id=tenant_id,
             config=provider_controller.get_credentials_schema(),
             provider_type=provider_controller.provider_type.value,
             provider_identity=provider_controller.identity.name
         )
 
-        original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
+        original_credentials = tool_configuration.decrypt(provider.credentials)
         masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
         # check if the credential has changed, save the original credential
         for name, value in credentials.items():
             if name in masked_credentials and value == masked_credentials[name]:
                 credentials[name] = original_credentials[name]
 
-        credentials = tool_configuration.encrypt_tool_credentials(credentials)
+        credentials = tool_configuration.encrypt(credentials)
         provider.credentials_str = json.dumps(credentials)
 
         db.session.add(provider)
@@ -405,13 +405,13 @@ class ApiToolManageService:
 
         # decrypt credentials
         if db_provider.id:
-            tool_configuration = ToolConfigurationManager(
+            tool_configuration = ProviderConfigEncrypter(
                 tenant_id=tenant_id,
                 config=provider_controller.get_credentials_schema(),
                 provider_type=provider_controller.provider_type.value,
                 provider_identity=provider_controller.identity.name
             )
-            decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
+            decrypted_credentials = tool_configuration.decrypt(credentials)
             # check if the credential has changed, save the original credential
             masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
             for name, value in credentials.items():

+ 9 - 9
api/services/tools/builtin_tools_manage_service.py

@@ -10,7 +10,7 @@ from core.tools.provider.builtin._positions import BuiltinToolProviderSort
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.tool_label_manager import ToolLabelManager
 from core.tools.tool_manager import ToolManager
-from core.tools.utils.configuration import ToolConfigurationManager
+from core.tools.utils.configuration import ProviderConfigEncrypter
 from extensions.ext_database import db
 from models.tools import BuiltinToolProvider
 from services.tools.tools_transform_service import ToolTransformService
@@ -27,7 +27,7 @@ class BuiltinToolManageService:
         provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
         tools = provider_controller.get_tools()
 
-        tool_provider_configurations = ToolConfigurationManager(
+        tool_provider_configurations = ProviderConfigEncrypter(
             tenant_id=tenant_id,
             config=provider_controller.get_credentials_schema(),
             provider_type=provider_controller.provider_type.value,
@@ -47,7 +47,7 @@ class BuiltinToolManageService:
         if builtin_provider is not None:
             # get credentials
             credentials = builtin_provider.credentials
-            credentials = tool_provider_configurations.decrypt_tool_credentials(credentials)
+            credentials = tool_provider_configurations.decrypt(credentials)
 
         result = []
         for tool in tools:
@@ -92,7 +92,7 @@ class BuiltinToolManageService:
             provider_controller = ToolManager.get_builtin_provider(provider_name)
             if not provider_controller.need_credentials:
                 raise ValueError(f"provider {provider_name} does not need credentials")
-            tool_configuration = ToolConfigurationManager(
+            tool_configuration = ProviderConfigEncrypter(
                 tenant_id=tenant_id,
                 config=provider_controller.get_credentials_schema(),
                 provider_type=provider_controller.provider_type.value,
@@ -101,7 +101,7 @@ class BuiltinToolManageService:
 
             # get original credentials if exists
             if provider is not None:
-                original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
+                original_credentials = tool_configuration.decrypt(provider.credentials)
                 masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
                 # check if the credential has changed, save the original credential
                 for name, value in credentials.items():
@@ -110,7 +110,7 @@ class BuiltinToolManageService:
             # validate credentials
             provider_controller.validate_credentials(credentials)
             # encrypt credentials
-            credentials = tool_configuration.encrypt_tool_credentials(credentials)
+            credentials = tool_configuration.encrypt(credentials)
         except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
             raise ValueError(str(e))
 
@@ -154,13 +154,13 @@ class BuiltinToolManageService:
             return {}
 
         provider_controller = ToolManager.get_builtin_provider(provider_obj.provider)
-        tool_configuration = ToolConfigurationManager(
+        tool_configuration = ProviderConfigEncrypter(
             tenant_id=tenant_id,
             config=provider_controller.get_credentials_schema(),
             provider_type=provider_controller.provider_type.value,
             provider_identity=provider_controller.identity.name,
         )
-        credentials = tool_configuration.decrypt_tool_credentials(provider_obj.credentials)
+        credentials = tool_configuration.decrypt(provider_obj.credentials)
         credentials = tool_configuration.mask_tool_credentials(credentials)
         return credentials
 
@@ -186,7 +186,7 @@ class BuiltinToolManageService:
 
         # delete cache
         provider_controller = ToolManager.get_builtin_provider(provider_name)
-        tool_configuration = ToolConfigurationManager(
+        tool_configuration = ProviderConfigEncrypter(
             tenant_id=tenant_id,
             config=provider_controller.get_credentials_schema(),
             provider_type=provider_controller.provider_type.value,

+ 7 - 7
api/services/tools/tools_transform_service.py

@@ -16,7 +16,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
 from core.tools.provider.workflow_tool_provider import WorkflowToolProviderController
 from core.tools.tool.tool import Tool
 from core.tools.tool.workflow_tool import WorkflowTool
-from core.tools.utils.configuration import ToolConfigurationManager
+from core.tools.utils.configuration import ProviderConfigEncrypter
 from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
 
 logger = logging.getLogger(__name__)
@@ -107,15 +107,15 @@ class ToolTransformService:
                 credentials = db_provider.credentials
 
                 # init tool configuration
-                tool_configuration = ToolConfigurationManager(
+                tool_configuration = ProviderConfigEncrypter(
                     tenant_id=db_provider.tenant_id,
                     config=provider_controller.get_credentials_schema(),
                     provider_type=provider_controller.provider_type.value,
                     provider_identity=provider_controller.identity.name
                 )
                 # decrypt the credentials and mask the credentials
-                decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
-                masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
+                decrypted_credentials = tool_configuration.decrypt(data=credentials)
+                masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
 
                 result.masked_credentials = masked_credentials
                 result.original_credentials = decrypted_credentials
@@ -218,7 +218,7 @@ class ToolTransformService:
 
         if decrypt_credentials:
             # init tool configuration
-            tool_configuration = ToolConfigurationManager(
+            tool_configuration = ProviderConfigEncrypter(
                 tenant_id=db_provider.tenant_id,
                 config=provider_controller.get_credentials_schema(),
                 provider_type=provider_controller.provider_type.value,
@@ -226,8 +226,8 @@ class ToolTransformService:
             )
 
             # decrypt the credentials and mask the credentials
-            decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
-            masked_credentials = tool_configuration.mask_tool_credentials(credentials=decrypted_credentials)
+            decrypted_credentials = tool_configuration.decrypt(data=credentials)
+            masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
 
             result.masked_credentials = masked_credentials