|
@@ -36,8 +36,7 @@ def email(email):
|
|
if re.match(pattern, email) is not None:
|
|
if re.match(pattern, email) is not None:
|
|
return email
|
|
return email
|
|
|
|
|
|
- error = ('{email} is not a valid email.'
|
|
|
|
- .format(email=email))
|
|
|
|
|
|
+ error = '{email} is not a valid email.'.format(email=email)
|
|
raise ValueError(error)
|
|
raise ValueError(error)
|
|
|
|
|
|
|
|
|
|
@@ -49,10 +48,10 @@ def uuid_value(value):
|
|
uuid_obj = uuid.UUID(value)
|
|
uuid_obj = uuid.UUID(value)
|
|
return str(uuid_obj)
|
|
return str(uuid_obj)
|
|
except ValueError:
|
|
except ValueError:
|
|
- error = ('{value} is not a valid uuid.'
|
|
|
|
- .format(value=value))
|
|
|
|
|
|
+ error = '{value} is not a valid uuid.'.format(value=value)
|
|
raise ValueError(error)
|
|
raise ValueError(error)
|
|
|
|
|
|
|
|
+
|
|
def alphanumeric(value: str):
|
|
def alphanumeric(value: str):
|
|
# check if the value is alphanumeric and underlined
|
|
# check if the value is alphanumeric and underlined
|
|
if re.match(r'^[a-zA-Z0-9_]+$', value):
|
|
if re.match(r'^[a-zA-Z0-9_]+$', value):
|
|
@@ -60,6 +59,7 @@ def alphanumeric(value: str):
|
|
|
|
|
|
raise ValueError(f'{value} is not a valid alphanumeric value')
|
|
raise ValueError(f'{value} is not a valid alphanumeric value')
|
|
|
|
|
|
|
|
+
|
|
def timestamp_value(timestamp):
|
|
def timestamp_value(timestamp):
|
|
try:
|
|
try:
|
|
int_timestamp = int(timestamp)
|
|
int_timestamp = int(timestamp)
|
|
@@ -67,13 +67,12 @@ def timestamp_value(timestamp):
|
|
raise ValueError
|
|
raise ValueError
|
|
return int_timestamp
|
|
return int_timestamp
|
|
except ValueError:
|
|
except ValueError:
|
|
- error = ('{timestamp} is not a valid timestamp.'
|
|
|
|
- .format(timestamp=timestamp))
|
|
|
|
|
|
+ error = '{timestamp} is not a valid timestamp.'.format(timestamp=timestamp)
|
|
raise ValueError(error)
|
|
raise ValueError(error)
|
|
|
|
|
|
|
|
|
|
class str_len:
|
|
class str_len:
|
|
- """ Restrict input to an integer in a range (inclusive) """
|
|
|
|
|
|
+ """Restrict input to an integer in a range (inclusive)"""
|
|
|
|
|
|
def __init__(self, max_length, argument='argument'):
|
|
def __init__(self, max_length, argument='argument'):
|
|
self.max_length = max_length
|
|
self.max_length = max_length
|
|
@@ -82,15 +81,17 @@ class str_len:
|
|
def __call__(self, value):
|
|
def __call__(self, value):
|
|
length = len(value)
|
|
length = len(value)
|
|
if length > self.max_length:
|
|
if length > self.max_length:
|
|
- error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}'
|
|
|
|
- .format(arg=self.argument, val=value, length=self.max_length))
|
|
|
|
|
|
+ error = 'Invalid {arg}: {val}. {arg} cannot exceed length {length}'.format(
|
|
|
|
+ arg=self.argument, val=value, length=self.max_length
|
|
|
|
+ )
|
|
raise ValueError(error)
|
|
raise ValueError(error)
|
|
|
|
|
|
return value
|
|
return value
|
|
|
|
|
|
|
|
|
|
class float_range:
|
|
class float_range:
|
|
- """ Restrict input to an float in a range (inclusive) """
|
|
|
|
|
|
+ """Restrict input to an float in a range (inclusive)"""
|
|
|
|
+
|
|
def __init__(self, low, high, argument='argument'):
|
|
def __init__(self, low, high, argument='argument'):
|
|
self.low = low
|
|
self.low = low
|
|
self.high = high
|
|
self.high = high
|
|
@@ -99,8 +100,9 @@ class float_range:
|
|
def __call__(self, value):
|
|
def __call__(self, value):
|
|
value = _get_float(value)
|
|
value = _get_float(value)
|
|
if value < self.low or value > self.high:
|
|
if value < self.low or value > self.high:
|
|
- error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'
|
|
|
|
- .format(arg=self.argument, val=value, lo=self.low, hi=self.high))
|
|
|
|
|
|
+ error = 'Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'.format(
|
|
|
|
+ arg=self.argument, val=value, lo=self.low, hi=self.high
|
|
|
|
+ )
|
|
raise ValueError(error)
|
|
raise ValueError(error)
|
|
|
|
|
|
return value
|
|
return value
|
|
@@ -115,8 +117,9 @@ class datetime_string:
|
|
try:
|
|
try:
|
|
datetime.strptime(value, self.format)
|
|
datetime.strptime(value, self.format)
|
|
except ValueError:
|
|
except ValueError:
|
|
- error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}'
|
|
|
|
- .format(arg=self.argument, val=value, format=self.format))
|
|
|
|
|
|
+ error = 'Invalid {arg}: {val}. {arg} must be conform to the format {format}'.format(
|
|
|
|
+ arg=self.argument, val=value, format=self.format
|
|
|
|
+ )
|
|
raise ValueError(error)
|
|
raise ValueError(error)
|
|
|
|
|
|
return value
|
|
return value
|
|
@@ -128,18 +131,18 @@ def _get_float(value):
|
|
except (TypeError, ValueError):
|
|
except (TypeError, ValueError):
|
|
raise ValueError('{} is not a valid float'.format(value))
|
|
raise ValueError('{} is not a valid float'.format(value))
|
|
|
|
|
|
|
|
+
|
|
def timezone(timezone_string):
|
|
def timezone(timezone_string):
|
|
if timezone_string and timezone_string in available_timezones():
|
|
if timezone_string and timezone_string in available_timezones():
|
|
return timezone_string
|
|
return timezone_string
|
|
|
|
|
|
- error = ('{timezone_string} is not a valid timezone.'
|
|
|
|
- .format(timezone_string=timezone_string))
|
|
|
|
|
|
+ error = '{timezone_string} is not a valid timezone.'.format(timezone_string=timezone_string)
|
|
raise ValueError(error)
|
|
raise ValueError(error)
|
|
|
|
|
|
|
|
|
|
def generate_string(n):
|
|
def generate_string(n):
|
|
letters_digits = string.ascii_letters + string.digits
|
|
letters_digits = string.ascii_letters + string.digits
|
|
- result = ""
|
|
|
|
|
|
+ result = ''
|
|
for i in range(n):
|
|
for i in range(n):
|
|
result += random.choice(letters_digits)
|
|
result += random.choice(letters_digits)
|
|
|
|
|
|
@@ -149,8 +152,8 @@ def generate_string(n):
|
|
def get_remote_ip(request) -> str:
|
|
def get_remote_ip(request) -> str:
|
|
if request.headers.get('CF-Connecting-IP'):
|
|
if request.headers.get('CF-Connecting-IP'):
|
|
return request.headers.get('Cf-Connecting-Ip')
|
|
return request.headers.get('Cf-Connecting-Ip')
|
|
- elif request.headers.getlist("X-Forwarded-For"):
|
|
|
|
- return request.headers.getlist("X-Forwarded-For")[0]
|
|
|
|
|
|
+ elif request.headers.getlist('X-Forwarded-For'):
|
|
|
|
+ return request.headers.getlist('X-Forwarded-For')[0]
|
|
else:
|
|
else:
|
|
return request.remote_addr
|
|
return request.remote_addr
|
|
|
|
|
|
@@ -160,19 +163,24 @@ def generate_text_hash(text: str) -> str:
|
|
return sha256(hash_text.encode()).hexdigest()
|
|
return sha256(hash_text.encode()).hexdigest()
|
|
|
|
|
|
|
|
|
|
-def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response:
|
|
|
|
|
|
+def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response:
|
|
if isinstance(response, dict):
|
|
if isinstance(response, dict):
|
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
|
return Response(response=json.dumps(response), status=200, mimetype='application/json')
|
|
else:
|
|
else:
|
|
|
|
+
|
|
def generate() -> Generator:
|
|
def generate() -> Generator:
|
|
- yield from response
|
|
|
|
|
|
+ for data in response:
|
|
|
|
+ if isinstance(data, dict):
|
|
|
|
+ yield json.dumps(data).encode()
|
|
|
|
+ if isinstance(data, str):
|
|
|
|
+ yield data.encode()
|
|
|
|
+ else:
|
|
|
|
+ yield data
|
|
|
|
|
|
- return Response(stream_with_context(generate()), status=200,
|
|
|
|
- mimetype='text/event-stream')
|
|
|
|
|
|
+ return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream')
|
|
|
|
|
|
|
|
|
|
class TokenManager:
|
|
class TokenManager:
|
|
-
|
|
|
|
@classmethod
|
|
@classmethod
|
|
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
|
|
def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
|
|
old_token = cls._get_current_token_for_account(account.id, token_type)
|
|
old_token = cls._get_current_token_for_account(account.id, token_type)
|
|
@@ -182,21 +190,13 @@ class TokenManager:
|
|
cls.revoke_token(old_token, token_type)
|
|
cls.revoke_token(old_token, token_type)
|
|
|
|
|
|
token = str(uuid.uuid4())
|
|
token = str(uuid.uuid4())
|
|
- token_data = {
|
|
|
|
- 'account_id': account.id,
|
|
|
|
- 'email': account.email,
|
|
|
|
- 'token_type': token_type
|
|
|
|
- }
|
|
|
|
|
|
+ token_data = {'account_id': account.id, 'email': account.email, 'token_type': token_type}
|
|
if additional_data:
|
|
if additional_data:
|
|
token_data.update(additional_data)
|
|
token_data.update(additional_data)
|
|
|
|
|
|
expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
|
|
expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
|
|
token_key = cls._get_token_key(token, token_type)
|
|
token_key = cls._get_token_key(token, token_type)
|
|
- redis_client.setex(
|
|
|
|
- token_key,
|
|
|
|
- expiry_hours * 60 * 60,
|
|
|
|
- json.dumps(token_data)
|
|
|
|
- )
|
|
|
|
|
|
+ redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data))
|
|
|
|
|
|
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
|
|
cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
|
|
return token
|
|
return token
|
|
@@ -215,7 +215,7 @@ class TokenManager:
|
|
key = cls._get_token_key(token, token_type)
|
|
key = cls._get_token_key(token, token_type)
|
|
token_data_json = redis_client.get(key)
|
|
token_data_json = redis_client.get(key)
|
|
if token_data_json is None:
|
|
if token_data_json is None:
|
|
- logging.warning(f"{token_type} token {token} not found with key {key}")
|
|
|
|
|
|
+ logging.warning(f'{token_type} token {token} not found with key {key}')
|
|
return None
|
|
return None
|
|
token_data = json.loads(token_data_json)
|
|
token_data = json.loads(token_data_json)
|
|
return token_data
|
|
return token_data
|
|
@@ -243,7 +243,7 @@ class RateLimiter:
|
|
self.time_window = time_window
|
|
self.time_window = time_window
|
|
|
|
|
|
def _get_key(self, email: str) -> str:
|
|
def _get_key(self, email: str) -> str:
|
|
- return f"{self.prefix}:{email}"
|
|
|
|
|
|
+ return f'{self.prefix}:{email}'
|
|
|
|
|
|
def is_rate_limited(self, email: str) -> bool:
|
|
def is_rate_limited(self, email: str) -> bool:
|
|
key = self._get_key(email)
|
|
key = self._get_key(email)
|