Explorar o código

Migrate to DeclarativeBaseModel

Yeuoly hai 7 meses
pai
achega
11270a7ef2

+ 23 - 7
api/controllers/console/admin.py

@@ -3,6 +3,8 @@ from functools import wraps
 
 from flask import request
 from flask_restful import Resource, reqparse
+from sqlalchemy import select
+from sqlalchemy.orm import Session
 from werkzeug.exceptions import NotFound, Unauthorized
 
 from constants.languages import supported_language
@@ -54,7 +56,8 @@ class InsertExploreAppListApi(Resource):
         parser.add_argument("position", type=int, required=True, nullable=False, location="json")
         args = parser.parse_args()
 
-        app = App.query.filter(App.id == args["app_id"]).first()
+        with Session(db.engine) as session:
+            app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
         if not app:
             raise NotFound(f'App \'{args["app_id"]}\' is not found')
 
@@ -70,7 +73,10 @@ class InsertExploreAppListApi(Resource):
             privacy_policy = site.privacy_policy or args["privacy_policy"] or ""
             custom_disclaimer = site.custom_disclaimer or args["custom_disclaimer"] or ""
 
-        recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first()
+        with Session(db.engine) as session:
+            recommended_app = session.execute(
+                select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
+            ).scalar_one_or_none()
 
         if not recommended_app:
             recommended_app = RecommendedApp(
@@ -110,17 +116,27 @@ class InsertExploreAppApi(Resource):
     @only_edition_cloud
     @admin_required
     def delete(self, app_id):
-        recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == str(app_id)).first()
+        with Session(db.engine) as session:
+            recommended_app = session.execute(
+                select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
+            ).scalar_one_or_none()
+
         if not recommended_app:
             return {"result": "success"}, 204
 
-        app = App.query.filter(App.id == recommended_app.app_id).first()
+        with Session(db.engine) as session:
+            app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
+
         if app:
             app.is_public = False
 
-        installed_apps = InstalledApp.query.filter(
-            InstalledApp.app_id == recommended_app.app_id, InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id
-        ).all()
+        with Session(db.engine) as session:
+            installed_apps = session.execute(
+                select(InstalledApp).filter(
+                    InstalledApp.app_id == recommended_app.app_id,
+                    InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
+                )
+            ).all()
 
         for installed_app in installed_apps:
             db.session.delete(installed_app)

+ 4 - 1
api/controllers/console/apikey.py

@@ -33,7 +33,10 @@ def _get_resource(resource_id, tenant_id, resource_model):
                 select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
             ).scalar_one_or_none()
     else:
-        resource = resource_model.query.filter_by(id=resource_id, tenant_id=tenant_id).first()
+        with Session(db.engine) as session:
+            resource = session.execute(
+                select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id)
+            ).scalar_one_or_none()
 
     if resource is None:
         flask_restful.abort(404, message=f"{resource_model.__name__} not found.")

+ 6 - 2
api/controllers/console/auth/forgot_password.py

@@ -3,6 +3,8 @@ import secrets
 
 from flask import request
 from flask_restful import Resource, reqparse
+from sqlalchemy import select
+from sqlalchemy.orm import Session
 
 from constants.languages import languages
 from controllers.console import api
@@ -41,7 +43,8 @@ class ForgotPasswordSendEmailApi(Resource):
         else:
             language = "en-US"
 
-        account = Account.query.filter_by(email=args["email"]).first()
+        with Session(db.engine) as session:
+            account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
         token = None
         if account is None:
             if FeatureService.get_system_features().is_allow_register:
@@ -108,7 +111,8 @@ class ForgotPasswordResetApi(Resource):
         password_hashed = hash_password(new_password, salt)
         base64_password_hashed = base64.b64encode(password_hashed).decode()
 
-        account = Account.query.filter_by(email=reset_data.get("email")).first()
+        with Session(db.engine) as session:
+            account = session.execute(select(Account).filter_by(email=reset_data.get("email"))).scalar_one_or_none()
         if account:
             account.password = base64_password_hashed
             account.password_salt = base64_salt

+ 4 - 1
api/controllers/console/auth/oauth.py

@@ -5,6 +5,8 @@ from typing import Optional
 import requests
 from flask import current_app, redirect, request
 from flask_restful import Resource
+from sqlalchemy import select
+from sqlalchemy.orm import Session
 from werkzeug.exceptions import Unauthorized
 
 from configs import dify_config
@@ -135,7 +137,8 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
     account = Account.get_by_openid(provider, user_info.id)
 
     if not account:
-        account = Account.query.filter_by(email=user_info.email).first()
+        with Session(db.engine) as session:
+            account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
 
     return account
 

+ 63 - 49
api/controllers/console/datasets/data_source.py

@@ -4,6 +4,8 @@ import json
 from flask import request
 from flask_login import current_user
 from flask_restful import Resource, marshal_with, reqparse
+from sqlalchemy import select
+from sqlalchemy.orm import Session
 from werkzeug.exceptions import NotFound
 
 from controllers.console import api
@@ -77,7 +79,10 @@ class DataSourceApi(Resource):
     def patch(self, binding_id, action):
         binding_id = str(binding_id)
         action = str(action)
-        data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
+        with Session(db.engine) as session:
+            data_source_binding = session.execute(
+                select(DataSourceOauthBinding).filter_by(id=binding_id)
+            ).scalar_one_or_none()
         if data_source_binding is None:
             raise NotFound("Data source binding not found.")
         # enable binding
@@ -109,47 +114,53 @@ class DataSourceNotionListApi(Resource):
     def get(self):
         dataset_id = request.args.get("dataset_id", default=None, type=str)
         exist_page_ids = []
-        # import notion in the exist dataset
-        if dataset_id:
-            dataset = DatasetService.get_dataset(dataset_id)
-            if not dataset:
-                raise NotFound("Dataset not found.")
-            if dataset.data_source_type != "notion_import":
-                raise ValueError("Dataset is not notion type.")
-            documents = Document.query.filter_by(
-                dataset_id=dataset_id,
-                tenant_id=current_user.current_tenant_id,
-                data_source_type="notion_import",
-                enabled=True,
+        with Session(db.engine) as session:
+            # import notion in the exist dataset
+            if dataset_id:
+                dataset = DatasetService.get_dataset(dataset_id)
+                if not dataset:
+                    raise NotFound("Dataset not found.")
+                if dataset.data_source_type != "notion_import":
+                    raise ValueError("Dataset is not notion type.")
+
+                documents = session.execute(
+                    select(Document).filter_by(
+                        dataset_id=dataset_id,
+                        tenant_id=current_user.current_tenant_id,
+                        data_source_type="notion_import",
+                        enabled=True,
+                    )
+                ).all()
+                if documents:
+                    for document in documents:
+                        data_source_info = json.loads(document.data_source_info)
+                        exist_page_ids.append(data_source_info["notion_page_id"])
+            # get all authorized pages
+            data_source_bindings = session.execute(
+                select(DataSourceOauthBinding).filter_by(
+                    tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
+                )
             ).all()
-            if documents:
-                for document in documents:
-                    data_source_info = json.loads(document.data_source_info)
-                    exist_page_ids.append(data_source_info["notion_page_id"])
-        # get all authorized pages
-        data_source_bindings = DataSourceOauthBinding.query.filter_by(
-            tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
-        ).all()
-        if not data_source_bindings:
-            return {"notion_info": []}, 200
-        pre_import_info_list = []
-        for data_source_binding in data_source_bindings:
-            source_info = data_source_binding.source_info
-            pages = source_info["pages"]
-            # Filter out already bound pages
-            for page in pages:
-                if page["page_id"] in exist_page_ids:
-                    page["is_bound"] = True
-                else:
-                    page["is_bound"] = False
-            pre_import_info = {
-                "workspace_name": source_info["workspace_name"],
-                "workspace_icon": source_info["workspace_icon"],
-                "workspace_id": source_info["workspace_id"],
-                "pages": pages,
-            }
-            pre_import_info_list.append(pre_import_info)
-        return {"notion_info": pre_import_info_list}, 200
+            if not data_source_bindings:
+                return {"notion_info": []}, 200
+            pre_import_info_list = []
+            for data_source_binding in data_source_bindings:
+                source_info = data_source_binding.source_info
+                pages = source_info["pages"]
+                # Filter out already bound pages
+                for page in pages:
+                    if page["page_id"] in exist_page_ids:
+                        page["is_bound"] = True
+                    else:
+                        page["is_bound"] = False
+                pre_import_info = {
+                    "workspace_name": source_info["workspace_name"],
+                    "workspace_icon": source_info["workspace_icon"],
+                    "workspace_id": source_info["workspace_id"],
+                    "pages": pages,
+                }
+                pre_import_info_list.append(pre_import_info)
+            return {"notion_info": pre_import_info_list}, 200
 
 
 class DataSourceNotionApi(Resource):
@@ -159,14 +170,17 @@ class DataSourceNotionApi(Resource):
     def get(self, workspace_id, page_id, page_type):
         workspace_id = str(workspace_id)
         page_id = str(page_id)
-        data_source_binding = DataSourceOauthBinding.query.filter(
-            db.and_(
-                DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
-                DataSourceOauthBinding.provider == "notion",
-                DataSourceOauthBinding.disabled == False,
-                DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
-            )
-        ).first()
+        with Session(db.engine) as session:
+            data_source_binding = session.execute(
+                select(DataSourceOauthBinding).filter(
+                    db.and_(
+                        DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+                        DataSourceOauthBinding.provider == "notion",
+                        DataSourceOauthBinding.disabled == False,
+                        DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
+                    )
+                )
+            ).scalar_one_or_none()
         if not data_source_binding:
             raise NotFound("Data source binding not found.")
 

+ 27 - 15
api/controllers/console/datasets/datasets_document.py

@@ -5,7 +5,8 @@ from datetime import datetime, timezone
 from flask import request
 from flask_login import current_user
 from flask_restful import Resource, fields, marshal, marshal_with, reqparse
-from sqlalchemy import asc, desc
+from sqlalchemy import asc, desc, select
+from sqlalchemy.orm import Session
 from werkzeug.exceptions import Forbidden, NotFound
 
 import services
@@ -104,7 +105,8 @@ class GetProcessRuleApi(Resource):
         rules = DocumentService.DEFAULT_RULES["rules"]
         if document_id:
             # get the latest process rule
-            document = Document.query.get_or_404(document_id)
+            with Session(db.engine) as session:
+                document = session.execute(select(Document).get_or_404(document_id)).scalar_one_or_none()
 
             dataset = DatasetService.get_dataset(document.dataset_id)
 
@@ -167,7 +169,10 @@ class DatasetDocumentListApi(Resource):
         except services.errors.account.NoPermissionError as e:
             raise Forbidden(str(e))
 
-        query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
+        with Session(db.engine) as session:
+            query = session.execute(
+                select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
+            ).all()
 
         if search:
             search = f"%{search}%"
@@ -204,18 +209,25 @@ class DatasetDocumentListApi(Resource):
         paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
         documents = paginated_documents.items
         if fetch:
-            for document in documents:
-                completed_segments = DocumentSegment.query.filter(
-                    DocumentSegment.completed_at.isnot(None),
-                    DocumentSegment.document_id == str(document.id),
-                    DocumentSegment.status != "re_segment",
-                ).count()
-                total_segments = DocumentSegment.query.filter(
-                    DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment"
-                ).count()
-                document.completed_segments = completed_segments
-                document.total_segments = total_segments
-            data = marshal(documents, document_with_segments_fields)
+            with Session(db.engine) as session:
+                for document in documents:
+                    completed_segments = (
+                        session.query(DocumentSegment)
+                        .filter(
+                            DocumentSegment.completed_at.isnot(None),
+                            DocumentSegment.document_id == str(document.id),
+                            DocumentSegment.status != "re_segment",
+                        )
+                        .count()
+                    )
+                    total_segments = (
+                        session.query(DocumentSegment)
+                        .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
+                        .count()
+                    )
+                    document.completed_segments = completed_segments
+                    document.total_segments = total_segments
+                data = marshal(documents, document_with_segments_fields)
         else:
             data = marshal(documents, document_fields)
         response = {

+ 8 - 1
api/controllers/console/init_validate.py

@@ -2,8 +2,11 @@ import os
 
 from flask import session
 from flask_restful import Resource, reqparse
+from sqlalchemy import select
+from sqlalchemy.orm import Session
 
 from configs import dify_config
+from extensions.ext_database import db
 from libs.helper import StrLen
 from models.model import DifySetup
 from services.account_service import TenantService
@@ -42,7 +45,11 @@ class InitValidateAPI(Resource):
 def get_init_validate_status():
     if dify_config.EDITION == "SELF_HOSTED":
         if os.environ.get("INIT_PASSWORD"):
-            return session.get("is_init_validated") or DifySetup.query.first()
+            if session.get("is_init_validated"):
+                return True
+
+            with Session(db.engine) as db_session:
+                return db_session.execute(select(DifySetup)).scalar_one_or_none()
 
     return True
 

+ 2 - 1
api/models/account.py

@@ -4,6 +4,7 @@ import json
 from flask_login import UserMixin
 
 from extensions.ext_database import db
+from models.base import Base
 
 from .types import StringUUID
 
@@ -16,7 +17,7 @@ class AccountStatus(str, enum.Enum):
     CLOSED = "closed"
 
 
-class Account(UserMixin, db.Model):
+class Account(UserMixin, Base):
     __tablename__ = "accounts"
     __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
 

+ 1 - 1
api/models/model.py

@@ -38,7 +38,7 @@ class FileUploadConfig(BaseModel):
     number_limits: int = Field(default=0, gt=0, le=10)
 
 
-class DifySetup(db.Model):
+class DifySetup(BaseModel):
     __tablename__ = "dify_setups"
     __table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)