Browse Source

fix: invoke tool streamingly

Yeuoly 9 months ago
parent
commit
886a160115

+ 2 - 2
api/core/entities/provider_entities.py

@@ -4,8 +4,8 @@ from typing import Optional, Union
 from pydantic import BaseModel, ConfigDict, Field
 
 from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
-from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import ModelType
+from core.tools.entities.common_entities import I18nObject
 from models.provider import ProviderQuotaType
 
 
@@ -143,7 +143,7 @@ class ProviderConfig(BasicProviderConfig):
         value: str = Field(..., description="The value of the option")
         label: I18nObject = Field(..., description="The label of the option")
 
-    scope: AppSelectorScope | ModelConfigScope | None
+    scope: AppSelectorScope | ModelConfigScope | None = None
     required: bool = False
     default: Optional[Union[int, str]] = None
     options: Optional[list[Option]] = None

+ 1 - 0
api/core/helper/tool_provider_cache.py

@@ -8,6 +8,7 @@ from extensions.ext_redis import redis_client
 
 class ToolProviderCredentialsCacheType(Enum):
     PROVIDER = "tool_provider"
+    ENDPOINT = "endpoint"
 
 class ToolProviderCredentialsCache:
     def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):

+ 6 - 5
api/core/tools/entities/api_entities.py

@@ -1,10 +1,11 @@
 from typing import Literal, Optional
 
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
 
+from core.entities.provider_entities import ProviderConfig
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_entities import ProviderConfig, ToolProviderType
+from core.tools.entities.tool_entities import ToolProviderType
 from core.tools.tool.tool import ToolParameter
 
 
@@ -14,7 +15,7 @@ class UserTool(BaseModel):
     label: I18nObject # label
     description: I18nObject
     parameters: Optional[list[ToolParameter]] = None
-    labels: list[str] = None
+    labels: list[str] = Field(default_factory=list)
 
 UserToolProviderTypeLiteral = Optional[Literal[
     'builtin', 'api', 'workflow'
@@ -32,8 +33,8 @@ class UserToolProvider(BaseModel):
     original_credentials: Optional[dict] = None
     is_team_authorization: bool = False
     allow_delete: bool = True
-    tools: list[UserTool] = None
-    labels: list[str] = None
+    tools: list[UserTool] = Field(default_factory=list)
+    labels: list[str] = Field(default_factory=list)
 
     def to_dict(self) -> dict:
         # -------------

+ 2 - 2
api/core/tools/entities/tool_entities.py

@@ -25,7 +25,7 @@ class ToolLabelEnum(Enum):
     UTILITIES = 'utilities'
     OTHER = 'other'
 
-class ToolProviderType(Enum):
+class ToolProviderType(str, Enum):
     """
         Enum class for tool provider
     """
@@ -181,7 +181,7 @@ class ToolParameter(BaseModel):
         if options:
             option_objs = [ToolParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) for option in options]
         else:
-            option_objs = None
+            option_objs = []
         return cls(
             name=name,
             label=I18nObject(en_US='', zh_Hans=''),

+ 14 - 11
api/core/tools/provider/api_tool_provider.py

@@ -1,21 +1,23 @@
 
+from pydantic import Field
+
+from core.entities.provider_entities import ProviderConfig
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import (
     ApiProviderAuthType,
-    ProviderConfig,
-    ToolCredentialsOption,
     ToolProviderType,
 )
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.tool.api_tool import ApiTool
-from core.tools.tool.tool import Tool
 from extensions.ext_database import db
 from models.tools import ApiToolProvider
 
 
 class ApiToolProviderController(ToolProviderController):
     provider_id: str
+    tenant_id: str
+    tools: list[ApiTool] = Field(default_factory=list)
 
     @staticmethod
     def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
@@ -25,8 +27,8 @@ class ApiToolProviderController(ToolProviderController):
                 required=True,
                 type=ProviderConfig.Type.SELECT,
                 options=[
-                    ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
-                    ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
+                    ProviderConfig.Option(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
+                    ProviderConfig.Option(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
                 ],
                 default='none',
                 help=I18nObject(
@@ -67,9 +69,9 @@ class ApiToolProviderController(ToolProviderController):
                         zh_Hans='api key header 的前缀'
                     ),
                     options=[
-                        ToolCredentialsOption(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
-                        ToolCredentialsOption(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
-                        ToolCredentialsOption(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
+                        ProviderConfig.Option(value='basic', label=I18nObject(en_US='Basic', zh_Hans='Basic')),
+                        ProviderConfig.Option(value='bearer', label=I18nObject(en_US='Bearer', zh_Hans='Bearer')),
+                        ProviderConfig.Option(value='custom', label=I18nObject(en_US='Custom', zh_Hans='Custom'))
                     ]
                 )
             }
@@ -96,6 +98,7 @@ class ApiToolProviderController(ToolProviderController):
             },
             'credentials_schema': credentials_schema,
             'provider_id': db_provider.id or '',
+            'tenant_id': db_provider.tenant_id or '',
         })
 
     @property
@@ -142,7 +145,7 @@ class ApiToolProviderController(ToolProviderController):
 
         return self.tools
 
-    def get_tools(self, user_id: str, tenant_id: str) -> list[ApiTool]:
+    def get_tools(self, tenant_id: str) -> list[ApiTool]:
         """
             fetch tools from database
 
@@ -153,7 +156,7 @@ class ApiToolProviderController(ToolProviderController):
         if self.tools is not None:
             return self.tools
         
-        tools: list[Tool] = []
+        tools: list[ApiTool] = []
 
         # get tenant api providers
         db_providers: list[ApiToolProvider] = db.session.query(ApiToolProvider).filter(
@@ -179,7 +182,7 @@ class ApiToolProviderController(ToolProviderController):
             :return: the tool
         """
         if self.tools is None:
-            self.get_tools()
+            self.get_tools(self.tenant_id)
 
         for tool in self.tools:
             if tool.identity.name == tool_name:

+ 1 - 1
api/core/tools/provider/builtin_tool_provider.py

@@ -39,7 +39,7 @@ class BuiltinToolProviderController(ToolProviderController):
 
         super().__init__(**{
             'identity': provider_yaml['identity'],
-            'credentials_schema': provider_yaml.get('credentials_for_provider', None),
+            'credentials_schema': provider_yaml.get('credentials_for_provider', {}) or {},
         })
 
     def _get_builtin_tools(self) -> list[BuiltinTool]:

+ 3 - 1
api/core/tools/provider/tool_provider.py

@@ -1,7 +1,7 @@
 from abc import ABC, abstractmethod
 from typing import Any
 
-from pydantic import BaseModel, Field
+from pydantic import BaseModel, ConfigDict, Field
 
 from core.entities.provider_entities import ProviderConfig
 from core.tools.entities.tool_entities import (
@@ -17,6 +17,8 @@ class ToolProviderController(BaseModel, ABC):
     tools: list[Tool] = Field(default_factory=list)
     credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
 
+    model_config = ConfigDict(validate_assignment=True)
+
     def get_credentials_schema(self) -> dict[str, ProviderConfig]:
         """
             returns the credentials schema of the provider

+ 11 - 2
api/core/tools/tool/tool.py

@@ -206,7 +206,16 @@ class Tool(BaseModel, ABC):
             tool_parameters=tool_parameters,
         )
 
-        return result
+        if isinstance(result, ToolInvokeMessage):
+            def single_generator():
+                yield result
+            return single_generator()
+        elif isinstance(result, list):
+            def generator():
+                yield from result
+            return generator()
+        else:
+            return result
 
     def _transform_tool_parameters_type(self, tool_parameters: dict[str, Any]) -> dict[str, Any]:
         """
@@ -223,7 +232,7 @@ class Tool(BaseModel, ABC):
         return result
 
     @abstractmethod
-    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessage, None, None]:
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage] | Generator[ToolInvokeMessage, None, None]:
         pass
 
     def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:

+ 18 - 3
api/core/tools/tool_manager.py

@@ -116,7 +116,12 @@ class ToolManager:
             # decrypt the credentials
             credentials = builtin_provider.credentials
             controller = cls.get_builtin_provider(provider_id)
-            tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
+            tool_configuration = ToolConfigurationManager(
+                tenant_id=tenant_id, 
+                config=controller.get_credentials_schema(),
+                provider_type=controller.provider_type.value,
+                provider_identity=controller.identity.name
+            )
 
             decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
 
@@ -135,7 +140,12 @@ class ToolManager:
             api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
 
             # decrypt the credentials
-            tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
+            tool_configuration = ToolConfigurationManager(
+                tenant_id=tenant_id, 
+                config=api_provider.get_credentials_schema(),
+                provider_type=api_provider.provider_type.value,
+                provider_identity=api_provider.identity.name
+            )
             decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
 
             return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
@@ -513,7 +523,12 @@ class ToolManager:
             provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
         )
         # init tool configuration
-        tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
+        tool_configuration = ToolConfigurationManager(
+            tenant_id=tenant_id,
+            config=controller.get_credentials_schema(),
+            provider_type=controller.provider_type.value,
+            provider_identity=controller.identity.name
+        )
 
         decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
         masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)

+ 13 - 11
api/core/tools/utils/configuration.py

@@ -1,23 +1,25 @@
+from collections.abc import Mapping
 from copy import deepcopy
 from typing import Any
 
 from pydantic import BaseModel
 
+from core.entities.provider_entities import BasicProviderConfig
 from core.helper import encrypter
 from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
 from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
 from core.tools.entities.tool_entities import (
-    ProviderConfig,
     ToolParameter,
     ToolProviderType,
 )
-from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.tool.tool import Tool
 
 
 class ToolConfigurationManager(BaseModel):
     tenant_id: str
-    provider_controller: ToolProviderController
+    config: Mapping[str, BasicProviderConfig]
+    provider_type: str
+    provider_identity: str
 
     def _deep_copy(self, credentials: dict[str, str]) -> dict[str, str]:
         """
@@ -34,9 +36,9 @@ class ToolConfigurationManager(BaseModel):
         credentials = self._deep_copy(credentials)
 
         # get fields need to be decrypted
-        fields = self.provider_controller.get_credentials_schema()
+        fields = self.config
         for field_name, field in fields.items():
-            if field.type == ProviderConfig.Type.SECRET_INPUT:
+            if field.type == BasicProviderConfig.Type.SECRET_INPUT:
                 if field_name in credentials:
                     encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
                     credentials[field_name] = encrypted
@@ -52,9 +54,9 @@ class ToolConfigurationManager(BaseModel):
         credentials = self._deep_copy(credentials)
 
         # get fields need to be decrypted
-        fields = self.provider_controller.get_credentials_schema()
+        fields = self.config
         for field_name, field in fields.items():
-            if field.type == ProviderConfig.Type.SECRET_INPUT:
+            if field.type == BasicProviderConfig.Type.SECRET_INPUT:
                 if field_name in credentials:
                     if len(credentials[field_name]) > 6:
                         credentials[field_name] = \
@@ -74,7 +76,7 @@ class ToolConfigurationManager(BaseModel):
         """
         cache = ToolProviderCredentialsCache(
             tenant_id=self.tenant_id, 
-            identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
+            identity_id=f'{self.provider_type}.{self.provider_identity}',
             cache_type=ToolProviderCredentialsCacheType.PROVIDER
         )
         cached_credentials = cache.get()
@@ -82,9 +84,9 @@ class ToolConfigurationManager(BaseModel):
             return cached_credentials
         credentials = self._deep_copy(credentials)
         # get fields need to be decrypted
-        fields = self.provider_controller.get_credentials_schema()
+        fields = self.config
         for field_name, field in fields.items():
-            if field.type == ProviderConfig.Type.SECRET_INPUT:
+            if field.type == BasicProviderConfig.Type.SECRET_INPUT:
                 if field_name in credentials:
                     try:
                         credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
@@ -97,7 +99,7 @@ class ToolConfigurationManager(BaseModel):
     def delete_tool_credentials_cache(self):
         cache = ToolProviderCredentialsCache(
             tenant_id=self.tenant_id, 
-            identity_id=f'{self.provider_controller.provider_type.value}.{self.provider_controller.identity.name}',
+            identity_id=f'{self.provider_type}.{self.provider_identity}',
             cache_type=ToolProviderCredentialsCacheType.PROVIDER
         )
         cache.delete()

+ 6 - 5
api/core/tools/utils/parser.py

@@ -16,7 +16,7 @@ from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolPro
 
 class ApiBasedToolSchemaParser:
     @staticmethod
-    def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
+    def parse_openapi_to_tool_bundle(openapi: dict, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
         warning = warning if warning is not None else {}
         extra_info = extra_info if extra_info is not None else {}
 
@@ -173,7 +173,7 @@ class ApiBasedToolSchemaParser:
             return ToolParameter.ToolParameterType.STRING
 
     @staticmethod
-    def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
+    def parse_openapi_yaml_to_tool_bundle(yaml: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
         """
             parse openapi yaml to tool bundle
 
@@ -189,7 +189,8 @@ class ApiBasedToolSchemaParser:
         return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
     
     @staticmethod
-    def parse_swagger_to_openapi(swagger: dict, extra_info: dict = None, warning: dict = None) -> dict:
+    def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
+        warning = warning or {}
         """
             parse swagger to openapi
 
@@ -255,7 +256,7 @@ class ApiBasedToolSchemaParser:
         return openapi
 
     @staticmethod
-    def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict = None, warning: dict = None) -> list[ApiToolBundle]:
+    def parse_openai_plugin_json_to_tool_bundle(json: str, extra_info: dict | None = None, warning: dict | None = None) -> list[ApiToolBundle]:
         """
             parse openapi plugin yaml to tool bundle
 
@@ -287,7 +288,7 @@ class ApiBasedToolSchemaParser:
         return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(response.text, extra_info=extra_info, warning=warning)
     
     @staticmethod
-    def auto_parse_to_tool_bundle(content: str, extra_info: dict = None, warning: dict = None) -> tuple[list[ApiToolBundle], str]:
+    def auto_parse_to_tool_bundle(content: str, extra_info: dict | None = None, warning: dict | None = None) -> tuple[list[ApiToolBundle], str]:
         """
             auto parse to tool bundle
 

+ 10 - 7
api/core/workflow/nodes/tool/tool_node.py

@@ -1,6 +1,6 @@
 from collections.abc import Generator, Sequence
 from os import path
-from typing import Any, cast
+from typing import Any, Iterable, cast
 
 from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
 from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@@ -158,14 +158,17 @@ class ToolNode(BaseNode):
             tenant_id=self.tenant_id,
             conversation_id=None,
         )
+
+        result = list(messages)
+
         # extract plain text and files
-        files = self._extract_tool_response_binary(messages)
-        plain_text = self._extract_tool_response_text(messages)
-        json = self._extract_tool_response_json(messages)
+        files = self._extract_tool_response_binary(result)
+        plain_text = self._extract_tool_response_text(result)
+        json = self._extract_tool_response_json(result)
 
         return plain_text, files, json
 
-    def _extract_tool_response_binary(self, tool_response: Generator[ToolInvokeMessage, None, None]) -> list[FileVar]:
+    def _extract_tool_response_binary(self, tool_response: Iterable[ToolInvokeMessage]) -> list[FileVar]:
         """
         Extract tool response binary
         """
@@ -215,7 +218,7 @@ class ToolNode(BaseNode):
 
         return result
 
-    def _extract_tool_response_text(self, tool_response: Generator[ToolInvokeMessage]) -> str:
+    def _extract_tool_response_text(self, tool_response: Iterable[ToolInvokeMessage]) -> str:
         """
         Extract tool response text
         """
@@ -230,7 +233,7 @@ class ToolNode(BaseNode):
 
         return '\n'.join(result)
 
-    def _extract_tool_response_json(self, tool_response: Generator[ToolInvokeMessage]) -> list[dict]:
+    def _extract_tool_response_json(self, tool_response: Iterable[ToolInvokeMessage]) -> list[dict]:
         result: list[dict] = []
         for message in tool_response:
             if message.type == ToolInvokeMessage.MessageType.JSON:

+ 12 - 12
api/models/model.py

@@ -7,7 +7,7 @@ from typing import Optional
 from flask import request
 from flask_login import UserMixin
 from sqlalchemy import Float, func, text
-from sqlalchemy.orm import Mapped, mapped_column
+from sqlalchemy.orm import Mapped, mapped_column, relationship
 
 from configs import dify_config
 from core.file.tool_file_parser import ToolFileParser
@@ -495,14 +495,14 @@ class InstalledApp(db.Model):
         return tenant
 
 
-class Conversation(db.Model):
+class Conversation(Base):
     __tablename__ = 'conversations'
     __table_args__ = (
         db.PrimaryKeyConstraint('id', name='conversation_pkey'),
         db.Index('conversation_app_from_user_idx', 'app_id', 'from_source', 'from_end_user_id')
     )
 
-    id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
+    id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     app_id = db.Column(StringUUID, nullable=False)
     app_model_config_id = db.Column(StringUUID, nullable=True)
     model_provider = db.Column(db.String(255), nullable=True)
@@ -526,8 +526,8 @@ class Conversation(db.Model):
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
-    messages = db.relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
-    message_annotations = db.relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
+    messages: Mapped[list["Message"]] = relationship("Message", backref="conversation", lazy='select', passive_deletes="all")
+    message_annotations: Mapped[list["MessageAnnotation"]] = relationship("MessageAnnotation", backref="conversation", lazy='select', passive_deletes="all")
 
     is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
 
@@ -660,10 +660,10 @@ class Message(Base):
     model_provider = db.Column(db.String(255), nullable=True)
     model_id = db.Column(db.String(255), nullable=True)
     override_model_configs = db.Column(db.Text)
-    conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=False)
-    inputs = db.Column(db.JSON)
-    query = db.Column(db.Text, nullable=False)
-    message = db.Column(db.JSON, nullable=False)
+    conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey('conversations.id'), nullable=False)
+    inputs: Mapped[str] = mapped_column(db.JSON)
+    query: Mapped[str] = mapped_column(db.Text, nullable=False)
+    message: Mapped[str] = mapped_column(db.JSON, nullable=False)
     message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text('0'))
     message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
     message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text('0.001'))
@@ -944,7 +944,7 @@ class MessageFile(Base):
         db.Index('message_file_created_by_idx', 'created_by')
     )
 
-    id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
+    id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
     type: Mapped[str] = mapped_column(db.String(255), nullable=False)
     transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
@@ -956,7 +956,7 @@ class MessageFile(Base):
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
 
 
-class MessageAnnotation(db.Model):
+class MessageAnnotation(Base):
     __tablename__ = 'message_annotations'
     __table_args__ = (
         db.PrimaryKeyConstraint('id', name='message_annotation_pkey'),
@@ -967,7 +967,7 @@ class MessageAnnotation(db.Model):
 
     id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     app_id = db.Column(StringUUID, nullable=False)
-    conversation_id = db.Column(StringUUID, db.ForeignKey('conversations.id'), nullable=True)
+    conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey('conversations.id'), nullable=True)
     message_id = db.Column(StringUUID, nullable=True)
     question = db.Column(db.Text, nullable=True)
     content = db.Column(db.Text, nullable=False)

+ 3 - 3
api/models/tools.py

@@ -77,10 +77,10 @@ class PublishedAppTool(db.Model):
         return I18nObject(**json.loads(self.description))
     
     @property
-    def app(self) -> App:
+    def app(self) -> App | None:
         return db.session.query(App).filter(App.id == self.app_id).first()
 
-class ApiToolProvider(db.Model):
+class ApiToolProvider(Base):
     """
     The table stores the api providers.
     """
@@ -290,7 +290,7 @@ class ToolFile(Base):
         db.Index('tool_file_conversation_id_idx', 'conversation_id'),
     )
 
-    id: Mapped[str] = mapped_column(StringUUID, default=db.text('uuid_generate_v4()'))
+    id: Mapped[str] = mapped_column(StringUUID, server_default=db.text('uuid_generate_v4()'))
     # conversation user id
     user_id: Mapped[str] = mapped_column(StringUUID)
     # tenant id

+ 33 - 19
api/services/tools/api_tools_manage_service.py

@@ -3,6 +3,7 @@ import logging
 
 from httpx import get
 
+from core.entities.provider_entities import ProviderConfig
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.tools.entities.api_entities import UserTool, UserToolProvider
 from core.tools.entities.common_entities import I18nObject
@@ -10,8 +11,6 @@ from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import (
     ApiProviderAuthType,
     ApiProviderSchemaType,
-    ProviderConfig,
-    ToolCredentialsOption,
 )
 from core.tools.provider.api_tool_provider import ApiToolProviderController
 from core.tools.tool_label_manager import ToolLabelManager
@@ -45,8 +44,8 @@ class ApiToolManageService:
                     required=True,
                     default="none",
                     options=[
-                        ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
-                        ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
+                        ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
+                        ProviderConfig.Option(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")),
                     ],
                     placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
                 ),
@@ -79,15 +78,14 @@ class ApiToolManageService:
             raise ValueError(f"invalid schema: {str(e)}")
 
     @staticmethod
-    def convert_schema_to_tool_bundles(schema: str, extra_info: dict = None) -> list[ApiToolBundle]:
+    def convert_schema_to_tool_bundles(schema: str, extra_info: dict | None = None) -> tuple[list[ApiToolBundle], str]:
         """
         convert schema to tool bundles
 
         :return: the list of tool bundles, description
         """
         try:
-            tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
-            return tool_bundles
+            return ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info)
         except Exception as e:
             raise ValueError(f"invalid schema: {str(e)}")
 
@@ -111,7 +109,7 @@ class ApiToolManageService:
             raise ValueError(f"invalid schema type {schema}")
 
         # check if the provider exists
-        provider: ApiToolProvider = (
+        provider: ApiToolProvider | None = (
             db.session.query(ApiToolProvider)
             .filter(
                 ApiToolProvider.tenant_id == tenant_id,
@@ -158,7 +156,13 @@ class ApiToolManageService:
         provider_controller.load_bundled_tools(tool_bundles)
 
         # encrypt credentials
-        tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
+        tool_configuration = ToolConfigurationManager(
+            tenant_id=tenant_id,
+            config=provider_controller.get_credentials_schema(),
+            provider_type=provider_controller.provider_type.value,
+            provider_identity=provider_controller.identity.name
+        )
+
         encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials)
         db_provider.credentials_str = json.dumps(encrypted_credentials)
 
@@ -195,21 +199,21 @@ class ApiToolManageService:
         return {"schema": schema}
 
     @staticmethod
-    def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
+    def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
         """
         list api tool provider tools
         """
-        provider: ApiToolProvider = (
+        provider: ApiToolProvider | None = (
             db.session.query(ApiToolProvider)
             .filter(
                 ApiToolProvider.tenant_id == tenant_id,
-                ApiToolProvider.name == provider,
+                ApiToolProvider.name == provider_name,
             )
             .first()
         )
 
         if provider is None:
-            raise ValueError(f"you have not added provider {provider}")
+            raise ValueError(f"you have not added provider {provider_name}")
 
         controller = ToolTransformService.api_provider_to_controller(db_provider=provider)
         labels = ToolLabelManager.get_tool_labels(controller)
@@ -243,7 +247,7 @@ class ApiToolManageService:
             raise ValueError(f"invalid schema type {schema}")
 
         # check if the provider exists
-        provider: ApiToolProvider = (
+        provider: ApiToolProvider | None = (
             db.session.query(ApiToolProvider)
             .filter(
                 ApiToolProvider.tenant_id == tenant_id,
@@ -282,7 +286,12 @@ class ApiToolManageService:
         provider_controller.load_bundled_tools(tool_bundles)
 
         # get original credentials if exists
-        tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
+        tool_configuration = ToolConfigurationManager(
+            tenant_id=tenant_id,
+            config=provider_controller.get_credentials_schema(),
+            provider_type=provider_controller.provider_type.value,
+            provider_identity=provider_controller.identity.name
+        )
 
         original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials)
         masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
@@ -310,7 +319,7 @@ class ApiToolManageService:
         """
         delete tool provider
         """
-        provider: ApiToolProvider = (
+        provider: ApiToolProvider | None = (
             db.session.query(ApiToolProvider)
             .filter(
                 ApiToolProvider.tenant_id == tenant_id,
@@ -360,7 +369,7 @@ class ApiToolManageService:
         if tool_bundle is None:
             raise ValueError(f"invalid tool name {tool_name}")
 
-        db_provider: ApiToolProvider = (
+        db_provider: ApiToolProvider | None = (
             db.session.query(ApiToolProvider)
             .filter(
                 ApiToolProvider.tenant_id == tenant_id,
@@ -396,7 +405,12 @@ class ApiToolManageService:
 
         # decrypt credentials
         if db_provider.id:
-            tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller)
+            tool_configuration = ToolConfigurationManager(
+                tenant_id=tenant_id,
+                config=provider_controller.get_credentials_schema(),
+                provider_type=provider_controller.provider_type.value,
+                provider_identity=provider_controller.identity.name
+            )
             decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
             # check if the credential has changed, save the original credential
             masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
@@ -444,7 +458,7 @@ class ApiToolManageService:
             # add icon
             ToolTransformService.repack_provider(user_provider)
 
-            tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id)
+            tools = provider_controller.get_tools(tenant_id=tenant_id)
 
             for tool in tools:
                 user_provider.tools.append(

+ 14 - 8
api/services/tools/tools_transform_service.py

@@ -3,12 +3,12 @@ import logging
 from typing import Optional, Union
 
 from configs import dify_config
+from core.entities.provider_entities import ProviderConfig
 from core.tools.entities.api_entities import UserTool, UserToolProvider
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import (
     ApiProviderAuthType,
-    ProviderConfig,
     ToolParameter,
     ToolProviderType,
 )
@@ -106,7 +106,10 @@ class ToolTransformService:
 
                 # init tool configuration
                 tool_configuration = ToolConfigurationManager(
-                    tenant_id=db_provider.tenant_id, provider_controller=provider_controller
+                    tenant_id=db_provider.tenant_id,
+                    config=provider_controller.get_credentials_schema(),
+                    provider_type=provider_controller.provider_type.value,
+                    provider_identity=provider_controller.identity.name
                 )
                 # decrypt the credentials and mask the credentials
                 decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials=credentials)
@@ -143,7 +146,7 @@ class ToolTransformService:
 
     @staticmethod
     def workflow_provider_to_user_provider(
-        provider_controller: WorkflowToolProviderController, labels: list[str] = None
+        provider_controller: WorkflowToolProviderController, labels: list[str] | None = None
     ):
         """
         convert provider controller to user provider
@@ -174,7 +177,7 @@ class ToolTransformService:
         provider_controller: ApiToolProviderController,
         db_provider: ApiToolProvider,
         decrypt_credentials: bool = True,
-        labels: list[str] = None,
+        labels: list[str] | None = None,
     ) -> UserToolProvider:
         """
         convert provider controller to user provider
@@ -209,7 +212,10 @@ class ToolTransformService:
         if decrypt_credentials:
             # init tool configuration
             tool_configuration = ToolConfigurationManager(
-                tenant_id=db_provider.tenant_id, provider_controller=provider_controller
+                tenant_id=db_provider.tenant_id,
+                config=provider_controller.get_credentials_schema(),
+                provider_type=provider_controller.provider_type.value,
+                provider_identity=provider_controller.identity.name
             )
 
             # decrypt the credentials and mask the credentials
@@ -223,9 +229,9 @@ class ToolTransformService:
     @staticmethod
     def tool_to_user_tool(
         tool: Union[ApiToolBundle, WorkflowTool, Tool],
-        credentials: dict = None,
-        tenant_id: str = None,
-        labels: list[str] = None,
+        credentials: dict | None = None,
+        tenant_id: str | None = None,
+        labels: list[str] | None = None,
     ) -> UserTool:
         """
         convert tool to user tool