Parcourir la source

refactor: document segment query

Yeuoly il y a 7 mois
Parent
commit
685e8cdc7d
1 fichiers modifiés avec 46 ajouts et 45 suppressions
  1. 46 45
      api/controllers/console/datasets/datasets_document.py

+ 46 - 45
api/controllers/console/datasets/datasets_document.py

@@ -170,46 +170,47 @@ class DatasetDocumentListApi(Resource):
             raise Forbidden(str(e))
 
         with Session(db.engine) as session:
-            query = session.execute(
-                select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
-            ).all()
+            query = session.query(Document).filter_by(
+                dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id
+            )
 
-        if search:
-            search = f"%{search}%"
-            query = query.filter(Document.name.like(search))
+            if search:
+                search = f"%{search}%"
+                query = query.filter(Document.name.like(search))
 
-        if sort.startswith("-"):
-            sort_logic = desc
-            sort = sort[1:]
-        else:
-            sort_logic = asc
+            if sort.startswith("-"):
+                sort_logic = desc
+                sort = sort[1:]
+            else:
+                sort_logic = asc
 
-        if sort == "hit_count":
-            sub_query = (
-                db.select(DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count"))
-                .group_by(DocumentSegment.document_id)
-                .subquery()
-            )
+            if sort == "hit_count":
+                sub_query = (
+                    db.select(
+                        DocumentSegment.document_id, db.func.sum(DocumentSegment.hit_count).label("total_hit_count")
+                    )
+                    .group_by(DocumentSegment.document_id)
+                    .subquery()
+                )
 
-            query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
-                sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
-                sort_logic(Document.position),
-            )
-        elif sort == "created_at":
-            query = query.order_by(
-                sort_logic(Document.created_at),
-                sort_logic(Document.position),
-            )
-        else:
-            query = query.order_by(
-                desc(Document.created_at),
-                desc(Document.position),
-            )
+                query = query.outerjoin(sub_query, sub_query.c.document_id == Document.id).order_by(
+                    sort_logic(db.func.coalesce(sub_query.c.total_hit_count, 0)),
+                    sort_logic(Document.position),
+                )
+            elif sort == "created_at":
+                query = query.order_by(
+                    sort_logic(Document.created_at),
+                    sort_logic(Document.position),
+                )
+            else:
+                query = query.order_by(
+                    desc(Document.created_at),
+                    desc(Document.position),
+                )
 
-        paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
-        documents = paginated_documents.items
-        if fetch:
-            with Session(db.engine) as session:
+            paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
+            documents = paginated_documents.items
+            if fetch:
                 for document in documents:
                     completed_segments = (
                         session.query(DocumentSegment)
@@ -228,17 +229,17 @@ class DatasetDocumentListApi(Resource):
                     document.completed_segments = completed_segments
                     document.total_segments = total_segments
                 data = marshal(documents, document_with_segments_fields)
-        else:
-            data = marshal(documents, document_fields)
-        response = {
-            "data": data,
-            "has_more": len(documents) == limit,
-            "limit": limit,
-            "total": paginated_documents.total,
-            "page": page,
-        }
+            else:
+                data = marshal(documents, document_fields)
+            response = {
+                "data": data,
+                "has_more": len(documents) == limit,
+                "limit": limit,
+                "total": paginated_documents.total,
+                "page": page,
+            }
 
-        return response
+            return response
 
     documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String}