wraps.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. from functools import wraps
  2. from flask import request
  3. from flask_restful import Resource
  4. from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
  5. from controllers.web.error import WebSSOAuthRequiredError
  6. from extensions.ext_database import db
  7. from libs.passport import PassportService
  8. from models.model import App, EndUser, Site
  9. from services.feature_service import FeatureService
  10. def validate_jwt_token(view=None):
  11. def decorator(view):
  12. @wraps(view)
  13. def decorated(*args, **kwargs):
  14. app_model, end_user = decode_jwt_token()
  15. return view(app_model, end_user, *args, **kwargs)
  16. return decorated
  17. if view:
  18. return decorator(view)
  19. return decorator
  20. def decode_jwt_token():
  21. system_features = FeatureService.get_system_features()
  22. try:
  23. auth_header = request.headers.get('Authorization')
  24. if auth_header is None:
  25. raise Unauthorized('Authorization header is missing.')
  26. if ' ' not in auth_header:
  27. raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
  28. auth_scheme, tk = auth_header.split(None, 1)
  29. auth_scheme = auth_scheme.lower()
  30. if auth_scheme != 'bearer':
  31. raise Unauthorized('Invalid Authorization header format. Expected \'Bearer <api-key>\' format.')
  32. decoded = PassportService().verify(tk)
  33. app_code = decoded.get('app_code')
  34. app_model = db.session.query(App).filter(App.id == decoded['app_id']).first()
  35. site = db.session.query(Site).filter(Site.code == app_code).first()
  36. if not app_model:
  37. raise NotFound()
  38. if not app_code or not site:
  39. raise BadRequest('Site URL is no longer valid.')
  40. if app_model.enable_site is False:
  41. raise BadRequest('Site is disabled.')
  42. end_user = db.session.query(EndUser).filter(EndUser.id == decoded['end_user_id']).first()
  43. if not end_user:
  44. raise NotFound()
  45. _validate_web_sso_token(decoded, system_features)
  46. return app_model, end_user
  47. except Unauthorized as e:
  48. if system_features.sso_enforced_for_web:
  49. raise WebSSOAuthRequiredError()
  50. raise Unauthorized(e.description)
  51. def _validate_web_sso_token(decoded, system_features):
  52. # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login
  53. if system_features.sso_enforced_for_web:
  54. source = decoded.get('token_source')
  55. if not source or source != 'sso':
  56. raise WebSSOAuthRequiredError()
  57. # Check if SSO is not enforced for web, and if the token source is SSO, raise an error and redirect to normal passport login
  58. if not system_features.sso_enforced_for_web:
  59. source = decoded.get('token_source')
  60. if source and source == 'sso':
  61. raise Unauthorized('sso token expired.')
  62. class WebApiResource(Resource):
  63. method_decorators = [validate_jwt_token]