|
@@ -27,6 +27,9 @@ class RateLimit:
|
|
|
|
|
|
def __init__(self, client_id: str, max_active_requests: int):
|
|
|
self.max_active_requests = max_active_requests
|
|
|
+ # must be called after max_active_requests is set
|
|
|
+ if self.disabled():
|
|
|
+ return
|
|
|
if hasattr(self, "initialized"):
|
|
|
return
|
|
|
self.initialized = True
|
|
@@ -37,6 +40,8 @@ class RateLimit:
|
|
|
self.flush_cache(use_local_value=True)
|
|
|
|
|
|
def flush_cache(self, use_local_value=False):
|
|
|
+ if self.disabled():
|
|
|
+ return
|
|
|
self.last_recalculate_time = time.time()
|
|
|
# flush max active requests
|
|
|
if use_local_value or not redis_client.exists(self.max_active_requests_key):
|
|
@@ -59,18 +64,18 @@ class RateLimit:
|
|
|
redis_client.hdel(self.active_requests_key, *timeout_requests)
|
|
|
|
|
|
def enter(self, request_id: Optional[str] = None) -> str:
|
|
|
+ if self.disabled():
|
|
|
+ return RateLimit._UNLIMITED_REQUEST_ID
|
|
|
if time.time() - self.last_recalculate_time > RateLimit._ACTIVE_REQUESTS_COUNT_FLUSH_INTERVAL:
|
|
|
self.flush_cache()
|
|
|
- if self.max_active_requests <= 0:
|
|
|
- return RateLimit._UNLIMITED_REQUEST_ID
|
|
|
if not request_id:
|
|
|
request_id = RateLimit.gen_request_key()
|
|
|
|
|
|
active_requests_count = redis_client.hlen(self.active_requests_key)
|
|
|
if active_requests_count >= self.max_active_requests:
|
|
|
raise AppInvokeQuotaExceededError(
|
|
|
- "Too many requests. Please try again later. The current maximum "
|
|
|
- "concurrent requests allowed is {}.".format(self.max_active_requests)
|
|
|
+ f"Too many requests. Please try again later. The current maximum concurrent requests allowed "
|
|
|
+ f"for {self.client_id} is {self.max_active_requests}."
|
|
|
)
|
|
|
redis_client.hset(self.active_requests_key, request_id, str(time.time()))
|
|
|
return request_id
|
|
@@ -80,6 +85,9 @@ class RateLimit:
|
|
|
return
|
|
|
redis_client.hdel(self.active_requests_key, request_id)
|
|
|
|
|
|
+ def disabled(self):
|
|
|
+ return self.max_active_requests <= 0
|
|
|
+
|
|
|
@staticmethod
|
|
|
def gen_request_key() -> str:
|
|
|
return str(uuid.uuid4())
|