wraps.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. from collections.abc import Callable
  2. from functools import wraps
  3. from typing import Optional
  4. from flask import request
  5. from flask_restful import reqparse
  6. from pydantic import BaseModel
  7. from extensions.ext_database import db
  8. from models.account import Account, Tenant
  9. from models.model import EndUser
  10. from services.account_service import AccountService
  11. def get_user(user_id: str | None) -> Account | EndUser:
  12. try:
  13. if not user_id:
  14. user_id = "DEFAULT-USER"
  15. if user_id == "DEFAULT-USER":
  16. user_model = db.session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
  17. else:
  18. user_model = AccountService.load_user(user_id)
  19. if not user_model:
  20. user_model = db.session.query(EndUser).filter(EndUser.id == user_id).first()
  21. if not user_model:
  22. raise ValueError("user not found")
  23. except Exception:
  24. raise ValueError("user not found")
  25. return user_model
  26. def get_user_tenant(view: Optional[Callable] = None):
  27. def decorator(view_func):
  28. @wraps(view_func)
  29. def decorated_view(*args, **kwargs):
  30. # fetch json body
  31. parser = reqparse.RequestParser()
  32. parser.add_argument("tenant_id", type=str, required=True, location="json")
  33. parser.add_argument("user_id", type=str, required=True, location="json")
  34. kwargs = parser.parse_args()
  35. user_id = kwargs.get("user_id")
  36. tenant_id = kwargs.get("tenant_id")
  37. del kwargs["tenant_id"]
  38. del kwargs["user_id"]
  39. try:
  40. tenant_model = (
  41. db.session.query(Tenant)
  42. .filter(
  43. Tenant.id == tenant_id,
  44. )
  45. .first()
  46. )
  47. except Exception:
  48. raise ValueError("tenant not found")
  49. if not tenant_model:
  50. raise ValueError("tenant not found")
  51. kwargs["tenant_model"] = tenant_model
  52. kwargs["user_model"] = get_user(user_id)
  53. return view_func(*args, **kwargs)
  54. return decorated_view
  55. if view is None:
  56. return decorator
  57. else:
  58. return decorator(view)
  59. def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
  60. def decorator(view_func):
  61. def decorated_view(*args, **kwargs):
  62. try:
  63. data = request.get_json()
  64. except Exception:
  65. raise ValueError("invalid json")
  66. try:
  67. payload = payload_type(**data)
  68. except Exception as e:
  69. raise ValueError(f"invalid payload: {str(e)}")
  70. kwargs["payload"] = payload
  71. return view_func(*args, **kwargs)
  72. return decorated_view
  73. if view is None:
  74. return decorator
  75. else:
  76. return decorator(view)