Pārlūkot izejas kodu

fix: optimize DEFAULT-USER

Yeuoly 5 mēneši atpakaļ
vecāks
revīzija
97a3727962
2 mainītis faili ar 31 papildinājumiem un 14 dzēšanām
  1. 30 13
      api/controllers/inner_api/plugin/wraps.py
  2. 1 1
      api/models/model.py

+ 30 - 13
api/controllers/inner_api/plugin/wraps.py

@@ -5,6 +5,7 @@ from typing import Optional
 from flask import request
 from flask import request
 from flask_restful import reqparse
 from flask_restful import reqparse
 from pydantic import BaseModel
 from pydantic import BaseModel
+from sqlalchemy.orm import Session
 
 
 from extensions.ext_database import db
 from extensions.ext_database import db
 from models.account import Account, Tenant
 from models.account import Account, Tenant
@@ -12,19 +13,29 @@ from models.model import EndUser
 from services.account_service import AccountService
 from services.account_service import AccountService
 
 
 
 
-def get_user(user_id: str | None) -> Account | EndUser:
+def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
     try:
     try:
-        if not user_id:
-            user_id = "DEFAULT-USER"
-
-        if user_id == "DEFAULT-USER":
-            user_model = db.session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
-        else:
-            user_model = AccountService.load_user(user_id)
-            if not user_model:
-                user_model = db.session.query(EndUser).filter(EndUser.id == user_id).first()
-            if not user_model:
-                raise ValueError("user not found")
+        with Session(db.engine) as session:
+            if not user_id:
+                user_id = "DEFAULT-USER"
+
+            if user_id == "DEFAULT-USER":
+                user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
+                if not user_model:
+                    user_model = EndUser(
+                        tenant_id=tenant_id,
+                        type="service_api",
+                        is_anonymous=True if user_id == "DEFAULT-USER" else False,
+                        session_id=user_id,
+                    )
+                    session.add(user_model)
+                    session.commit()
+            else:
+                user_model = AccountService.load_user(user_id)
+                if not user_model:
+                    user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
+                if not user_model:
+                    raise ValueError("user not found")
     except Exception:
     except Exception:
         raise ValueError("user not found")
         raise ValueError("user not found")
 
 
@@ -45,6 +56,12 @@ def get_user_tenant(view: Optional[Callable] = None):
             user_id = kwargs.get("user_id")
             user_id = kwargs.get("user_id")
             tenant_id = kwargs.get("tenant_id")
             tenant_id = kwargs.get("tenant_id")
 
 
+            if not tenant_id:
+                raise ValueError("tenant_id is required")
+
+            if not user_id:
+                user_id = "DEFAULT-USER"
+
             del kwargs["tenant_id"]
             del kwargs["tenant_id"]
             del kwargs["user_id"]
             del kwargs["user_id"]
 
 
@@ -63,7 +80,7 @@ def get_user_tenant(view: Optional[Callable] = None):
                 raise ValueError("tenant not found")
                 raise ValueError("tenant not found")
 
 
             kwargs["tenant_model"] = tenant_model
             kwargs["tenant_model"] = tenant_model
-            kwargs["user_model"] = get_user(user_id)
+            kwargs["user_model"] = get_user(tenant_id, user_id)
 
 
             return view_func(*args, **kwargs)
             return view_func(*args, **kwargs)
 
 

+ 1 - 1
api/models/model.py

@@ -1260,7 +1260,7 @@ class OperationLog(Base):
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
     updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 
 
 
 
-class EndUser(UserMixin, Base):
+class EndUser(Base, UserMixin):
     __tablename__ = "end_users"
     __tablename__ = "end_users"
     __table_args__ = (
     __table_args__ = (
         db.PrimaryKeyConstraint("id", name="end_user_pkey"),
         db.PrimaryKeyConstraint("id", name="end_user_pkey"),