|
@@ -4,6 +4,8 @@ import json
|
|
|
from flask import request
|
|
|
from flask_login import current_user
|
|
|
from flask_restful import Resource, marshal_with, reqparse
|
|
|
+from sqlalchemy import select
|
|
|
+from sqlalchemy.orm import Session
|
|
|
from werkzeug.exceptions import NotFound
|
|
|
|
|
|
from controllers.console import api
|
|
@@ -77,7 +79,10 @@ class DataSourceApi(Resource):
|
|
|
def patch(self, binding_id, action):
|
|
|
binding_id = str(binding_id)
|
|
|
action = str(action)
|
|
|
- data_source_binding = DataSourceOauthBinding.query.filter_by(id=binding_id).first()
|
|
|
+ with Session(db.engine) as session:
|
|
|
+ data_source_binding = session.execute(
|
|
|
+ select(DataSourceOauthBinding).filter_by(id=binding_id)
|
|
|
+ ).scalar_one_or_none()
|
|
|
if data_source_binding is None:
|
|
|
raise NotFound("Data source binding not found.")
|
|
|
# enable binding
|
|
@@ -109,47 +114,53 @@ class DataSourceNotionListApi(Resource):
|
|
|
def get(self):
|
|
|
dataset_id = request.args.get("dataset_id", default=None, type=str)
|
|
|
exist_page_ids = []
|
|
|
- # import notion in the exist dataset
|
|
|
- if dataset_id:
|
|
|
- dataset = DatasetService.get_dataset(dataset_id)
|
|
|
- if not dataset:
|
|
|
- raise NotFound("Dataset not found.")
|
|
|
- if dataset.data_source_type != "notion_import":
|
|
|
- raise ValueError("Dataset is not notion type.")
|
|
|
- documents = Document.query.filter_by(
|
|
|
- dataset_id=dataset_id,
|
|
|
- tenant_id=current_user.current_tenant_id,
|
|
|
- data_source_type="notion_import",
|
|
|
- enabled=True,
|
|
|
+ with Session(db.engine) as session:
|
|
|
+ # import notion in the exist dataset
|
|
|
+ if dataset_id:
|
|
|
+ dataset = DatasetService.get_dataset(dataset_id)
|
|
|
+ if not dataset:
|
|
|
+ raise NotFound("Dataset not found.")
|
|
|
+ if dataset.data_source_type != "notion_import":
|
|
|
+ raise ValueError("Dataset is not notion type.")
|
|
|
+
|
|
|
+ documents = session.execute(
|
|
|
+ select(Document).filter_by(
|
|
|
+ dataset_id=dataset_id,
|
|
|
+ tenant_id=current_user.current_tenant_id,
|
|
|
+ data_source_type="notion_import",
|
|
|
+ enabled=True,
|
|
|
+ )
|
|
|
+ ).all()
|
|
|
+ if documents:
|
|
|
+ for document in documents:
|
|
|
+ data_source_info = json.loads(document.data_source_info)
|
|
|
+ exist_page_ids.append(data_source_info["notion_page_id"])
|
|
|
+ # get all authorized pages
|
|
|
+ data_source_bindings = session.execute(
|
|
|
+ select(DataSourceOauthBinding).filter_by(
|
|
|
+ tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
|
|
+ )
|
|
|
).all()
|
|
|
- if documents:
|
|
|
- for document in documents:
|
|
|
- data_source_info = json.loads(document.data_source_info)
|
|
|
- exist_page_ids.append(data_source_info["notion_page_id"])
|
|
|
- # get all authorized pages
|
|
|
- data_source_bindings = DataSourceOauthBinding.query.filter_by(
|
|
|
- tenant_id=current_user.current_tenant_id, provider="notion", disabled=False
|
|
|
- ).all()
|
|
|
- if not data_source_bindings:
|
|
|
- return {"notion_info": []}, 200
|
|
|
- pre_import_info_list = []
|
|
|
- for data_source_binding in data_source_bindings:
|
|
|
- source_info = data_source_binding.source_info
|
|
|
- pages = source_info["pages"]
|
|
|
- # Filter out already bound pages
|
|
|
- for page in pages:
|
|
|
- if page["page_id"] in exist_page_ids:
|
|
|
- page["is_bound"] = True
|
|
|
- else:
|
|
|
- page["is_bound"] = False
|
|
|
- pre_import_info = {
|
|
|
- "workspace_name": source_info["workspace_name"],
|
|
|
- "workspace_icon": source_info["workspace_icon"],
|
|
|
- "workspace_id": source_info["workspace_id"],
|
|
|
- "pages": pages,
|
|
|
- }
|
|
|
- pre_import_info_list.append(pre_import_info)
|
|
|
- return {"notion_info": pre_import_info_list}, 200
|
|
|
+ if not data_source_bindings:
|
|
|
+ return {"notion_info": []}, 200
|
|
|
+ pre_import_info_list = []
|
|
|
+ for data_source_binding in data_source_bindings:
|
|
|
+ source_info = data_source_binding.source_info
|
|
|
+ pages = source_info["pages"]
|
|
|
+ # Filter out already bound pages
|
|
|
+ for page in pages:
|
|
|
+ if page["page_id"] in exist_page_ids:
|
|
|
+ page["is_bound"] = True
|
|
|
+ else:
|
|
|
+ page["is_bound"] = False
|
|
|
+ pre_import_info = {
|
|
|
+ "workspace_name": source_info["workspace_name"],
|
|
|
+ "workspace_icon": source_info["workspace_icon"],
|
|
|
+ "workspace_id": source_info["workspace_id"],
|
|
|
+ "pages": pages,
|
|
|
+ }
|
|
|
+ pre_import_info_list.append(pre_import_info)
|
|
|
+ return {"notion_info": pre_import_info_list}, 200
|
|
|
|
|
|
|
|
|
class DataSourceNotionApi(Resource):
|
|
@@ -159,14 +170,17 @@ class DataSourceNotionApi(Resource):
|
|
|
def get(self, workspace_id, page_id, page_type):
|
|
|
workspace_id = str(workspace_id)
|
|
|
page_id = str(page_id)
|
|
|
- data_source_binding = DataSourceOauthBinding.query.filter(
|
|
|
- db.and_(
|
|
|
- DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
|
|
- DataSourceOauthBinding.provider == "notion",
|
|
|
- DataSourceOauthBinding.disabled == False,
|
|
|
- DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
|
|
- )
|
|
|
- ).first()
|
|
|
+ with Session(db.engine) as session:
|
|
|
+ data_source_binding = session.execute(
|
|
|
+ select(DataSourceOauthBinding).filter(
|
|
|
+ db.and_(
|
|
|
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
|
|
|
+ DataSourceOauthBinding.provider == "notion",
|
|
|
+ DataSourceOauthBinding.disabled == False,
|
|
|
+ DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
|
|
|
+ )
|
|
|
+ )
|
|
|
+ ).scalar_one_or_none()
|
|
|
if not data_source_binding:
|
|
|
raise NotFound("Data source binding not found.")
|
|
|
|