|
@@ -5,6 +5,7 @@ from typing import Optional
|
|
from flask import request
|
|
from flask import request
|
|
from flask_restful import reqparse
|
|
from flask_restful import reqparse
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
|
|
+from sqlalchemy.orm import Session
|
|
|
|
|
|
from extensions.ext_database import db
|
|
from extensions.ext_database import db
|
|
from models.account import Account, Tenant
|
|
from models.account import Account, Tenant
|
|
@@ -12,19 +13,29 @@ from models.model import EndUser
|
|
from services.account_service import AccountService
|
|
from services.account_service import AccountService
|
|
|
|
|
|
|
|
|
|
-def get_user(user_id: str | None) -> Account | EndUser:
|
|
|
|
|
|
+def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
|
|
try:
|
|
try:
|
|
- if not user_id:
|
|
|
|
- user_id = "DEFAULT-USER"
|
|
|
|
-
|
|
|
|
- if user_id == "DEFAULT-USER":
|
|
|
|
- user_model = db.session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
|
|
|
|
- else:
|
|
|
|
- user_model = AccountService.load_user(user_id)
|
|
|
|
- if not user_model:
|
|
|
|
- user_model = db.session.query(EndUser).filter(EndUser.id == user_id).first()
|
|
|
|
- if not user_model:
|
|
|
|
- raise ValueError("user not found")
|
|
|
|
|
|
+ with Session(db.engine) as session:
|
|
|
|
+ if not user_id:
|
|
|
|
+ user_id = "DEFAULT-USER"
|
|
|
|
+
|
|
|
|
+ if user_id == "DEFAULT-USER":
|
|
|
|
+ user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
|
|
|
|
+ if not user_model:
|
|
|
|
+ user_model = EndUser(
|
|
|
|
+ tenant_id=tenant_id,
|
|
|
|
+ type="service_api",
|
|
|
|
+ is_anonymous=True if user_id == "DEFAULT-USER" else False,
|
|
|
|
+ session_id=user_id,
|
|
|
|
+ )
|
|
|
|
+ session.add(user_model)
|
|
|
|
+ session.commit()
|
|
|
|
+ else:
|
|
|
|
+ user_model = AccountService.load_user(user_id)
|
|
|
|
+ if not user_model:
|
|
|
|
+ user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
|
|
|
|
+ if not user_model:
|
|
|
|
+ raise ValueError("user not found")
|
|
except Exception:
|
|
except Exception:
|
|
raise ValueError("user not found")
|
|
raise ValueError("user not found")
|
|
|
|
|
|
@@ -45,6 +56,12 @@ def get_user_tenant(view: Optional[Callable] = None):
|
|
user_id = kwargs.get("user_id")
|
|
user_id = kwargs.get("user_id")
|
|
tenant_id = kwargs.get("tenant_id")
|
|
tenant_id = kwargs.get("tenant_id")
|
|
|
|
|
|
|
|
+ if not tenant_id:
|
|
|
|
+ raise ValueError("tenant_id is required")
|
|
|
|
+
|
|
|
|
+ if not user_id:
|
|
|
|
+ user_id = "DEFAULT-USER"
|
|
|
|
+
|
|
del kwargs["tenant_id"]
|
|
del kwargs["tenant_id"]
|
|
del kwargs["user_id"]
|
|
del kwargs["user_id"]
|
|
|
|
|
|
@@ -63,7 +80,7 @@ def get_user_tenant(view: Optional[Callable] = None):
|
|
raise ValueError("tenant not found")
|
|
raise ValueError("tenant not found")
|
|
|
|
|
|
kwargs["tenant_model"] = tenant_model
|
|
kwargs["tenant_model"] = tenant_model
|
|
- kwargs["user_model"] = get_user(user_id)
|
|
|
|
|
|
+ kwargs["user_model"] = get_user(tenant_id, user_id)
|
|
|
|
|
|
return view_func(*args, **kwargs)
|
|
return view_func(*args, **kwargs)
|
|
|
|
|