Browse Source

refactor tools

Yeuoly 1 year ago
parent
commit
1fa3b9cfd8

+ 9 - 0
api/controllers/inner_api/plugin/plugin.py

@@ -10,6 +10,7 @@ from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
 from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
 from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
 from core.plugin.entities.request import (
 from core.plugin.entities.request import (
     RequestInvokeApp,
     RequestInvokeApp,
+    RequestInvokeEncrypt,
     RequestInvokeLLM,
     RequestInvokeLLM,
     RequestInvokeModeration,
     RequestInvokeModeration,
     RequestInvokeNode,
     RequestInvokeNode,
@@ -132,6 +133,14 @@ class PluginInvokeAppApi(Resource):
             PluginAppBackwardsInvocation.convert_to_event_stream(response)
             PluginAppBackwardsInvocation.convert_to_event_stream(response)
         )
         )
 
 
+class PluginInvokeEncryptApi(Resource):
+    @setup_required
+    @plugin_inner_api_only
+    @get_tenant
+    @plugin_data(payload_type=RequestInvokeEncrypt)
+    def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeEncrypt):
+        """"""
+
 api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
 api.add_resource(PluginInvokeLLMApi, '/invoke/llm')
 api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')
 api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding')
 api.add_resource(PluginInvokeRerankApi, '/invoke/rerank')
 api.add_resource(PluginInvokeRerankApi, '/invoke/rerank')

+ 2 - 0
api/controllers/inner_api/wraps.py

@@ -46,6 +46,8 @@ def enterprise_inner_api_user_auth(view):
             user_id = user_id.split(" ")[1]
             user_id = user_id.split(" ")[1]
 
 
         inner_api_key = request.headers.get("X-Inner-Api-Key")
         inner_api_key = request.headers.get("X-Inner-Api-Key")
+        if not inner_api_key:
+            raise ValueError("inner api key not found")
 
 
         data_to_sign = f"DIFY {user_id}"
         data_to_sign = f"DIFY {user_id}"
 
 

+ 1 - 1
api/core/app/entities/queue_entities.py

@@ -60,7 +60,7 @@ class QueueIterationStartEvent(AppQueueEvent):
     node_data: BaseNodeData
     node_data: BaseNodeData
 
 
     node_run_index: int
     node_run_index: int
-    inputs: dict = None
+    inputs: Optional[dict] = None
     predecessor_node_id: Optional[str] = None
     predecessor_node_id: Optional[str] = None
     metadata: Optional[dict] = None
     metadata: Optional[dict] = None
 
 

+ 30 - 0
api/core/entities/parameter_entities.py

@@ -0,0 +1,30 @@
+from enum import Enum
+
+
+class CommonParameterType(Enum):
+    SECRET_INPUT = "secret-input"
+    TEXT_INPUT = "text-input"
+    SELECT = "select"
+    STRING = "string"
+    NUMBER = "number"
+    FILE = "file"
+    BOOLEAN = "boolean"
+    APP_SELECTOR = "app-selector"
+    MODEL_CONFIG = "model-config"
+
+
+class AppSelectorScope(Enum):
+    ALL = "all"
+    CHAT = "chat"
+    WORKFLOW = "workflow"
+    COMPLETION = "completion"
+
+
+class ModelConfigScope(Enum):
+    LLM = "llm"
+    TEXT_EMBEDDING = "text-embedding"
+    RERANK = "rerank"
+    TTS = "tts"
+    SPEECH2TEXT = "speech2text"
+    MODERATION = "moderation"
+    VISION = "vision"

+ 53 - 2
api/core/entities/provider_entities.py

@@ -1,8 +1,10 @@
 from enum import Enum
 from enum import Enum
-from typing import Optional
+from typing import Optional, Union
 
 
-from pydantic import BaseModel, ConfigDict
+from pydantic import BaseModel, ConfigDict, Field
 
 
+from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
+from core.model_runtime.entities.common_entities import I18nObject
 from core.model_runtime.entities.model_entities import ModelType
 from core.model_runtime.entities.model_entities import ModelType
 from models.provider import ProviderQuotaType
 from models.provider import ProviderQuotaType
 
 
@@ -100,3 +102,52 @@ class ModelSettings(BaseModel):
 
 
     # pydantic configs
     # pydantic configs
     model_config = ConfigDict(protected_namespaces=())
     model_config = ConfigDict(protected_namespaces=())
+
+class BasicProviderConfig(BaseModel):
+    """
+    Base model class for common provider settings like credentials
+    """
+    class Type(Enum):
+        SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
+        TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
+        SELECT = CommonParameterType.SELECT.value
+        BOOLEAN = CommonParameterType.BOOLEAN.value
+        APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
+        MODEL_CONFIG = CommonParameterType.MODEL_CONFIG.value
+
+        @classmethod
+        def value_of(cls, value: str) -> "ProviderConfig.Type":
+            """
+            Get value of given mode.
+
+            :param value: mode value
+            :return: mode
+            """
+            for mode in cls:
+                if mode.value == value:
+                    return mode
+            raise ValueError(f'invalid mode value {value}')
+
+        @staticmethod
+        def default(value: str) -> str:
+            return ""
+    
+    type: Type = Field(..., description="The type of the credentials")
+    name: str = Field(..., description="The name of the credentials")
+
+class ProviderConfig(BasicProviderConfig):
+    """
+    Model class for common provider settings like credentials
+    """
+    class Option(BaseModel):
+        value: str = Field(..., description="The value of the option")
+        label: I18nObject = Field(..., description="The label of the option")
+
+    scope: AppSelectorScope | ModelConfigScope | None
+    required: bool = False
+    default: Optional[Union[int, str]] = None
+    options: Optional[list[Option]] = None
+    label: Optional[I18nObject] = None
+    help: Optional[I18nObject] = None
+    url: Optional[str] = None
+    placeholder: Optional[I18nObject] = None

+ 6 - 1
api/core/file/tool_file_parser.py

@@ -1,4 +1,9 @@
-tool_file_manager = {
+from typing import TYPE_CHECKING, Any
+
+if TYPE_CHECKING:
+    from core.tools.tool_file_manager import ToolFileManager
+
+tool_file_manager: dict[str, Any] = {
     'manager': None
     'manager': None
 }
 }
 
 

+ 11 - 3
api/core/plugin/entities/request.py

@@ -1,7 +1,9 @@
+from collections.abc import Mapping
 from typing import Any, Literal, Optional
 from typing import Any, Literal, Optional
 
 
 from pydantic import BaseModel, Field, field_validator
 from pydantic import BaseModel, Field, field_validator
 
 
+from core.entities.provider_entities import BasicProviderConfig
 from core.model_runtime.entities.message_entities import (
 from core.model_runtime.entities.message_entities import (
     AssistantPromptMessage,
     AssistantPromptMessage,
     PromptMessage,
     PromptMessage,
@@ -30,11 +32,10 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
     """
     """
     Request to invoke LLM
     Request to invoke LLM
     """
     """
-
     model_type: ModelType = ModelType.LLM
     model_type: ModelType = ModelType.LLM
     mode: str
     mode: str
     model_parameters: dict[str, Any] = Field(default_factory=dict)
     model_parameters: dict[str, Any] = Field(default_factory=dict)
-    prompt_messages: list[PromptMessage]
+    prompt_messages: list[PromptMessage] = Field(default_factory=list)
     tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
     tools: Optional[list[PromptMessageTool]] = Field(default_factory=list)
     stop: Optional[list[str]] = Field(default_factory=list)
     stop: Optional[list[str]] = Field(default_factory=list)
     stream: Optional[bool] = False
     stream: Optional[bool] = False
@@ -105,4 +106,11 @@ class RequestInvokeApp(BaseModel):
     conversation_id: Optional[str] = None
     conversation_id: Optional[str] = None
     user: Optional[str] = None
     user: Optional[str] = None
     files: list[dict] = Field(default_factory=list)
     files: list[dict] = Field(default_factory=list)
-    
+
+class RequestInvokeEncrypt(BaseModel):
+    """
+    Request to encryption
+    """
+    opt: Literal["encrypt", "decrypt"]
+    data: dict = Field(default_factory=dict)
+    config: Mapping[str, BasicProviderConfig] = Field(default_factory=Mapping)

+ 2 - 2
api/core/tools/entities/api_entities.py

@@ -4,7 +4,7 @@ from pydantic import BaseModel
 
 
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.model_runtime.utils.encoders import jsonable_encoder
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_entities import ToolProviderCredentials, ToolProviderType
+from core.tools.entities.tool_entities import ProviderConfig, ToolProviderType
 from core.tools.tool.tool import ToolParameter
 from core.tools.tool.tool import ToolParameter
 
 
 
 
@@ -62,4 +62,4 @@ class UserToolProvider(BaseModel):
         }
         }
 
 
 class UserToolProviderCredentials(BaseModel):
 class UserToolProviderCredentials(BaseModel):
-    credentials: dict[str, ToolProviderCredentials]
+    credentials: dict[str, ProviderConfig]

+ 11 - 60
api/core/tools/entities/tool_entities.py

@@ -3,6 +3,7 @@ from typing import Any, Optional, Union, cast
 
 
 from pydantic import BaseModel, Field, field_validator
 from pydantic import BaseModel, Field, field_validator
 
 
+from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.common_entities import I18nObject
 
 
 
 
@@ -137,12 +138,12 @@ class ToolParameterOption(BaseModel):
 
 
 class ToolParameter(BaseModel):
 class ToolParameter(BaseModel):
     class ToolParameterType(str, Enum):
     class ToolParameterType(str, Enum):
-        STRING = "string"
-        NUMBER = "number"
-        BOOLEAN = "boolean"
-        SELECT = "select"
-        SECRET_INPUT = "secret-input"
-        FILE = "file"
+        STRING = CommonParameterType.STRING.value
+        NUMBER = CommonParameterType.NUMBER.value
+        BOOLEAN = CommonParameterType.BOOLEAN.value
+        SELECT = CommonParameterType.SELECT.value
+        SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
+        FILE = CommonParameterType.FILE.value
 
 
     class ToolParameterForm(Enum):
     class ToolParameterForm(Enum):
         SCHEMA = "schema" # should be set while adding tool
         SCHEMA = "schema" # should be set while adding tool
@@ -151,16 +152,17 @@ class ToolParameter(BaseModel):
 
 
     name: str = Field(..., description="The name of the parameter")
     name: str = Field(..., description="The name of the parameter")
     label: I18nObject = Field(..., description="The label presented to the user")
     label: I18nObject = Field(..., description="The label presented to the user")
-    human_description: Optional[I18nObject] = Field(None, description="The description presented to the user")
-    placeholder: Optional[I18nObject] = Field(None, description="The placeholder presented to the user")
+    human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
+    placeholder: Optional[I18nObject] = Field(default=None, description="The placeholder presented to the user")
     type: ToolParameterType = Field(..., description="The type of the parameter")
     type: ToolParameterType = Field(..., description="The type of the parameter")
+    scope: AppSelectorScope | ModelConfigScope | None = None
     form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
     form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
     llm_description: Optional[str] = None
     llm_description: Optional[str] = None
     required: Optional[bool] = False
     required: Optional[bool] = False
     default: Optional[Union[float, int, str]] = None
     default: Optional[Union[float, int, str]] = None
     min: Optional[Union[float, int]] = None
     min: Optional[Union[float, int]] = None
     max: Optional[Union[float, int]] = None
     max: Optional[Union[float, int]] = None
-    options: Optional[list[ToolParameterOption]] = None
+    options: list[ToolParameterOption] = Field(default_factory=list)
 
 
     @classmethod
     @classmethod
     def get_simple_instance(cls,
     def get_simple_instance(cls,
@@ -211,57 +213,6 @@ class ToolIdentity(BaseModel):
     provider: str = Field(..., description="The provider of the tool")
     provider: str = Field(..., description="The provider of the tool")
     icon: Optional[str] = None
     icon: Optional[str] = None
 
 
-class ToolCredentialsOption(BaseModel):
-    value: str = Field(..., description="The value of the option")
-    label: I18nObject = Field(..., description="The label of the option")
-
-class ToolProviderCredentials(BaseModel):
-    class CredentialsType(Enum):
-        SECRET_INPUT = "secret-input"
-        TEXT_INPUT = "text-input"
-        SELECT = "select"
-        BOOLEAN = "boolean"
-
-        @classmethod
-        def value_of(cls, value: str) -> "ToolProviderCredentials.CredentialsType":
-            """
-            Get value of given mode.
-
-            :param value: mode value
-            :return: mode
-            """
-            for mode in cls:
-                if mode.value == value:
-                    return mode
-            raise ValueError(f'invalid mode value {value}')
-
-        @staticmethod
-        def default(value: str) -> str:
-            return ""
-
-    name: str = Field(..., description="The name of the credentials")
-    type: CredentialsType = Field(..., description="The type of the credentials")
-    required: bool = False
-    default: Optional[Union[int, str]] = None
-    options: Optional[list[ToolCredentialsOption]] = None
-    label: Optional[I18nObject] = None
-    help: Optional[I18nObject] = None
-    url: Optional[str] = None
-    placeholder: Optional[I18nObject] = None
-
-    def to_dict(self) -> dict:
-        return {
-            'name': self.name,
-            'type': self.type.value,
-            'required': self.required,
-            'default': self.default,
-            'options': self.options,
-            'help': self.help.to_dict() if self.help else None,
-            'label': self.label.to_dict() if self.label else None,
-            'url': self.url,
-            'placeholder': self.placeholder.to_dict() if self.placeholder else None,
-        }
-
 class ToolRuntimeVariableType(Enum):
 class ToolRuntimeVariableType(Enum):
     TEXT = "text"
     TEXT = "text"
     IMAGE = "image"
     IMAGE = "image"

+ 9 - 9
api/core/tools/provider/api_tool_provider.py

@@ -3,8 +3,8 @@ from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import (
 from core.tools.entities.tool_entities import (
     ApiProviderAuthType,
     ApiProviderAuthType,
+    ProviderConfig,
     ToolCredentialsOption,
     ToolCredentialsOption,
-    ToolProviderCredentials,
     ToolProviderType,
     ToolProviderType,
 )
 )
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.provider.tool_provider import ToolProviderController
@@ -20,10 +20,10 @@ class ApiToolProviderController(ToolProviderController):
     @staticmethod
     @staticmethod
     def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
     def from_db(db_provider: ApiToolProvider, auth_type: ApiProviderAuthType) -> 'ApiToolProviderController':
         credentials_schema = {
         credentials_schema = {
-            'auth_type': ToolProviderCredentials(
+            'auth_type': ProviderConfig(
                 name='auth_type',
                 name='auth_type',
                 required=True,
                 required=True,
-                type=ToolProviderCredentials.CredentialsType.SELECT,
+                type=ProviderConfig.Type.SELECT,
                 options=[
                 options=[
                     ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
                     ToolCredentialsOption(value='none', label=I18nObject(en_US='None', zh_Hans='无')),
                     ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
                     ToolCredentialsOption(value='api_key', label=I18nObject(en_US='api_key', zh_Hans='api_key'))
@@ -38,30 +38,30 @@ class ApiToolProviderController(ToolProviderController):
         if auth_type == ApiProviderAuthType.API_KEY:
         if auth_type == ApiProviderAuthType.API_KEY:
             credentials_schema = {
             credentials_schema = {
                 **credentials_schema,
                 **credentials_schema,
-                'api_key_header': ToolProviderCredentials(
+                'api_key_header': ProviderConfig(
                     name='api_key_header',
                     name='api_key_header',
                     required=False,
                     required=False,
                     default='api_key',
                     default='api_key',
-                    type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
+                    type=ProviderConfig.Type.TEXT_INPUT,
                     help=I18nObject(
                     help=I18nObject(
                         en_US='The header name of the api key',
                         en_US='The header name of the api key',
                         zh_Hans='携带 api key 的 header 名称'
                         zh_Hans='携带 api key 的 header 名称'
                     )
                     )
                 ),
                 ),
-                'api_key_value': ToolProviderCredentials(
+                'api_key_value': ProviderConfig(
                     name='api_key_value',
                     name='api_key_value',
                     required=True,
                     required=True,
-                    type=ToolProviderCredentials.CredentialsType.SECRET_INPUT,
+                    type=ProviderConfig.Type.SECRET_INPUT,
                     help=I18nObject(
                     help=I18nObject(
                         en_US='The api key',
                         en_US='The api key',
                         zh_Hans='api key的值'
                         zh_Hans='api key的值'
                     )
                     )
                 ),
                 ),
-                'api_key_header_prefix': ToolProviderCredentials(
+                'api_key_header_prefix': ProviderConfig(
                     name='api_key_header_prefix',
                     name='api_key_header_prefix',
                     required=False,
                     required=False,
                     default='basic',
                     default='basic',
-                    type=ToolProviderCredentials.CredentialsType.SELECT,
+                    type=ProviderConfig.Type.SELECT,
                     help=I18nObject(
                     help=I18nObject(
                         en_US='The prefix of the api key header',
                         en_US='The prefix of the api key header',
                         zh_Hans='api key header 的前缀'
                         zh_Hans='api key header 的前缀'

+ 0 - 115
api/core/tools/provider/app_tool_provider.py

@@ -1,115 +0,0 @@
-import logging
-from typing import Any
-
-from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_entities import ToolParameter, ToolParameterOption, ToolProviderType
-from core.tools.provider.tool_provider import ToolProviderController
-from core.tools.tool.tool import Tool
-from extensions.ext_database import db
-from models.model import App, AppModelConfig
-from models.tools import PublishedAppTool
-
-logger = logging.getLogger(__name__)
-
-class AppToolProviderEntity(ToolProviderController):
-    @property
-    def provider_type(self) -> ToolProviderType:
-        return ToolProviderType.APP
-    
-    def _validate_credentials(self, tool_name: str, credentials: dict[str, Any]) -> None:
-        pass
-
-    def validate_parameters(self, tool_name: str, tool_parameters: dict[str, Any]) -> None:
-        pass
-
-    def get_tools(self, user_id: str) -> list[Tool]:
-        db_tools: list[PublishedAppTool] = db.session.query(PublishedAppTool).filter(
-            PublishedAppTool.user_id == user_id,
-        ).all()
-
-        if not db_tools or len(db_tools) == 0:
-            return []
-
-        tools: list[Tool] = []
-
-        for db_tool in db_tools:
-            tool = {
-                'identity': {
-                    'author': db_tool.author,
-                    'name': db_tool.tool_name,
-                    'label': {
-                        'en_US': db_tool.tool_name,
-                        'zh_Hans': db_tool.tool_name
-                    },
-                    'icon': ''
-                },
-                'description': {
-                    'human': {
-                        'en_US': db_tool.description_i18n.en_US,
-                        'zh_Hans': db_tool.description_i18n.zh_Hans
-                    },
-                    'llm': db_tool.llm_description
-                },
-                'parameters': []
-            }
-            # get app from db
-            app: App = db_tool.app
-
-            if not app:
-                logger.error(f"app {db_tool.app_id} not found")
-                continue
-
-            app_model_config: AppModelConfig = app.app_model_config
-            user_input_form_list = app_model_config.user_input_form_list
-            for input_form in user_input_form_list:
-                # get type
-                form_type = input_form.keys()[0]
-                default = input_form[form_type]['default']
-                required = input_form[form_type]['required']
-                label = input_form[form_type]['label']
-                variable_name = input_form[form_type]['variable_name']
-                options = input_form[form_type].get('options', [])
-                if form_type == 'paragraph' or form_type == 'text-input':
-                    tool['parameters'].append(ToolParameter(
-                        name=variable_name,
-                        label=I18nObject(
-                            en_US=label,
-                            zh_Hans=label
-                        ),
-                        human_description=I18nObject(
-                            en_US=label,
-                            zh_Hans=label
-                        ),
-                        llm_description=label,
-                        form=ToolParameter.ToolParameterForm.FORM,
-                        type=ToolParameter.ToolParameterType.STRING,
-                        required=required,
-                        default=default
-                    ))
-                elif form_type == 'select':
-                    tool['parameters'].append(ToolParameter(
-                        name=variable_name,
-                        label=I18nObject(
-                            en_US=label,
-                            zh_Hans=label
-                        ),
-                        human_description=I18nObject(
-                            en_US=label,
-                            zh_Hans=label
-                        ),
-                        llm_description=label,
-                        form=ToolParameter.ToolParameterForm.FORM,
-                        type=ToolParameter.ToolParameterType.SELECT,
-                        required=required,
-                        default=default,
-                        options=[ToolParameterOption(
-                            value=option,
-                            label=I18nObject(
-                                en_US=option,
-                                zh_Hans=option
-                            )
-                        ) for option in options]
-                    ))
-
-            tools.append(Tool(**tool))
-        return tools

+ 10 - 82
api/core/tools/provider/builtin_tool_provider.py

@@ -2,22 +2,23 @@ from abc import abstractmethod
 from os import listdir, path
 from os import listdir, path
 from typing import Any
 from typing import Any
 
 
+from pydantic import Field
+
+from core.entities.provider_entities import ProviderConfig
 from core.helper.module_import_helper import load_single_subclass_from_source
 from core.helper.module_import_helper import load_single_subclass_from_source
-from core.tools.entities.tool_entities import ToolParameter, ToolProviderCredentials, ToolProviderType
+from core.tools.entities.tool_entities import ToolProviderType
 from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
 from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
 from core.tools.errors import (
 from core.tools.errors import (
-    ToolNotFoundError,
-    ToolParameterValidationError,
     ToolProviderNotFoundError,
     ToolProviderNotFoundError,
 )
 )
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.builtin_tool import BuiltinTool
-from core.tools.tool.tool import Tool
-from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 from core.tools.utils.yaml_utils import load_yaml_file
 from core.tools.utils.yaml_utils import load_yaml_file
 
 
 
 
 class BuiltinToolProviderController(ToolProviderController):
 class BuiltinToolProviderController(ToolProviderController):
+    tools: list[BuiltinTool] = Field(default_factory=list)
+
     def __init__(self, **data: Any) -> None:
     def __init__(self, **data: Any) -> None:
         if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
         if self.provider_type == ToolProviderType.API or self.provider_type == ToolProviderType.APP:
             super().__init__(**data)
             super().__init__(**data)
@@ -41,7 +42,7 @@ class BuiltinToolProviderController(ToolProviderController):
             'credentials_schema': provider_yaml.get('credentials_for_provider', None),
             'credentials_schema': provider_yaml.get('credentials_for_provider', None),
         })
         })
 
 
-    def _get_builtin_tools(self) -> list[Tool]:
+    def _get_builtin_tools(self) -> list[BuiltinTool]:
         """
         """
             returns a list of tools that the provider can provide
             returns a list of tools that the provider can provide
 
 
@@ -72,7 +73,7 @@ class BuiltinToolProviderController(ToolProviderController):
         self.tools = tools
         self.tools = tools
         return tools
         return tools
     
     
-    def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
+    def get_credentials_schema(self) -> dict[str, ProviderConfig]:
         """
         """
             returns the credentials schema of the provider
             returns the credentials schema of the provider
 
 
@@ -83,7 +84,7 @@ class BuiltinToolProviderController(ToolProviderController):
         
         
         return self.credentials_schema.copy()
         return self.credentials_schema.copy()
 
 
-    def get_tools(self) -> list[Tool]:
+    def get_tools(self) -> list[BuiltinTool]:
         """
         """
             returns a list of tools that the provider can provide
             returns a list of tools that the provider can provide
 
 
@@ -91,24 +92,12 @@ class BuiltinToolProviderController(ToolProviderController):
         """
         """
         return self._get_builtin_tools()
         return self._get_builtin_tools()
     
     
-    def get_tool(self, tool_name: str) -> Tool:
+    def get_tool(self, tool_name: str) -> BuiltinTool | None:
         """
         """
             returns the tool that the provider can provide
             returns the tool that the provider can provide
         """
         """
         return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
         return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
 
 
-    def get_parameters(self, tool_name: str) -> list[ToolParameter]:
-        """
-            returns the parameters of the tool
-
-            :param tool_name: the name of the tool, defined in `get_tools`
-            :return: list of parameters
-        """
-        tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
-        if tool is None:
-            raise ToolNotFoundError(f'tool {tool_name} not found')
-        return tool.parameters
-
     @property
     @property
     def need_credentials(self) -> bool:
     def need_credentials(self) -> bool:
         """
         """
@@ -143,67 +132,6 @@ class BuiltinToolProviderController(ToolProviderController):
             returns the labels of the provider
             returns the labels of the provider
         """
         """
         return self.identity.tags or []
         return self.identity.tags or []
-
-    def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
-        """
-            validate the parameters of the tool and set the default value if needed
-
-            :param tool_name: the name of the tool, defined in `get_tools`
-            :param tool_parameters: the parameters of the tool
-        """
-        tool_parameters_schema = self.get_parameters(tool_name)
-        
-        tool_parameters_need_to_validate: dict[str, ToolParameter] = {}
-        for parameter in tool_parameters_schema:
-            tool_parameters_need_to_validate[parameter.name] = parameter
-
-        for parameter in tool_parameters:
-            if parameter not in tool_parameters_need_to_validate:
-                raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
-            
-            # check type
-            parameter_schema = tool_parameters_need_to_validate[parameter]
-            if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
-                if not isinstance(tool_parameters[parameter], str):
-                    raise ToolParameterValidationError(f'parameter {parameter} should be string')
-            
-            elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
-                if not isinstance(tool_parameters[parameter], int | float):
-                    raise ToolParameterValidationError(f'parameter {parameter} should be number')
-                
-                if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
-                    raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
-                
-                if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
-                    raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
-                
-            elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
-                if not isinstance(tool_parameters[parameter], bool):
-                    raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
-                
-            elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
-                if not isinstance(tool_parameters[parameter], str):
-                    raise ToolParameterValidationError(f'parameter {parameter} should be string')
-                
-                options = parameter_schema.options
-                if not isinstance(options, list):
-                    raise ToolParameterValidationError(f'parameter {parameter} options should be list')
-                
-                if tool_parameters[parameter] not in [x.value for x in options]:
-                    raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
-                
-            tool_parameters_need_to_validate.pop(parameter)
-
-        for parameter in tool_parameters_need_to_validate:
-            parameter_schema = tool_parameters_need_to_validate[parameter]
-            if parameter_schema.required:
-                raise ToolParameterValidationError(f'parameter {parameter} is required')
-            
-            # the parameter is not set currently, set the default value if needed
-            if parameter_schema.default is not None:
-                default_value = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
-                                                                              parameter_schema.type)
-                tool_parameters[parameter] = default_value
     
     
     def validate_credentials(self, credentials: dict[str, Any]) -> None:
     def validate_credentials(self, credentials: dict[str, Any]) -> None:
         """
         """

+ 15 - 98
api/core/tools/provider/tool_provider.py

@@ -1,25 +1,23 @@
 from abc import ABC, abstractmethod
 from abc import ABC, abstractmethod
-from typing import Any, Optional
+from typing import Any
 
 
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
 
 
+from core.entities.provider_entities import ProviderConfig
 from core.tools.entities.tool_entities import (
 from core.tools.entities.tool_entities import (
-    ToolParameter,
-    ToolProviderCredentials,
     ToolProviderIdentity,
     ToolProviderIdentity,
     ToolProviderType,
     ToolProviderType,
 )
 )
-from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
+from core.tools.errors import ToolProviderCredentialValidationError
 from core.tools.tool.tool import Tool
 from core.tools.tool.tool import Tool
-from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 
 
 
 
 class ToolProviderController(BaseModel, ABC):
 class ToolProviderController(BaseModel, ABC):
-    identity: Optional[ToolProviderIdentity] = None
-    tools: Optional[list[Tool]] = None
-    credentials_schema: Optional[dict[str, ToolProviderCredentials]] = None
+    identity: ToolProviderIdentity
+    tools: list[Tool] = Field(default_factory=list)
+    credentials_schema: dict[str, ProviderConfig] = Field(default_factory=dict)
 
 
-    def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
+    def get_credentials_schema(self) -> dict[str, ProviderConfig]:
         """
         """
             returns the credentials schema of the provider
             returns the credentials schema of the provider
 
 
@@ -28,15 +26,6 @@ class ToolProviderController(BaseModel, ABC):
         return self.credentials_schema.copy()
         return self.credentials_schema.copy()
     
     
     @abstractmethod
     @abstractmethod
-    def get_tools(self) -> list[Tool]:
-        """
-            returns a list of tools that the provider can provide
-
-            :return: list of tools
-        """
-        pass
-
-    @abstractmethod
     def get_tool(self, tool_name: str) -> Tool:
     def get_tool(self, tool_name: str) -> Tool:
         """
         """
             returns a tool that the provider can provide
             returns a tool that the provider can provide
@@ -45,18 +34,6 @@ class ToolProviderController(BaseModel, ABC):
         """
         """
         pass
         pass
 
 
-    def get_parameters(self, tool_name: str) -> list[ToolParameter]:
-        """
-            returns the parameters of the tool
-
-            :param tool_name: the name of the tool, defined in `get_tools`
-            :return: list of parameters
-        """
-        tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
-        if tool is None:
-            raise ToolNotFoundError(f'tool {tool_name} not found')
-        return tool.parameters
-
     @property
     @property
     def provider_type(self) -> ToolProviderType:
     def provider_type(self) -> ToolProviderType:
         """
         """
@@ -66,66 +43,6 @@ class ToolProviderController(BaseModel, ABC):
         """
         """
         return ToolProviderType.BUILT_IN
         return ToolProviderType.BUILT_IN
 
 
-    def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
-        """
-            validate the parameters of the tool and set the default value if needed
-
-            :param tool_name: the name of the tool, defined in `get_tools`
-            :param tool_parameters: the parameters of the tool
-        """
-        tool_parameters_schema = self.get_parameters(tool_name)
-        
-        tool_parameters_need_to_validate: dict[str, ToolParameter] = {}
-        for parameter in tool_parameters_schema:
-            tool_parameters_need_to_validate[parameter.name] = parameter
-
-        for parameter in tool_parameters:
-            if parameter not in tool_parameters_need_to_validate:
-                raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
-            
-            # check type
-            parameter_schema = tool_parameters_need_to_validate[parameter]
-            if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
-                if not isinstance(tool_parameters[parameter], str):
-                    raise ToolParameterValidationError(f'parameter {parameter} should be string')
-            
-            elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
-                if not isinstance(tool_parameters[parameter], int | float):
-                    raise ToolParameterValidationError(f'parameter {parameter} should be number')
-                
-                if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
-                    raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
-                
-                if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
-                    raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
-                
-            elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
-                if not isinstance(tool_parameters[parameter], bool):
-                    raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
-                
-            elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
-                if not isinstance(tool_parameters[parameter], str):
-                    raise ToolParameterValidationError(f'parameter {parameter} should be string')
-                
-                options = parameter_schema.options
-                if not isinstance(options, list):
-                    raise ToolParameterValidationError(f'parameter {parameter} options should be list')
-                
-                if tool_parameters[parameter] not in [x.value for x in options]:
-                    raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
-                
-            tool_parameters_need_to_validate.pop(parameter)
-
-        for parameter in tool_parameters_need_to_validate:
-            parameter_schema = tool_parameters_need_to_validate[parameter]
-            if parameter_schema.required:
-                raise ToolParameterValidationError(f'parameter {parameter} is required')
-            
-            # the parameter is not set currently, set the default value if needed
-            if parameter_schema.default is not None:
-                tool_parameters[parameter] = ToolParameterConverter.cast_parameter_by_type(parameter_schema.default,
-                                                                                           parameter_schema.type)
-
     def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
     def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
         """
         """
             validate the format of the credentials of the provider and set the default value if needed
             validate the format of the credentials of the provider and set the default value if needed
@@ -136,7 +53,7 @@ class ToolProviderController(BaseModel, ABC):
         if credentials_schema is None:
         if credentials_schema is None:
             return
             return
         
         
-        credentials_need_to_validate: dict[str, ToolProviderCredentials] = {}
+        credentials_need_to_validate: dict[str, ProviderConfig] = {}
         for credential_name in credentials_schema:
         for credential_name in credentials_schema:
             credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
             credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
 
 
@@ -146,12 +63,12 @@ class ToolProviderController(BaseModel, ABC):
             
             
             # check type
             # check type
             credential_schema = credentials_need_to_validate[credential_name]
             credential_schema = credentials_need_to_validate[credential_name]
-            if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
-                credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
+            if credential_schema == ProviderConfig.Type.SECRET_INPUT or \
+                credential_schema == ProviderConfig.Type.TEXT_INPUT:
                 if not isinstance(credentials[credential_name], str):
                 if not isinstance(credentials[credential_name], str):
                     raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
                     raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
             
             
-            elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
+            elif credential_schema.type == ProviderConfig.Type.SELECT:
                 if not isinstance(credentials[credential_name], str):
                 if not isinstance(credentials[credential_name], str):
                     raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
                     raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
                 
                 
@@ -173,9 +90,9 @@ class ToolProviderController(BaseModel, ABC):
             if credential_schema.default is not None:
             if credential_schema.default is not None:
                 default_value = credential_schema.default
                 default_value = credential_schema.default
                 # parse default value into the correct type
                 # parse default value into the correct type
-                if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
-                    credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \
-                    credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
+                if credential_schema.type == ProviderConfig.Type.SECRET_INPUT or \
+                    credential_schema.type == ProviderConfig.Type.TEXT_INPUT or \
+                    credential_schema.type == ProviderConfig.Type.SELECT:
                     default_value = str(default_value)
                     default_value = str(default_value)
 
 
                 credentials[credential_name] = default_value
                 credentials[credential_name] = default_value

+ 12 - 7
api/core/tools/provider/workflow_tool_provider.py

@@ -1,5 +1,8 @@
+from collections.abc import Mapping
 from typing import Optional
 from typing import Optional
 
 
+from pydantic import Field
+
 from core.app.app_config.entities import VariableEntity, VariableEntityType
 from core.app.app_config.entities import VariableEntity, VariableEntityType
 from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
 from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
 from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.common_entities import I18nObject
@@ -28,6 +31,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = {
 
 
 class WorkflowToolProviderController(ToolProviderController):
 class WorkflowToolProviderController(ToolProviderController):
     provider_id: str
     provider_id: str
+    tools: list[WorkflowTool] = Field(default_factory=list)
 
 
     @classmethod
     @classmethod
     def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController':
     def from_db(cls, db_provider: WorkflowToolProvider) -> 'WorkflowToolProviderController':
@@ -71,16 +75,17 @@ class WorkflowToolProviderController(ToolProviderController):
             :param app: the app
             :param app: the app
             :return: the tool
             :return: the tool
         """
         """
-        workflow: Workflow = db.session.query(Workflow).filter(
+        workflow: Workflow | None = db.session.query(Workflow).filter(
             Workflow.app_id == db_provider.app_id,
             Workflow.app_id == db_provider.app_id,
             Workflow.version == db_provider.version
             Workflow.version == db_provider.version
         ).first()
         ).first()
+
         if not workflow:
         if not workflow:
             raise ValueError('workflow not found')
             raise ValueError('workflow not found')
 
 
         # fetch start node
         # fetch start node
-        graph: dict = workflow.graph_dict
-        features_dict: dict = workflow.features_dict
+        graph: Mapping = workflow.graph_dict
+        features_dict: Mapping = workflow.features_dict
         features = WorkflowAppConfigManager.convert_features(
         features = WorkflowAppConfigManager.convert_features(
             config_dict=features_dict,
             config_dict=features_dict,
             app_mode=AppMode.WORKFLOW
             app_mode=AppMode.WORKFLOW
@@ -89,7 +94,7 @@ class WorkflowToolProviderController(ToolProviderController):
         parameters = db_provider.parameter_configurations
         parameters = db_provider.parameter_configurations
         variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
         variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph)
 
 
-        def fetch_workflow_variable(variable_name: str) -> VariableEntity:
+        def fetch_workflow_variable(variable_name: str) -> VariableEntity | None:
             return next(filter(lambda x: x.variable == variable_name, variables), None)
             return next(filter(lambda x: x.variable == variable_name, variables), None)
 
 
         user = db_provider.user
         user = db_provider.user
@@ -99,7 +104,7 @@ class WorkflowToolProviderController(ToolProviderController):
             variable = fetch_workflow_variable(parameter.name)
             variable = fetch_workflow_variable(parameter.name)
             if variable:
             if variable:
                 parameter_type = None
                 parameter_type = None
-                options = None
+                options = []
                 if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
                 if variable.type not in VARIABLE_TO_PARAMETER_TYPE_MAPPING:
                     raise ValueError(f'unsupported variable type {variable.type}')
                     raise ValueError(f'unsupported variable type {variable.type}')
                 parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
                 parameter_type = VARIABLE_TO_PARAMETER_TYPE_MAPPING[variable.type]
@@ -185,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController):
             label=db_provider.label
             label=db_provider.label
         )
         )
 
 
-    def get_tools(self, user_id: str, tenant_id: str) -> list[WorkflowTool]:
+    def get_tools(self, tenant_id: str) -> list[WorkflowTool]:
         """
         """
             fetch tools from database
             fetch tools from database
 
 
@@ -196,7 +201,7 @@ class WorkflowToolProviderController(ToolProviderController):
         if self.tools is not None:
         if self.tools is not None:
             return self.tools
             return self.tools
 
 
-        db_providers: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
+        db_providers: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
             WorkflowToolProvider.tenant_id == tenant_id,
             WorkflowToolProvider.tenant_id == tenant_id,
             WorkflowToolProvider.app_id == self.provider_id,
             WorkflowToolProvider.app_id == self.provider_id,
         ).first()
         ).first()

+ 1 - 1
api/core/tools/tool/tool.py

@@ -55,7 +55,7 @@ class Tool(BaseModel, ABC):
         invoke_from: Optional[InvokeFrom] = None
         invoke_from: Optional[InvokeFrom] = None
         tool_invoke_from: Optional[ToolInvokeFrom] = None
         tool_invoke_from: Optional[ToolInvokeFrom] = None
         credentials: Optional[dict[str, Any]] = None
         credentials: Optional[dict[str, Any]] = None
-        runtime_parameters: Optional[dict[str, Any]] = None
+        runtime_parameters: dict[str, Any] = Field(default_factory=dict)
 
 
     runtime: Optional[Runtime] = None
     runtime: Optional[Runtime] = None
     variables: Optional[ToolRuntimeVariablePool] = None
     variables: Optional[ToolRuntimeVariablePool] = None

+ 51 - 38
api/core/tools/tool_manager.py

@@ -4,7 +4,7 @@ import mimetypes
 from collections.abc import Generator
 from collections.abc import Generator
 from os import listdir, path
 from os import listdir, path
 from threading import Lock
 from threading import Lock
-from typing import Any, Union
+from typing import Any, Union, cast
 
 
 from configs import dify_config
 from configs import dify_config
 from core.agent.entities import AgentToolEntity
 from core.agent.entities import AgentToolEntity
@@ -22,6 +22,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl
 from core.tools.tool.api_tool import ApiTool
 from core.tools.tool.api_tool import ApiTool
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.tool import Tool
 from core.tools.tool.tool import Tool
+from core.tools.tool.workflow_tool import WorkflowTool
 from core.tools.tool_label_manager import ToolLabelManager
 from core.tools.tool_label_manager import ToolLabelManager
 from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
 from core.tools.utils.configuration import ToolConfigurationManager, ToolParameterConfigurationManager
 from core.tools.utils.tool_parameter_converter import ToolParameterConverter
 from core.tools.utils.tool_parameter_converter import ToolParameterConverter
@@ -57,7 +58,7 @@ class ToolManager:
         return cls._builtin_providers[provider]
         return cls._builtin_providers[provider]
 
 
     @classmethod
     @classmethod
-    def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool:
+    def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool | None:
         """
         """
             get the builtin tool
             get the builtin tool
 
 
@@ -78,7 +79,7 @@ class ToolManager:
                          tenant_id: str,
                          tenant_id: str,
                          invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
                          invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
                          tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
                          tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT) \
-            -> Union[BuiltinTool, ApiTool]:
+            -> Union[BuiltinTool, ApiTool, WorkflowTool]:
         """
         """
             get the tool runtime
             get the tool runtime
 
 
@@ -90,19 +91,21 @@ class ToolManager:
         """
         """
         if provider_type == ToolProviderType.BUILT_IN:
         if provider_type == ToolProviderType.BUILT_IN:
             builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
             builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
+            if not builtin_tool:
+                raise ValueError(f"tool {tool_name} not found")
 
 
             # check if the builtin tool need credentials
             # check if the builtin tool need credentials
             provider_controller = cls.get_builtin_provider(provider_id)
             provider_controller = cls.get_builtin_provider(provider_id)
             if not provider_controller.need_credentials:
             if not provider_controller.need_credentials:
-                return builtin_tool.fork_tool_runtime(runtime={
+                return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={
                     'tenant_id': tenant_id,
                     'tenant_id': tenant_id,
                     'credentials': {},
                     'credentials': {},
                     'invoke_from': invoke_from,
                     'invoke_from': invoke_from,
                     'tool_invoke_from': tool_invoke_from,
                     'tool_invoke_from': tool_invoke_from,
-                })
+                }))
 
 
             # get credentials
             # get credentials
-            builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
+            builtin_provider: BuiltinToolProvider | None = db.session.query(BuiltinToolProvider).filter(
                 BuiltinToolProvider.tenant_id == tenant_id,
                 BuiltinToolProvider.tenant_id == tenant_id,
                 BuiltinToolProvider.provider == provider_id,
                 BuiltinToolProvider.provider == provider_id,
             ).first()
             ).first()
@@ -117,13 +120,13 @@ class ToolManager:
 
 
             decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
             decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
 
 
-            return builtin_tool.fork_tool_runtime(runtime={
+            return cast(BuiltinTool, builtin_tool.fork_tool_runtime(runtime={
                 'tenant_id': tenant_id,
                 'tenant_id': tenant_id,
                 'credentials': decrypted_credentials,
                 'credentials': decrypted_credentials,
                 'runtime_parameters': {},
                 'runtime_parameters': {},
                 'invoke_from': invoke_from,
                 'invoke_from': invoke_from,
                 'tool_invoke_from': tool_invoke_from,
                 'tool_invoke_from': tool_invoke_from,
-            })
+            }))
 
 
         elif provider_type == ToolProviderType.API:
         elif provider_type == ToolProviderType.API:
             if tenant_id is None:
             if tenant_id is None:
@@ -135,12 +138,12 @@ class ToolManager:
             tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
             tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=api_provider)
             decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
             decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials)
 
 
-            return api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
+            return cast(ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime(runtime={
                 'tenant_id': tenant_id,
                 'tenant_id': tenant_id,
                 'credentials': decrypted_credentials,
                 'credentials': decrypted_credentials,
                 'invoke_from': invoke_from,
                 'invoke_from': invoke_from,
                 'tool_invoke_from': tool_invoke_from,
                 'tool_invoke_from': tool_invoke_from,
-            })
+            }))
         elif provider_type == ToolProviderType.WORKFLOW:
         elif provider_type == ToolProviderType.WORKFLOW:
             workflow_provider = db.session.query(WorkflowToolProvider).filter(
             workflow_provider = db.session.query(WorkflowToolProvider).filter(
                 WorkflowToolProvider.tenant_id == tenant_id,
                 WorkflowToolProvider.tenant_id == tenant_id,
@@ -154,12 +157,12 @@ class ToolManager:
                 db_provider=workflow_provider
                 db_provider=workflow_provider
             )
             )
 
 
-            return controller.get_tools(user_id=None, tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
+            return cast(WorkflowTool, controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime(runtime={
                 'tenant_id': tenant_id,
                 'tenant_id': tenant_id,
                 'credentials': {},
                 'credentials': {},
                 'invoke_from': invoke_from,
                 'invoke_from': invoke_from,
                 'tool_invoke_from': tool_invoke_from,
                 'tool_invoke_from': tool_invoke_from,
-            })
+            }))
         elif provider_type == ToolProviderType.APP:
         elif provider_type == ToolProviderType.APP:
             raise NotImplementedError('app provider not implemented')
             raise NotImplementedError('app provider not implemented')
         else:
         else:
@@ -220,7 +223,10 @@ class ToolManager:
             identity_id=f'AGENT.{app_id}'
             identity_id=f'AGENT.{app_id}'
         )
         )
         runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
         runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
-
+        
+        if not tool_entity.runtime:
+            raise Exception("tool missing runtime")
+        
         tool_entity.runtime.runtime_parameters.update(runtime_parameters)
         tool_entity.runtime.runtime_parameters.update(runtime_parameters)
         return tool_entity
         return tool_entity
 
 
@@ -258,6 +264,9 @@ class ToolManager:
         if runtime_parameters:
         if runtime_parameters:
             runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
             runtime_parameters = encryption_manager.decrypt_tool_parameters(runtime_parameters)
 
 
+        if not tool_entity.runtime:
+            raise Exception("tool missing runtime")
+        
         tool_entity.runtime.runtime_parameters.update(runtime_parameters)
         tool_entity.runtime.runtime_parameters.update(runtime_parameters)
         return tool_entity
         return tool_entity
 
 
@@ -304,20 +313,20 @@ class ToolManager:
         """
         """
             list all the builtin providers
             list all the builtin providers
         """
         """
-        for provider in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
-            if provider.startswith('__'):
+        for provider_path in listdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin')):
+            if provider_path.startswith('__'):
                 continue
                 continue
 
 
-            if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider)):
-                if provider.startswith('__'):
+            if path.isdir(path.join(path.dirname(path.realpath(__file__)), 'provider', 'builtin', provider_path)):
+                if provider_path.startswith('__'):
                     continue
                     continue
 
 
                 # init provider
                 # init provider
                 try:
                 try:
                     provider_class = load_single_subclass_from_source(
                     provider_class = load_single_subclass_from_source(
-                        module_name=f'core.tools.provider.builtin.{provider}.{provider}',
+                        module_name=f'core.tools.provider.builtin.{provider_path}.{provider_path}',
                         script_path=path.join(path.dirname(path.realpath(__file__)),
                         script_path=path.join(path.dirname(path.realpath(__file__)),
-                                              'provider', 'builtin', provider, f'{provider}.py'),
+                                              'provider', 'builtin', provider_path, f'{provider_path}.py'),
                         parent_type=BuiltinToolProviderController)
                         parent_type=BuiltinToolProviderController)
                     provider: BuiltinToolProviderController = provider_class()
                     provider: BuiltinToolProviderController = provider_class()
                     cls._builtin_providers[provider.identity.name] = provider
                     cls._builtin_providers[provider.identity.name] = provider
@@ -387,8 +396,8 @@ class ToolManager:
             for provider in builtin_providers:
             for provider in builtin_providers:
                 # handle include, exclude
                 # handle include, exclude
                 if is_filtered(
                 if is_filtered(
-                        include_set=dify_config.POSITION_TOOL_INCLUDES_SET,
-                        exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET,
+                        include_set=dify_config.POSITION_TOOL_INCLUDES_SET, # type: ignore
+                        exclude_set=dify_config.POSITION_TOOL_EXCLUDES_SET, # type: ignore
                         data=provider,
                         data=provider,
                         name_func=lambda x: x.identity.name
                         name_func=lambda x: x.identity.name
                 ):
                 ):
@@ -461,7 +470,7 @@ class ToolManager:
 
 
             :return: the provider controller, the credentials
             :return: the provider controller, the credentials
         """
         """
-        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
+        provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter(
             ApiToolProvider.id == provider_id,
             ApiToolProvider.id == provider_id,
             ApiToolProvider.tenant_id == tenant_id,
             ApiToolProvider.tenant_id == tenant_id,
         ).first()
         ).first()
@@ -486,22 +495,22 @@ class ToolManager:
         """
         """
             get tool provider
             get tool provider
         """
         """
-        provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
+        provider_obj: ApiToolProvider| None = db.session.query(ApiToolProvider).filter(
             ApiToolProvider.tenant_id == tenant_id,
             ApiToolProvider.tenant_id == tenant_id,
             ApiToolProvider.name == provider,
             ApiToolProvider.name == provider,
         ).first()
         ).first()
 
 
-        if provider is None:
+        if provider_obj is None:
             raise ValueError(f'you have not added provider {provider}')
             raise ValueError(f'you have not added provider {provider}')
 
 
         try:
         try:
-            credentials = json.loads(provider.credentials_str) or {}
+            credentials = json.loads(provider_obj.credentials_str) or {}
         except:
         except:
             credentials = {}
             credentials = {}
 
 
         # package tool provider controller
         # package tool provider controller
         controller = ApiToolProviderController.from_db(
         controller = ApiToolProviderController.from_db(
-            provider, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
+            provider_obj, ApiProviderAuthType.API_KEY if credentials['auth_type'] == 'api_key' else ApiProviderAuthType.NONE
         )
         )
         # init tool configuration
         # init tool configuration
         tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
         tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=controller)
@@ -510,7 +519,7 @@ class ToolManager:
         masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
         masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
 
 
         try:
         try:
-            icon = json.loads(provider.icon)
+            icon = json.loads(provider_obj.icon)
         except:
         except:
             icon = {
             icon = {
                 "background": "#252525",
                 "background": "#252525",
@@ -521,14 +530,14 @@ class ToolManager:
         labels = ToolLabelManager.get_tool_labels(controller)
         labels = ToolLabelManager.get_tool_labels(controller)
 
 
         return jsonable_encoder({
         return jsonable_encoder({
-            'schema_type': provider.schema_type,
-            'schema': provider.schema,
-            'tools': provider.tools,
+            'schema_type': provider_obj.schema_type,
+            'schema': provider_obj.schema,
+            'tools': provider_obj.tools,
             'icon': icon,
             'icon': icon,
-            'description': provider.description,
+            'description': provider_obj.description,
             'credentials': masked_credentials,
             'credentials': masked_credentials,
-            'privacy_policy': provider.privacy_policy,
-            'custom_disclaimer': provider.custom_disclaimer,
+            'privacy_policy': provider_obj.privacy_policy,
+            'custom_disclaimer': provider_obj.custom_disclaimer,
             'labels': labels,
             'labels': labels,
         })
         })
 
 
@@ -551,25 +560,29 @@ class ToolManager:
                     + "/icon")
                     + "/icon")
         elif provider_type == ToolProviderType.API:
         elif provider_type == ToolProviderType.API:
             try:
             try:
-                provider: ApiToolProvider = db.session.query(ApiToolProvider).filter(
+                api_provider: ApiToolProvider | None = db.session.query(ApiToolProvider).filter(
                     ApiToolProvider.tenant_id == tenant_id,
                     ApiToolProvider.tenant_id == tenant_id,
                     ApiToolProvider.id == provider_id
                     ApiToolProvider.id == provider_id
                 ).first()
                 ).first()
-                return json.loads(provider.icon)
+                if not api_provider:
+                    raise ValueError("api tool not found")
+                
+                return json.loads(api_provider.icon)
             except:
             except:
                 return {
                 return {
                     "background": "#252525",
                     "background": "#252525",
                     "content": "\ud83d\ude01"
                     "content": "\ud83d\ude01"
                 }
                 }
         elif provider_type == ToolProviderType.WORKFLOW:
         elif provider_type == ToolProviderType.WORKFLOW:
-            provider: WorkflowToolProvider = db.session.query(WorkflowToolProvider).filter(
+            workflow_provider: WorkflowToolProvider | None = db.session.query(WorkflowToolProvider).filter(
                 WorkflowToolProvider.tenant_id == tenant_id,
                 WorkflowToolProvider.tenant_id == tenant_id,
                 WorkflowToolProvider.id == provider_id
                 WorkflowToolProvider.id == provider_id
             ).first()
             ).first()
-            if provider is None:
+
+            if workflow_provider is None:
                 raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
                 raise ToolProviderNotFoundError(f'workflow provider {provider_id} not found')
 
 
-            return json.loads(provider.icon)
+            return json.loads(workflow_provider.icon)
         else:
         else:
             raise ValueError(f"provider type {provider_type} not found")
             raise ValueError(f"provider type {provider_type} not found")
 
 

+ 4 - 4
api/core/tools/utils/configuration.py

@@ -7,8 +7,8 @@ from core.helper import encrypter
 from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
 from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
 from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
 from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
 from core.tools.entities.tool_entities import (
 from core.tools.entities.tool_entities import (
+    ProviderConfig,
     ToolParameter,
     ToolParameter,
-    ToolProviderCredentials,
     ToolProviderType,
     ToolProviderType,
 )
 )
 from core.tools.provider.tool_provider import ToolProviderController
 from core.tools.provider.tool_provider import ToolProviderController
@@ -36,7 +36,7 @@ class ToolConfigurationManager(BaseModel):
         # get fields need to be decrypted
         # get fields need to be decrypted
         fields = self.provider_controller.get_credentials_schema()
         fields = self.provider_controller.get_credentials_schema()
         for field_name, field in fields.items():
         for field_name, field in fields.items():
-            if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
+            if field.type == ProviderConfig.Type.SECRET_INPUT:
                 if field_name in credentials:
                 if field_name in credentials:
                     encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
                     encrypted = encrypter.encrypt_token(self.tenant_id, credentials[field_name])
                     credentials[field_name] = encrypted
                     credentials[field_name] = encrypted
@@ -54,7 +54,7 @@ class ToolConfigurationManager(BaseModel):
         # get fields need to be decrypted
         # get fields need to be decrypted
         fields = self.provider_controller.get_credentials_schema()
         fields = self.provider_controller.get_credentials_schema()
         for field_name, field in fields.items():
         for field_name, field in fields.items():
-            if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
+            if field.type == ProviderConfig.Type.SECRET_INPUT:
                 if field_name in credentials:
                 if field_name in credentials:
                     if len(credentials[field_name]) > 6:
                     if len(credentials[field_name]) > 6:
                         credentials[field_name] = \
                         credentials[field_name] = \
@@ -84,7 +84,7 @@ class ToolConfigurationManager(BaseModel):
         # get fields need to be decrypted
         # get fields need to be decrypted
         fields = self.provider_controller.get_credentials_schema()
         fields = self.provider_controller.get_credentials_schema()
         for field_name, field in fields.items():
         for field_name, field in fields.items():
-            if field.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT:
+            if field.type == ProviderConfig.Type.SECRET_INPUT:
                 if field_name in credentials:
                 if field_name in credentials:
                     try:
                     try:
                         credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])
                         credentials[field_name] = encrypter.decrypt_token(self.tenant_id, credentials[field_name])

+ 3 - 3
api/core/tools/utils/workflow_configuration_sync.py

@@ -1,3 +1,5 @@
+from collections.abc import Mapping
+
 from core.app.app_config.entities import VariableEntity
 from core.app.app_config.entities import VariableEntity
 from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
 from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
 
 
@@ -13,7 +15,7 @@ class WorkflowToolConfigurationUtils:
                 raise ValueError('invalid parameter configuration')
                 raise ValueError('invalid parameter configuration')
 
 
     @classmethod
     @classmethod
-    def get_workflow_graph_variables(cls, graph: dict) -> list[VariableEntity]:
+    def get_workflow_graph_variables(cls, graph: Mapping) -> list[VariableEntity]:
         """
         """
         get workflow graph variables
         get workflow graph variables
         """
         """
@@ -44,5 +46,3 @@ class WorkflowToolConfigurationUtils:
         for parameter in tool_configurations:
         for parameter in tool_configurations:
             if parameter.name not in variable_names:
             if parameter.name not in variable_names:
                 raise ValueError('parameter configuration mismatch, please republish the tool to update')
                 raise ValueError('parameter configuration mismatch, please republish the tool to update')
-
-        return True

+ 7 - 7
api/services/tools/api_tools_manage_service.py

@@ -10,8 +10,8 @@ from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import (
 from core.tools.entities.tool_entities import (
     ApiProviderAuthType,
     ApiProviderAuthType,
     ApiProviderSchemaType,
     ApiProviderSchemaType,
+    ProviderConfig,
     ToolCredentialsOption,
     ToolCredentialsOption,
-    ToolProviderCredentials,
 )
 )
 from core.tools.provider.api_tool_provider import ApiToolProviderController
 from core.tools.provider.api_tool_provider import ApiToolProviderController
 from core.tools.tool_label_manager import ToolLabelManager
 from core.tools.tool_label_manager import ToolLabelManager
@@ -39,9 +39,9 @@ class ApiToolManageService:
                 raise ValueError(f"invalid schema: {str(e)}")
                 raise ValueError(f"invalid schema: {str(e)}")
 
 
             credentials_schema = [
             credentials_schema = [
-                ToolProviderCredentials(
+                ProviderConfig(
                     name="auth_type",
                     name="auth_type",
-                    type=ToolProviderCredentials.CredentialsType.SELECT,
+                    type=ProviderConfig.Type.SELECT,
                     required=True,
                     required=True,
                     default="none",
                     default="none",
                     options=[
                     options=[
@@ -50,17 +50,17 @@ class ApiToolManageService:
                     ],
                     ],
                     placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
                     placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"),
                 ),
                 ),
-                ToolProviderCredentials(
+                ProviderConfig(
                     name="api_key_header",
                     name="api_key_header",
-                    type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
+                    type=ProviderConfig.Type.TEXT_INPUT,
                     required=False,
                     required=False,
                     placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
                     placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"),
                     default="api_key",
                     default="api_key",
                     help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
                     help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"),
                 ),
                 ),
-                ToolProviderCredentials(
+                ProviderConfig(
                     name="api_key_value",
                     name="api_key_value",
-                    type=ToolProviderCredentials.CredentialsType.TEXT_INPUT,
+                    type=ProviderConfig.Type.TEXT_INPUT,
                     required=False,
                     required=False,
                     placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
                     placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"),
                     default="",
                     default="",

+ 2 - 2
api/services/tools/tools_transform_service.py

@@ -8,8 +8,8 @@ from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import (
 from core.tools.entities.tool_entities import (
     ApiProviderAuthType,
     ApiProviderAuthType,
+    ProviderConfig,
     ToolParameter,
     ToolParameter,
-    ToolProviderCredentials,
     ToolProviderType,
     ToolProviderType,
 )
 )
 from core.tools.provider.api_tool_provider import ApiToolProviderController
 from core.tools.provider.api_tool_provider import ApiToolProviderController
@@ -92,7 +92,7 @@ class ToolTransformService:
         # get credentials schema
         # get credentials schema
         schema = provider_controller.get_credentials_schema()
         schema = provider_controller.get_credentials_schema()
         for name, value in schema.items():
         for name, value in schema.items():
-            result.masked_credentials[name] = ToolProviderCredentials.CredentialsType.default(value.type)
+            result.masked_credentials[name] = ProviderConfig.Type.default(value.type)
 
 
         # check if the provider need credentials
         # check if the provider need credentials
         if not provider_controller.need_credentials:
         if not provider_controller.need_credentials: