Selaa lähdekoodia

refactor: credentials schemas to array

Yeuoly 7 kuukautta sitten
vanhempi
commit
6dfc31a542

+ 3 - 0
api/core/entities/provider_entities.py

@@ -159,3 +159,6 @@ class ProviderConfig(BasicProviderConfig):
     help: Optional[I18nObject] = None
     url: Optional[str] = None
     placeholder: Optional[I18nObject] = None
+
+    def to_basic_provider_config(self) -> BasicProviderConfig:
+        return BasicProviderConfig(type=self.type, name=self.name)

+ 1 - 4
api/core/plugin/encrypt/__init__.py

@@ -1,6 +1,3 @@
-from collections.abc import Mapping
-from typing import Any
-
 from core.plugin.entities.request import RequestInvokeEncrypt
 from core.tools.utils.configuration import ProviderConfigEncrypter
 from models.account import Tenant
@@ -11,7 +8,7 @@ class PluginEncrypter:
     def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
         encrypter = ProviderConfigEncrypter(
             tenant_id=tenant.id,
-            config=payload.data,
+            config=payload.config,
             provider_type=payload.namespace,
             provider_identity=payload.identity,
         )

+ 1 - 2
api/core/plugin/entities/endpoint.py

@@ -1,4 +1,3 @@
-from collections.abc import Mapping
 from datetime import datetime
 
 from pydantic import BaseModel, Field
@@ -12,7 +11,7 @@ class EndpointDeclaration(BaseModel):
     declaration of an endpoint
     """
 
-    settings: Mapping[str, ProviderConfig] = Field(default_factory=Mapping)
+    settings: list[ProviderConfig] = Field(default_factory=list)
 
 
 class EndpointEntity(BasePluginEntity):

+ 1 - 2
api/core/plugin/entities/request.py

@@ -1,4 +1,3 @@
-from collections.abc import Mapping
 from typing import Any, Literal, Optional
 
 from pydantic import BaseModel, ConfigDict, Field, field_validator
@@ -181,4 +180,4 @@ class RequestInvokeEncrypt(BaseModel):
     namespace: Literal["endpoint"]
     identity: str
     data: dict = Field(default_factory=dict)
-    config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)
+    config: list[BasicProviderConfig] = Field(default_factory=list)

+ 7 - 3
api/core/tools/__base/tool_provider.py

@@ -1,4 +1,5 @@
 from abc import ABC, abstractmethod
+from copy import deepcopy
 from typing import Any
 
 from core.entities.provider_entities import ProviderConfig
@@ -16,13 +17,13 @@ class ToolProviderController(ABC):
     def __init__(self, entity: ToolProviderEntity) -> None:
         self.entity = entity
 
-    def get_credentials_schema(self) -> dict[str, ProviderConfig]:
+    def get_credentials_schema(self) -> list[ProviderConfig]:
         """
         returns the credentials schema of the provider
 
         :return: the credentials schema
         """
-        return self.entity.credentials_schema.copy()
+        return deepcopy(self.entity.credentials_schema)
 
     @abstractmethod
     def get_tool(self, tool_name: str) -> Tool:
@@ -48,10 +49,13 @@ class ToolProviderController(ABC):
 
         :param credentials: the credentials of the tool
         """
-        credentials_schema = self.entity.credentials_schema
+        credentials_schema = dict[str, ProviderConfig]()
         if credentials_schema is None:
             return
 
+        for credential in self.entity.credentials_schema:
+            credentials_schema[credential.name] = credential
+
         credentials_need_to_validate: dict[str, ProviderConfig] = {}
         for credential_name in credentials_schema:
             credentials_need_to_validate[credential_name] = credentials_schema[credential_name]

+ 7 - 3
api/core/tools/builtin_tool/provider.py

@@ -34,10 +34,14 @@ class BuiltinToolProviderController(ToolProviderController):
             for credential_name in provider_yaml["credentials_for_provider"]:
                 provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
 
+        credentials_schema = []
+        for credential in provider_yaml.get("credentials_for_provider", {}):
+            credentials_schema.append(credential)
+
         super().__init__(
             entity=ToolProviderEntity(
                 identity=provider_yaml["identity"],
-                credentials_schema=provider_yaml.get("credentials_for_provider", {}) or {},
+                credentials_schema=credentials_schema,
             ),
         )
 
@@ -84,14 +88,14 @@ class BuiltinToolProviderController(ToolProviderController):
         self.tools = tools
         return tools
 
-    def get_credentials_schema(self) -> dict[str, ProviderConfig]:
+    def get_credentials_schema(self) -> list[ProviderConfig]:
         """
         returns the credentials schema of the provider
 
         :return: the credentials schema
         """
         if not self.entity.credentials_schema:
-            return {}
+            return []
 
         return self.entity.credentials_schema.copy()
 

+ 0 - 1
api/core/tools/builtin_tool/providers/code/code.yaml

@@ -12,4 +12,3 @@ identity:
   icon: icon.svg
   tags:
     - productivity
-credentials_for_provider:

+ 0 - 1
api/core/tools/builtin_tool/providers/time/time.yaml

@@ -12,4 +12,3 @@ identity:
   icon: icon.svg
   tags:
     - utilities
-credentials_for_provider:

+ 9 - 9
api/core/tools/custom_tool/provider.py

@@ -28,8 +28,8 @@ class ApiToolProviderController(ToolProviderController):
 
     @classmethod
     def from_db(cls, db_provider: ApiToolProvider, auth_type: ApiProviderAuthType):
-        credentials_schema = {
-            "auth_type": ProviderConfig(
+        credentials_schema = [
+            ProviderConfig(
                 name="auth_type",
                 required=True,
                 type=ProviderConfig.Type.SELECT,
@@ -40,24 +40,24 @@ class ApiToolProviderController(ToolProviderController):
                 default="none",
                 help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
             )
-        }
+        ]
         if auth_type == ApiProviderAuthType.API_KEY:
-            credentials_schema = {
-                **credentials_schema,
-                "api_key_header": ProviderConfig(
+            credentials_schema = [
+                *credentials_schema,
+                ProviderConfig(
                     name="api_key_header",
                     required=False,
                     default="api_key",
                     type=ProviderConfig.Type.TEXT_INPUT,
                     help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
                 ),
-                "api_key_value": ProviderConfig(
+                ProviderConfig(
                     name="api_key_value",
                     required=True,
                     type=ProviderConfig.Type.SECRET_INPUT,
                     help=I18nObject(en_US="The api key", zh_Hans="api key的值"),
                 ),
-                "api_key_header_prefix": ProviderConfig(
+                ProviderConfig(
                     name="api_key_header_prefix",
                     required=False,
                     default="basic",
@@ -69,7 +69,7 @@ class ApiToolProviderController(ToolProviderController):
                         ProviderConfig.Option(value="custom", label=I18nObject(en_US="Custom", zh_Hans="Custom")),
                     ],
                 ),
-            }
+            ]
         elif auth_type == ApiProviderAuthType.NONE:
             pass
 

+ 0 - 5
api/core/tools/entities/api_entities.py

@@ -2,7 +2,6 @@ from typing import Literal, Optional
 
 from pydantic import BaseModel, Field
 
-from core.entities.provider_entities import ProviderConfig
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.tools.__base.tool import ToolParameter
 from core.tools.entities.common_entities import I18nObject
@@ -62,7 +61,3 @@ class ToolProviderApiEntity(BaseModel):
             "tools": tools,
             "labels": self.labels,
         }
-
-
-class ToolProviderCredentialsApiEntity(BaseModel):
-    credentials: dict[str, ProviderConfig]

+ 1 - 1
api/core/tools/entities/tool_entities.py

@@ -312,7 +312,7 @@ class ToolEntity(BaseModel):
 
 class ToolProviderEntity(BaseModel):
     identity: ToolProviderIdentity
-    credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
+    credentials_schema: list[ProviderConfig] = Field(default_factory=list)
 
 
 class ToolProviderEntityWithPlugin(ToolProviderEntity):

+ 3 - 3
api/core/tools/tool_manager.py

@@ -160,7 +160,7 @@ class ToolManager:
             credentials = builtin_provider.credentials
             tool_configuration = ProviderConfigEncrypter(
                 tenant_id=tenant_id,
-                config=provider_controller.get_credentials_schema(),
+                config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
                 provider_type=provider_controller.provider_type.value,
                 provider_identity=provider_controller.entity.identity.name,
             )
@@ -186,7 +186,7 @@ class ToolManager:
             # decrypt the credentials
             tool_configuration = ProviderConfigEncrypter(
                 tenant_id=tenant_id,
-                config=api_provider.get_credentials_schema(),
+                config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
                 provider_type=api_provider.provider_type.value,
                 provider_identity=api_provider.entity.identity.name,
             )
@@ -643,7 +643,7 @@ class ToolManager:
         # init tool configuration
         tool_configuration = ProviderConfigEncrypter(
             tenant_id=tenant_id,
-            config=controller.get_credentials_schema(),
+            config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
             provider_type=controller.provider_type.value,
             provider_identity=controller.entity.identity.name,
         )

+ 13 - 5
api/core/tools/utils/configuration.py

@@ -1,4 +1,3 @@
-from collections.abc import Mapping
 from copy import deepcopy
 from typing import Any
 
@@ -17,7 +16,7 @@ from core.tools.entities.tool_entities import (
 
 class ProviderConfigEncrypter(BaseModel):
     tenant_id: str
-    config: Mapping[str, BasicProviderConfig]
+    config: list[BasicProviderConfig]
     provider_type: str
     provider_identity: str
 
@@ -36,7 +35,10 @@ class ProviderConfigEncrypter(BaseModel):
         data = self._deep_copy(data)
 
         # get fields need to be decrypted
-        fields = self.config
+        fields = dict[str, BasicProviderConfig]()
+        for credential in self.config:
+            fields[credential.name] = credential
+
         for field_name, field in fields.items():
             if field.type == BasicProviderConfig.Type.SECRET_INPUT:
                 if field_name in data:
@@ -54,7 +56,10 @@ class ProviderConfigEncrypter(BaseModel):
         data = self._deep_copy(data)
 
         # get fields need to be decrypted
-        fields = self.config
+        fields = dict[str, BasicProviderConfig]()
+        for credential in self.config:
+            fields[credential.name] = credential
+
         for field_name, field in fields.items():
             if field.type == BasicProviderConfig.Type.SECRET_INPUT:
                 if field_name in data:
@@ -83,7 +88,10 @@ class ProviderConfigEncrypter(BaseModel):
             return cached_credentials
         data = self._deep_copy(data)
         # get fields need to be decrypted
-        fields = self.config
+        fields = dict[str, BasicProviderConfig]()
+        for credential in self.config:
+            fields[credential.name] = credential
+
         for field_name, field in fields.items():
             if field.type == BasicProviderConfig.Type.SECRET_INPUT:
                 if field_name in data:

+ 5 - 5
api/services/tools/builtin_tools_manage_service.py

@@ -35,7 +35,7 @@ class BuiltinToolManageService:
 
         tool_provider_configurations = ProviderConfigEncrypter(
             tenant_id=tenant_id,
-            config=provider_controller.get_credentials_schema(),
+            config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
             provider_type=provider_controller.provider_type.value,
             provider_identity=provider_controller.entity.identity.name,
         )
@@ -78,7 +78,7 @@ class BuiltinToolManageService:
         :return: the list of tool providers
         """
         provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
-        return jsonable_encoder([v for _, v in (provider.get_credentials_schema() or {}).items()])
+        return jsonable_encoder(provider.get_credentials_schema())
 
     @staticmethod
     def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
@@ -102,7 +102,7 @@ class BuiltinToolManageService:
                 raise ValueError(f"provider {provider_name} does not need credentials")
             tool_configuration = ProviderConfigEncrypter(
                 tenant_id=tenant_id,
-                config=provider_controller.get_credentials_schema(),
+                config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
                 provider_type=provider_controller.provider_type.value,
                 provider_identity=provider_controller.entity.identity.name,
             )
@@ -164,7 +164,7 @@ class BuiltinToolManageService:
         provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
         tool_configuration = ProviderConfigEncrypter(
             tenant_id=tenant_id,
-            config=provider_controller.get_credentials_schema(),
+            config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
             provider_type=provider_controller.provider_type.value,
             provider_identity=provider_controller.entity.identity.name,
         )
@@ -196,7 +196,7 @@ class BuiltinToolManageService:
         provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
         tool_configuration = ProviderConfigEncrypter(
             tenant_id=tenant_id,
-            config=provider_controller.get_credentials_schema(),
+            config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
             provider_type=provider_controller.provider_type.value,
             provider_identity=provider_controller.entity.identity.name,
         )

+ 4 - 3
api/services/tools/tools_transform_service.py

@@ -85,7 +85,8 @@ class ToolTransformService:
         )
 
         # get credentials schema
-        schema = provider_controller.get_credentials_schema()
+        schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
+
         for name, value in schema.items():
             if result.masked_credentials:
                 result.masked_credentials[name] = ""
@@ -103,7 +104,7 @@ class ToolTransformService:
                 # init tool configuration
                 tool_configuration = ProviderConfigEncrypter(
                     tenant_id=db_provider.tenant_id,
-                    config=provider_controller.get_credentials_schema(),
+                    config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
                     provider_type=provider_controller.provider_type.value,
                     provider_identity=provider_controller.entity.identity.name,
                 )
@@ -208,7 +209,7 @@ class ToolTransformService:
             # init tool configuration
             tool_configuration = ProviderConfigEncrypter(
                 tenant_id=db_provider.tenant_id,
-                config=provider_controller.get_credentials_schema(),
+                config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
                 provider_type=provider_controller.provider_type.value,
                 provider_identity=provider_controller.entity.identity.name,
             )