Parcourir la source

refactor: using DeclarativeBase as parent class of models, refactored tools

Yeuoly il y a 8 mois
Parent
commit
e9e5c8806a

+ 2 - 2
api/controllers/console/setup.py

@@ -6,7 +6,7 @@ from flask_restful import Resource, reqparse
 from configs import dify_config
 from libs.helper import StrLen, email, get_remote_ip
 from libs.password import valid_password
-from models.model import DifySetup
+from models.model import DifySetup, db
 from services.account_service import RegisterService, TenantService
 
 from . import api
@@ -69,7 +69,7 @@ def setup_required(view):
 
 def get_setup_status():
     if dify_config.EDITION == "SELF_HOSTED":
-        return DifySetup.query.first()
+        return db.session.query(DifySetup).first()
     else:
         return True
 

+ 7 - 6
api/controllers/console/workspace/tool_providers.py

@@ -610,16 +610,17 @@ class ToolLabelsApi(Resource):
 api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
 
 # builtin tool provider
-api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<provider>/tools")
-api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<provider>/delete")
-api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<provider>/update")
+api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/tools")
+api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin/<path:provider>/delete")
+api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin/<path:provider>/update")
 api.add_resource(
-    ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials"
+    ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin/<path:provider>/credentials"
 )
 api.add_resource(
-    ToolBuiltinProviderCredentialsSchemaApi, "/workspaces/current/tool-provider/builtin/<provider>/credentials_schema"
+    ToolBuiltinProviderCredentialsSchemaApi,
+    "/workspaces/current/tool-provider/builtin/<path:provider>/credentials_schema",
 )
-api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<provider>/icon")
+api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin/<path:provider>/icon")
 
 # api tool provider
 api.add_resource(ToolApiProviderAddApi, "/workspaces/current/tool-provider/api/add")

+ 18 - 3
api/core/plugin/manager/tool.py

@@ -14,9 +14,9 @@ class PluginToolManager(BasePluginManager):
         provider follows format: plugin_id/provider_name
         """
         if "/" in provider:
-            parts = provider.split("/", 1)
-            if len(parts) == 2:
-                return parts[0], parts[1]
+            parts = provider.split("/", -1)
+            if len(parts) >= 2:
+                return "/".join(parts[:-1]), parts[-1]
             raise ValueError(f"invalid provider format: {provider}")
 
         raise ValueError(f"invalid provider format: {provider}")
@@ -46,6 +46,10 @@ class PluginToolManager(BasePluginManager):
         for provider in response:
             provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
 
+            # override the provider name for each tool to plugin_id/provider_name
+            for tool in provider.declaration.tools:
+                tool.identity.provider = provider.declaration.identity.name
+
         return response
 
     def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
@@ -54,15 +58,26 @@ class PluginToolManager(BasePluginManager):
         """
         plugin_id, provider_name = self._split_provider(provider)
 
+        def transformer(json_response: dict[str, Any]) -> dict:
+            for tool in json_response.get("data", {}).get("declaration", {}).get("tools", []):
+                tool["identity"]["provider"] = provider_name
+
+            return json_response
+
         response = self._request_with_plugin_daemon_response(
             "GET",
             f"plugin/{tenant_id}/management/tool",
             PluginToolProviderEntity,
             params={"provider": provider_name, "plugin_id": plugin_id},
+            transformer=transformer,
         )
 
         response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
 
+        # override the provider name for each tool to plugin_id/provider_name
+        for tool in response.declaration.tools:
+            tool.identity.provider = response.declaration.identity.name
+
         return response
 
     def invoke(

+ 1 - 6
api/core/tools/plugin_tool/provider.py

@@ -11,12 +11,10 @@ from core.tools.plugin_tool.tool import PluginTool
 class PluginToolProviderController(BuiltinToolProviderController):
     entity: ToolProviderEntityWithPlugin
     tenant_id: str
-    plugin_id: str
 
-    def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str, plugin_id: str) -> None:
+    def __init__(self, entity: ToolProviderEntityWithPlugin, tenant_id: str) -> None:
         self.entity = entity
         self.tenant_id = tenant_id
-        self.plugin_id = plugin_id
 
     @property
     def provider_type(self) -> ToolProviderType:
@@ -35,7 +33,6 @@ class PluginToolProviderController(BuiltinToolProviderController):
         if not manager.validate_provider_credentials(
             tenant_id=self.tenant_id,
             user_id=user_id,
-            plugin_id=self.plugin_id,
             provider=self.entity.identity.name,
             credentials=credentials,
         ):
@@ -54,7 +51,6 @@ class PluginToolProviderController(BuiltinToolProviderController):
             entity=tool_entity,
             runtime=ToolRuntime(tenant_id=self.tenant_id),
             tenant_id=self.tenant_id,
-            plugin_id=self.plugin_id,
         )
 
     def get_tools(self) -> list[PluginTool]:
@@ -66,7 +62,6 @@ class PluginToolProviderController(BuiltinToolProviderController):
                 entity=tool_entity,
                 runtime=ToolRuntime(tenant_id=self.tenant_id),
                 tenant_id=self.tenant_id,
-                plugin_id=self.plugin_id,
             )
             for tool_entity in self.entity.tools
         ]

+ 1 - 5
api/core/tools/plugin_tool/tool.py

@@ -9,12 +9,10 @@ from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, Too
 
 class PluginTool(Tool):
     tenant_id: str
-    plugin_id: str
 
-    def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, plugin_id: str) -> None:
+    def __init__(self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str) -> None:
         super().__init__(entity, runtime)
         self.tenant_id = tenant_id
-        self.plugin_id = plugin_id
 
     @property
     def tool_provider_type(self) -> ToolProviderType:
@@ -25,7 +23,6 @@ class PluginTool(Tool):
         return manager.invoke(
             tenant_id=self.tenant_id,
             user_id=user_id,
-            plugin_id=self.plugin_id,
             tool_provider=self.entity.identity.provider,
             tool_name=self.entity.identity.name,
             credentials=self.runtime.credentials,
@@ -37,5 +34,4 @@ class PluginTool(Tool):
             entity=self.entity,
             runtime=runtime,
             tenant_id=self.tenant_id,
-            plugin_id=self.plugin_id,
         )

+ 4 - 7
api/core/tools/tool_manager.py

@@ -86,7 +86,6 @@ class ToolManager:
         return PluginToolProviderController(
             entity=provider_entity.declaration,
             tenant_id=tenant_id,
-            plugin_id=provider_entity.plugin_id,
         )
 
     @classmethod
@@ -158,12 +157,11 @@ class ToolManager:
 
             # decrypt the credentials
             credentials = builtin_provider.credentials
-            controller = cls.get_builtin_provider(provider_id, tenant_id)
             tool_configuration = ProviderConfigEncrypter(
                 tenant_id=tenant_id,
-                config=controller.get_credentials_schema(),
-                provider_type=controller.provider_type.value,
-                provider_identity=controller.entity.identity.name,
+                config=provider_controller.get_credentials_schema(),
+                provider_type=provider_controller.provider_type.value,
+                provider_identity=provider_controller.entity.identity.name,
             )
 
             decrypted_credentials = tool_configuration.decrypt(credentials)
@@ -400,7 +398,6 @@ class ToolManager:
             PluginToolProviderController(
                 entity=provider.declaration,
                 tenant_id=tenant_id,
-                plugin_id=provider.plugin_id,
             )
             for provider in provider_entities
         ]
@@ -525,7 +522,7 @@ class ToolManager:
                 )
 
                 if isinstance(provider, PluginToolProviderController):
-                    result_providers[f"plugin_provider.{user_provider.name}.{provider.plugin_id}"] = user_provider
+                    result_providers[f"plugin_provider.{user_provider.name}"] = user_provider
                 else:
                     result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
 

+ 2 - 5
api/migrations/env.py

@@ -31,19 +31,16 @@ def get_engine_url():
 # from myapp import mymodel
 # target_metadata = mymodel.Base.metadata
 config.set_main_option('sqlalchemy.url', get_engine_url())
-target_db = current_app.extensions['migrate'].db
 
 # other values from the config, defined by the needs of env.py,
 # can be acquired:
 # my_important_option = config.get_main_option("my_important_option")
 # ... etc.
 
+from models.base import Base
 
 def get_metadata():
-    if hasattr(target_db, 'metadatas'):
-        return target_db.metadatas[None]
-    return target_db.metadata
-
+    return Base.metadata
 
 def include_object(object, name, type_, reflected, compare_to):
     if type_ == "foreign_key_constraint":

+ 39 - 0
api/migrations/versions/2024_09_29_0835-ddcc8bbef391_increase_max_length_of_builtin_tool_provider.py

@@ -0,0 +1,39 @@
+"""increase max length of builtin tool provider
+
+Revision ID: ddcc8bbef391
+Revises: d57ba9ebb251
+Create Date: 2024-09-29 08:35:58.062698
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = 'ddcc8bbef391'
+down_revision = 'd57ba9ebb251'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
+        batch_op.alter_column('provider',
+               existing_type=sa.VARCHAR(length=40),
+               type_=sa.String(length=256),
+               existing_nullable=False)
+
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
+        batch_op.alter_column('provider',
+               existing_type=sa.String(length=256),
+               type_=sa.VARCHAR(length=40),
+               existing_nullable=False)
+
+    # ### end Alembic commands ###

+ 3 - 3
api/models/base.py

@@ -1,5 +1,5 @@
-from sqlalchemy.orm import DeclarativeBase
+from sqlalchemy.orm import declarative_base
 
+from extensions.ext_database import metadata
 
-class Base(DeclarativeBase):
-    pass
+Base = declarative_base(metadata=metadata)

+ 56 - 51
api/models/model.py

@@ -2,12 +2,15 @@ import json
 import re
 import uuid
 from enum import Enum
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
+
+if TYPE_CHECKING:
+    from models.workflow import Workflow
 
 from flask import request
 from flask_login import UserMixin
-from sqlalchemy import Float, func, text
-from sqlalchemy.orm import Mapped, mapped_column, relationship
+from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
+from sqlalchemy.orm import Mapped, mapped_column
 
 from configs import dify_config
 from core.file.tool_file_parser import ToolFileParser
@@ -20,7 +23,7 @@ from .account import Account, Tenant
 from .types import StringUUID
 
 
-class DifySetup(db.Model):
+class DifySetup(Base):
     __tablename__ = "dify_setups"
     __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
 
@@ -55,7 +58,7 @@ class IconType(Enum):
     EMOJI = "emoji"
 
 
-class App(db.Model):
+class App(Base):
     __tablename__ = "apps"
     __table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
 
@@ -133,7 +136,8 @@ class App(db.Model):
             return False
         if not app_model_config.agent_mode:
             return False
-        if self.app_model_config.agent_mode_dict.get("enabled", False) and self.app_model_config.agent_mode_dict.get(
+
+        if app_model_config.agent_mode_dict.get("enabled", False) and app_model_config.agent_mode_dict.get(
             "strategy", ""
         ) in {"function_call", "react"}:
             self.mode = AppMode.AGENT_CHAT.value
@@ -250,7 +254,7 @@ class AppModelConfig(Base):
         return app
 
     @property
-    def model_dict(self) -> dict:
+    def model_dict(self):
         return json.loads(self.model) if self.model else None
 
     @property
@@ -284,6 +288,9 @@ class AppModelConfig(Base):
         )
         if annotation_setting:
             collection_binding_detail = annotation_setting.collection_binding_detail
+            if not collection_binding_detail:
+                raise ValueError("Collection binding detail not found")
+
             return {
                 "id": annotation_setting.id,
                 "enabled": True,
@@ -314,7 +321,7 @@ class AppModelConfig(Base):
         return json.loads(self.external_data_tools) if self.external_data_tools else []
 
     @property
-    def user_input_form_list(self) -> dict:
+    def user_input_form_list(self):
         return json.loads(self.user_input_form) if self.user_input_form else []
 
     @property
@@ -458,7 +465,7 @@ class AppModelConfig(Base):
         return new_app_model_config
 
 
-class RecommendedApp(db.Model):
+class RecommendedApp(Base):
     __tablename__ = "recommended_apps"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="recommended_app_pkey"),
@@ -486,7 +493,7 @@ class RecommendedApp(db.Model):
         return app
 
 
-class InstalledApp(db.Model):
+class InstalledApp(Base):
     __tablename__ = "installed_apps"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="installed_app_pkey"),
@@ -522,7 +529,7 @@ class Conversation(Base):
         db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
     )
 
-    id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
     app_id = db.Column(StringUUID, nullable=False)
     app_model_config_id = db.Column(StringUUID, nullable=True)
     model_provider = db.Column(db.String(255), nullable=True)
@@ -546,10 +553,8 @@ class Conversation(Base):
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
-    messages: Mapped[list["Message"]] = relationship(
-        "Message", backref="conversation", lazy="select", passive_deletes="all"
-    )
-    message_annotations: Mapped[list["MessageAnnotation"]] = relationship(
+    messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
+    message_annotations = db.relationship(
         "MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
     )
 
@@ -578,7 +583,7 @@ class Conversation(Base):
                 )
 
                 if not app_model_config:
-                    raise ValueError("app config not found")
+                    return {}
 
                 model_config = app_model_config.to_dict()
 
@@ -692,12 +697,12 @@ class Conversation(Base):
 class Message(Base):
     __tablename__ = "messages"
     __table_args__ = (
-        db.PrimaryKeyConstraint("id", name="message_pkey"),
-        db.Index("message_app_id_idx", "app_id", "created_at"),
-        db.Index("message_conversation_id_idx", "conversation_id"),
-        db.Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"),
-        db.Index("message_account_idx", "app_id", "from_source", "from_account_id"),
-        db.Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
+        PrimaryKeyConstraint("id", name="message_pkey"),
+        Index("message_app_id_idx", "app_id", "created_at"),
+        Index("message_conversation_id_idx", "conversation_id"),
+        Index("message_end_user_idx", "app_id", "from_source", "from_end_user_id"),
+        Index("message_account_idx", "app_id", "from_source", "from_account_id"),
+        Index("message_workflow_run_id_idx", "conversation_id", "workflow_run_id"),
     )
 
     id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
@@ -705,10 +710,10 @@ class Message(Base):
     model_provider = db.Column(db.String(255), nullable=True)
     model_id = db.Column(db.String(255), nullable=True)
     override_model_configs = db.Column(db.Text)
-    conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
-    inputs: Mapped[str] = mapped_column(db.JSON)
-    query: Mapped[str] = mapped_column(db.Text, nullable=False)
-    message: Mapped[str] = mapped_column(db.JSON, nullable=False)
+    conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
+    inputs = db.Column(db.JSON)
+    query = db.Column(db.Text, nullable=False)
+    message = db.Column(db.JSON, nullable=False)
     message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
     message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
     message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
@@ -974,7 +979,7 @@ class Message(Base):
         )
 
 
-class MessageFeedback(db.Model):
+class MessageFeedback(Base):
     __tablename__ = "message_feedbacks"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
@@ -1009,15 +1014,15 @@ class MessageFile(Base):
         db.Index("message_file_created_by_idx", "created_by"),
     )
 
-    id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
-    message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
-    type: Mapped[str] = mapped_column(db.String(255), nullable=False)
-    transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
-    url: Mapped[str] = mapped_column(db.Text, nullable=True)
-    belongs_to: Mapped[str] = mapped_column(db.String(255), nullable=True)
-    upload_file_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
-    created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
-    created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
+    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    message_id = db.Column(StringUUID, nullable=False)
+    type = db.Column(db.String(255), nullable=False)
+    transfer_method = db.Column(db.String(255), nullable=False)
+    url = db.Column(db.Text, nullable=True)
+    belongs_to = db.Column(db.String(255), nullable=True)
+    upload_file_id = db.Column(StringUUID, nullable=True)
+    created_by_role = db.Column(db.String(255), nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
@@ -1032,7 +1037,7 @@ class MessageAnnotation(Base):
 
     id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
     app_id = db.Column(StringUUID, nullable=False)
-    conversation_id: Mapped[str] = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=True)
+    conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True)
     message_id = db.Column(StringUUID, nullable=True)
     question = db.Column(db.Text, nullable=True)
     content = db.Column(db.Text, nullable=False)
@@ -1052,7 +1057,7 @@ class MessageAnnotation(Base):
         return account
 
 
-class AppAnnotationHitHistory(db.Model):
+class AppAnnotationHitHistory(Base):
     __tablename__ = "app_annotation_hit_histories"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@@ -1090,7 +1095,7 @@ class AppAnnotationHitHistory(db.Model):
         return account
 
 
-class AppAnnotationSetting(db.Model):
+class AppAnnotationSetting(Base):
     __tablename__ = "app_annotation_settings"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
@@ -1138,7 +1143,7 @@ class AppAnnotationSetting(db.Model):
         return collection_binding_detail
 
 
-class OperationLog(db.Model):
+class OperationLog(Base):
     __tablename__ = "operation_logs"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="operation_log_pkey"),
@@ -1155,7 +1160,7 @@ class OperationLog(db.Model):
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
-class EndUser(UserMixin, db.Model):
+class EndUser(UserMixin, Base):
     __tablename__ = "end_users"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="end_user_pkey"),
@@ -1175,7 +1180,7 @@ class EndUser(UserMixin, db.Model):
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
-class Site(db.Model):
+class Site(Base):
     __tablename__ = "sites"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="site_pkey"),
@@ -1222,7 +1227,7 @@ class Site(db.Model):
         return dify_config.APP_WEB_URL or request.url_root.rstrip("/")
 
 
-class ApiToken(db.Model):
+class ApiToken(Base):
     __tablename__ = "api_tokens"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="api_token_pkey"),
@@ -1249,7 +1254,7 @@ class ApiToken(db.Model):
             return result
 
 
-class UploadFile(db.Model):
+class UploadFile(Base):
     __tablename__ = "upload_files"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="upload_file_pkey"),
@@ -1273,7 +1278,7 @@ class UploadFile(db.Model):
     hash = db.Column(db.String(255), nullable=True)
 
 
-class ApiRequest(db.Model):
+class ApiRequest(Base):
     __tablename__ = "api_requests"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="api_request_pkey"),
@@ -1290,7 +1295,7 @@ class ApiRequest(db.Model):
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
-class MessageChain(db.Model):
+class MessageChain(Base):
     __tablename__ = "message_chains"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="message_chain_pkey"),
@@ -1395,7 +1400,7 @@ class MessageAgentThought(Base):
             return {}
 
     @property
-    def tool_outputs_dict(self) -> dict:
+    def tool_outputs_dict(self):
         tools = self.tools
         try:
             if self.observation:
@@ -1417,7 +1422,7 @@ class MessageAgentThought(Base):
                 return dict.fromkeys(tools, self.observation)
 
 
-class DatasetRetrieverResource(db.Model):
+class DatasetRetrieverResource(Base):
     __tablename__ = "dataset_retriever_resources"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="dataset_retriever_resource_pkey"),
@@ -1444,7 +1449,7 @@ class DatasetRetrieverResource(db.Model):
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
 
 
-class Tag(db.Model):
+class Tag(Base):
     __tablename__ = "tags"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="tag_pkey"),
@@ -1462,7 +1467,7 @@ class Tag(db.Model):
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
-class TagBinding(db.Model):
+class TagBinding(Base):
     __tablename__ = "tag_bindings"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="tag_binding_pkey"),
@@ -1478,7 +1483,7 @@ class TagBinding(db.Model):
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
-class TraceAppConfig(db.Model):
+class TraceAppConfig(Base):
     __tablename__ = "trace_app_config"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="tracing_app_config_pkey"),

+ 8 - 7
api/models/provider.py

@@ -1,6 +1,7 @@
 from enum import Enum
 
 from extensions.ext_database import db
+from models.base import Base
 
 from .types import StringUUID
 
@@ -35,7 +36,7 @@ class ProviderQuotaType(Enum):
         raise ValueError(f"No matching enum found for value '{value}'")
 
 
-class Provider(db.Model):
+class Provider(Base):
     """
     Provider model representing the API providers and their configurations.
     """
@@ -88,7 +89,7 @@ class Provider(db.Model):
             return self.is_valid and self.token_is_set
 
 
-class ProviderModel(db.Model):
+class ProviderModel(Base):
     """
     Provider model representing the API provider_models and their configurations.
     """
@@ -113,7 +114,7 @@ class ProviderModel(db.Model):
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
-class TenantDefaultModel(db.Model):
+class TenantDefaultModel(Base):
     __tablename__ = "tenant_default_models"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="tenant_default_model_pkey"),
@@ -129,7 +130,7 @@ class TenantDefaultModel(db.Model):
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
-class TenantPreferredModelProvider(db.Model):
+class TenantPreferredModelProvider(Base):
     __tablename__ = "tenant_preferred_model_providers"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="tenant_preferred_model_provider_pkey"),
@@ -144,7 +145,7 @@ class TenantPreferredModelProvider(db.Model):
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
-class ProviderOrder(db.Model):
+class ProviderOrder(Base):
     __tablename__ = "provider_orders"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="provider_order_pkey"),
@@ -169,7 +170,7 @@ class ProviderOrder(db.Model):
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
-class ProviderModelSetting(db.Model):
+class ProviderModelSetting(Base):
     """
     Provider model settings for record the model enabled status and load balancing status.
     """
@@ -191,7 +192,7 @@ class ProviderModelSetting(db.Model):
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
-class LoadBalancingModelConfig(db.Model):
+class LoadBalancingModelConfig(Base):
     """
     Configurations for load balancing models.
     """

+ 3 - 2
api/models/source.py

@@ -3,11 +3,12 @@ import json
 from sqlalchemy.dialects.postgresql import JSONB
 
 from extensions.ext_database import db
+from models.base import Base
 
 from .types import StringUUID
 
 
-class DataSourceOauthBinding(db.Model):
+class DataSourceOauthBinding(Base):
     __tablename__ = "data_source_oauth_bindings"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="source_binding_pkey"),
@@ -25,7 +26,7 @@ class DataSourceOauthBinding(db.Model):
     disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
 
 
-class DataSourceApiKeyAuthBinding(db.Model):
+class DataSourceApiKeyAuthBinding(Base):
     __tablename__ = "data_source_api_key_auth_bindings"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="data_source_api_key_auth_binding_pkey"),

+ 3 - 2
api/models/task.py

@@ -3,9 +3,10 @@ from datetime import datetime, timezone
 from celery import states
 
 from extensions.ext_database import db
+from models.base import Base
 
 
-class CeleryTask(db.Model):
+class CeleryTask(Base):
     """Task result/status."""
 
     __tablename__ = "celery_taskmeta"
@@ -29,7 +30,7 @@ class CeleryTask(db.Model):
     queue = db.Column(db.String(155), nullable=True)
 
 
-class CeleryTaskSet(db.Model):
+class CeleryTaskSet(Base):
     """TaskSet result."""
 
     __tablename__ = "celery_tasksetmeta"

+ 2 - 1
api/models/tool.py

@@ -2,6 +2,7 @@ import json
 from enum import Enum
 
 from extensions.ext_database import db
+from models.base import Base
 
 from .types import StringUUID
 
@@ -17,7 +18,7 @@ class ToolProviderName(Enum):
         raise ValueError(f"No matching enum found for value '{value}'")
 
 
-class ToolProvider(db.Model):
+class ToolProvider(Base):
     __tablename__ = "tool_providers"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="tool_provider_pkey"),

+ 49 - 3
api/models/tools.py

@@ -1,8 +1,11 @@
 import json
 from datetime import datetime
 
+from deprecated import deprecated
+from sqlalchemy import ForeignKey
 from sqlalchemy.orm import Mapped, mapped_column
 
+from core.tools.entities.common_entities import I18nObject
 from core.tools.entities.tool_bundle import ApiToolBundle
 from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
 from extensions.ext_database import db
@@ -31,7 +34,7 @@ class BuiltinToolProvider(Base):
     # who created this tool provider
     user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
     # name of the tool provider
-    provider: Mapped[str] = mapped_column(db.String(40), nullable=False)
+    provider: Mapped[str] = mapped_column(db.String(256), nullable=False)
     # credential of the tool provider
     encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
     created_at: Mapped[datetime] = mapped_column(
@@ -182,7 +185,7 @@ class WorkflowToolProvider(Base):
         return db.session.query(App).filter(App.id == self.app_id).first()
 
 
-class ToolModelInvoke(db.Model):
+class ToolModelInvoke(Base):
     """
     store the invoke logs from tool invoke
     """
@@ -219,7 +222,7 @@ class ToolModelInvoke(db.Model):
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
-class ToolConversationVariables(db.Model):
+class ToolConversationVariables(Base):
     """
     store the conversation variables from tool invoke
     """
@@ -275,3 +278,46 @@ class ToolFile(Base):
     mimetype: Mapped[str] = mapped_column(db.String(255), nullable=False)
     # original url
     original_url: Mapped[str] = mapped_column(db.String(2048), nullable=True)
+
+
+@deprecated
+class DeprecatedPublishedAppTool(Base):
+    """
+    The table stores the apps published as a tool for each person.
+    """
+
+    __tablename__ = "tool_published_apps"
+    __table_args__ = (
+        db.PrimaryKeyConstraint("id", name="published_app_tool_pkey"),
+        db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
+    )
+
+    # id of the tool provider
+    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    # id of the app
+    app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
+    # who published this tool
+    user_id = db.Column(StringUUID, nullable=False)
+    # description of the tool, stored in i18n format, for human
+    description = db.Column(db.Text, nullable=False)
+    # llm_description of the tool, for LLM
+    llm_description = db.Column(db.Text, nullable=False)
+    # query description, query will be seem as a parameter of the tool,
+    # to describe this parameter to llm, we need this field
+    query_description = db.Column(db.Text, nullable=False)
+    # query name, the name of the query parameter
+    query_name = db.Column(db.String(40), nullable=False)
+    # name of the tool provider
+    tool_name = db.Column(db.String(40), nullable=False)
+    # author
+    author = db.Column(db.String(40), nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+
+    @property
+    def description_i18n(self) -> I18nObject:
+        return I18nObject(**json.loads(self.description))
+
+    @property
+    def app(self) -> App:
+        return db.session.query(App).filter(App.id == self.app_id).first()

+ 3 - 2
api/models/web.py

@@ -1,10 +1,11 @@
 from extensions.ext_database import db
+from models.base import Base
 
 from .model import Message
 from .types import StringUUID
 
 
-class SavedMessage(db.Model):
+class SavedMessage(Base):
     __tablename__ = "saved_messages"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="saved_message_pkey"),
@@ -23,7 +24,7 @@ class SavedMessage(db.Model):
         return db.session.query(Message).filter(Message.id == self.message_id).first()
 
 
-class PinnedConversation(db.Model):
+class PinnedConversation(Base):
     __tablename__ = "pinned_conversations"
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="pinned_conversation_pkey"),

+ 24 - 15
api/models/workflow.py

@@ -2,10 +2,13 @@ import json
 from collections.abc import Mapping, Sequence
 from datetime import datetime
 from enum import Enum
-from typing import Any, Optional, Union
+from typing import TYPE_CHECKING, Any, Union
 
-from sqlalchemy import func
-from sqlalchemy.orm import Mapped
+if TYPE_CHECKING:
+    from models.model import AppMode
+
+from sqlalchemy import Index, PrimaryKeyConstraint, func
+from sqlalchemy.orm import Mapped, mapped_column
 
 import contexts
 from constants import HIDDEN_VALUE
@@ -13,6 +16,7 @@ from core.app.segments import SecretVariable, Variable, factory
 from core.helper import encrypter
 from extensions.ext_database import db
 from libs import helper
+from models.base import Base
 
 from .account import Account
 from .types import StringUUID
@@ -75,7 +79,7 @@ class WorkflowType(Enum):
         return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
 
 
-class Workflow(db.Model):
+class Workflow(Base):
     """
     Workflow, for `Workflow App` and `Chat App workflow mode`.
 
@@ -345,7 +349,7 @@ class WorkflowRunStatus(Enum):
         raise ValueError(f"invalid workflow run status value {value}")
 
 
-class WorkflowRun(db.Model):
+class WorkflowRun(Base):
     """
     Workflow Run
 
@@ -436,7 +440,7 @@ class WorkflowRun(db.Model):
         return json.loads(self.outputs) if self.outputs else None
 
     @property
-    def message(self) -> Optional["Message"]:
+    def message(self):
         from models.model import Message
 
         return (
@@ -542,7 +546,7 @@ class WorkflowNodeExecutionStatus(Enum):
         raise ValueError(f"invalid workflow node execution status value {value}")
 
 
-class WorkflowNodeExecution(db.Model):
+class WorkflowNodeExecution(Base):
     """
     Workflow Node Execution
 
@@ -708,7 +712,7 @@ class WorkflowAppLogCreatedFrom(Enum):
         raise ValueError(f"invalid workflow app log created from value {value}")
 
 
-class WorkflowAppLog(db.Model):
+class WorkflowAppLog(Base):
     """
     Workflow App execution log, excluding workflow debugging records.
 
@@ -770,15 +774,20 @@ class WorkflowAppLog(db.Model):
         return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None
 
 
-class ConversationVariable(db.Model):
+class ConversationVariable(Base):
     __tablename__ = "workflow_conversation_variables"
+    __table_args__ = (
+        PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"),
+        Index("workflow__conversation_variables_app_id_idx", "app_id"),
+        Index("workflow__conversation_variables_created_at_idx", "created_at"),
+    )
 
-    id: Mapped[str] = db.Column(StringUUID, primary_key=True)
-    conversation_id: Mapped[str] = db.Column(StringUUID, nullable=False, primary_key=True)
-    app_id: Mapped[str] = db.Column(StringUUID, nullable=False, index=True)
-    data = db.Column(db.Text, nullable=False)
-    created_at = db.Column(db.DateTime, nullable=False, index=True, server_default=db.text("CURRENT_TIMESTAMP(0)"))
-    updated_at = db.Column(
+    id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
+    conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True)
+    app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+    data = mapped_column(db.Text, nullable=False)
+    created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+    updated_at = mapped_column(
         db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
     )