Browse Source

feat: support backwards invocation

Yeuoly 10 months ago
parent
commit
d52476c1c9

+ 32 - 13
api/controllers/inner_api/plugin/plugin.py

@@ -1,10 +1,20 @@
+import time
 from flask_restful import Resource, reqparse
 from flask_restful import Resource, reqparse
 
 
 from controllers.console.setup import setup_required
 from controllers.console.setup import setup_required
 from controllers.inner_api import api
 from controllers.inner_api import api
 from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
 from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
 from controllers.inner_api.wraps import plugin_inner_api_only
 from controllers.inner_api.wraps import plugin_inner_api_only
-from core.plugin.entities.request import RequestInvokeLLM, RequestInvokeModeration, RequestInvokeRerank, RequestInvokeSpeech2Text, RequestInvokeTTS, RequestInvokeTextEmbedding, RequestInvokeTool
+from core.plugin.entities.request import (
+    RequestInvokeLLM,
+    RequestInvokeModeration,
+    RequestInvokeRerank,
+    RequestInvokeSpeech2Text,
+    RequestInvokeTextEmbedding,
+    RequestInvokeTool,
+    RequestInvokeTTS,
+)
+from core.tools.entities.tool_entities import ToolInvokeMessage
 from libs.helper import compact_generate_response
 from libs.helper import compact_generate_response
 from models.account import Tenant
 from models.account import Tenant
 from services.plugin.plugin_invoke_service import PluginInvokeService
 from services.plugin.plugin_invoke_service import PluginInvokeService
@@ -18,6 +28,7 @@ class PluginInvokeLLMApi(Resource):
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM):
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeLLM):
         pass
         pass
 
 
+
 class PluginInvokeTextEmbeddingApi(Resource):
 class PluginInvokeTextEmbeddingApi(Resource):
     @setup_required
     @setup_required
     @plugin_inner_api_only
     @plugin_inner_api_only
@@ -26,6 +37,7 @@ class PluginInvokeTextEmbeddingApi(Resource):
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
         pass
         pass
 
 
+
 class PluginInvokeRerankApi(Resource):
 class PluginInvokeRerankApi(Resource):
     @setup_required
     @setup_required
     @plugin_inner_api_only
     @plugin_inner_api_only
@@ -34,6 +46,7 @@ class PluginInvokeRerankApi(Resource):
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeRerank):
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeRerank):
         pass
         pass
 
 
+
 class PluginInvokeTTSApi(Resource):
 class PluginInvokeTTSApi(Resource):
     @setup_required
     @setup_required
     @plugin_inner_api_only
     @plugin_inner_api_only
@@ -42,6 +55,7 @@ class PluginInvokeTTSApi(Resource):
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTTS):
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTTS):
         pass
         pass
 
 
+
 class PluginInvokeSpeech2TextApi(Resource):
 class PluginInvokeSpeech2TextApi(Resource):
     @setup_required
     @setup_required
     @plugin_inner_api_only
     @plugin_inner_api_only
@@ -50,6 +64,7 @@ class PluginInvokeSpeech2TextApi(Resource):
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
         pass
         pass
 
 
+
 class PluginInvokeModerationApi(Resource):
 class PluginInvokeModerationApi(Resource):
     @setup_required
     @setup_required
     @plugin_inner_api_only
     @plugin_inner_api_only
@@ -58,23 +73,27 @@ class PluginInvokeModerationApi(Resource):
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeModeration):
     def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeModeration):
         pass
         pass
 
 
+
 class PluginInvokeToolApi(Resource):
 class PluginInvokeToolApi(Resource):
     @setup_required
     @setup_required
     @plugin_inner_api_only
     @plugin_inner_api_only
     @get_tenant
     @get_tenant
     @plugin_data(payload_type=RequestInvokeTool)
     @plugin_data(payload_type=RequestInvokeTool)
-    def post(self, user_id: str, tenant_model: Tenant):
-        parser = reqparse.RequestParser()
-        parser.add_argument('provider', type=dict, required=True, location='json')
-        parser.add_argument('tool', type=dict, required=True, location='json')
-        parser.add_argument('parameters', type=dict, required=True, location='json')
-
-        args = parser.parse_args()
-
-        response = PluginInvokeService.invoke_tool(
-            user_id, tenant_model, args['provider'], args['tool'], args['parameters']
-        )
-        return compact_generate_response(response)
+    def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeTool):
+        def generator():
+            for i in range(10):
+                time.sleep(0.1)
+                yield (
+                    ToolInvokeMessage(
+                        type=ToolInvokeMessage.MessageType.TEXT,
+                        message=ToolInvokeMessage.TextMessage(text='helloworld'),
+                    )
+                    .model_dump_json()
+                    .encode()
+                    + b'\n\n'
+                )
+
+        return compact_generate_response(generator())
 
 
 
 
 class PluginInvokeNodeApi(Resource):
 class PluginInvokeNodeApi(Resource):

+ 2 - 2
api/core/app/features/rate_limiting/rate_limit.py

@@ -1,7 +1,7 @@
 import logging
 import logging
 import time
 import time
 import uuid
 import uuid
-from collections.abc import Generator
+from collections.abc import Callable, Generator
 from datetime import timedelta
 from datetime import timedelta
 from typing import Optional, Union
 from typing import Optional, Union
 
 
@@ -91,7 +91,7 @@ class RateLimit:
 
 
 
 
 class RateLimitGenerator:
 class RateLimitGenerator:
-    def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str):
+    def __init__(self, rate_limit: RateLimit, generator: Union[Generator, Callable[[], Generator]], request_id: str):
         self.rate_limit = rate_limit
         self.rate_limit = rate_limit
         if callable(generator):
         if callable(generator):
             self.generator = generator()
             self.generator = generator()

+ 7 - 1
api/core/tools/entities/tool_entities.py

@@ -90,6 +90,12 @@ class ApiProviderAuthType(Enum):
         raise ValueError(f'invalid mode value {value}')
         raise ValueError(f'invalid mode value {value}')
 
 
 class ToolInvokeMessage(BaseModel):
 class ToolInvokeMessage(BaseModel):
+    class TextMessage(BaseModel):
+        text: str
+
+    class JsonMessage(BaseModel):
+        json_object: dict
+
     class MessageType(Enum):
     class MessageType(Enum):
         TEXT = "text"
         TEXT = "text"
         IMAGE = "image"
         IMAGE = "image"
@@ -103,7 +109,7 @@ class ToolInvokeMessage(BaseModel):
     """
     """
         plain text, image url or link url
         plain text, image url or link url
     """
     """
-    message: Optional[Union[str, bytes, dict]] = None
+    message: JsonMessage | TextMessage
     meta: Optional[dict[str, Any]] = None
     meta: Optional[dict[str, Any]] = None
     save_as: str = ''
     save_as: str = ''
 
 

+ 36 - 36
api/libs/helper.py

@@ -36,8 +36,7 @@ def email(email):
     if re.match(pattern, email) is not None:
     if re.match(pattern, email) is not None:
         return email
         return email
 
 
-    error = ('{email} is not a valid email.'
-             .format(email=email))
+    error = '{email} is not a valid email.'.format(email=email)
     raise ValueError(error)
     raise ValueError(error)
 
 
 
 
@@ -49,10 +48,10 @@ def uuid_value(value):
         uuid_obj = uuid.UUID(value)
         uuid_obj = uuid.UUID(value)
         return str(uuid_obj)
         return str(uuid_obj)
     except ValueError:
     except ValueError:
-        error = ('{value} is not a valid uuid.'
-                 .format(value=value))
+        error = '{value} is not a valid uuid.'.format(value=value)
         raise ValueError(error)
         raise ValueError(error)
 
 
+
 def alphanumeric(value: str):
 def alphanumeric(value: str):
     # check if the value is alphanumeric and underlined
     # check if the value is alphanumeric and underlined
     if re.match(r'^[a-zA-Z0-9_]+$', value):
     if re.match(r'^[a-zA-Z0-9_]+$', value):
@@ -60,6 +59,7 @@ def alphanumeric(value: str):
 
 
     raise ValueError(f'{value} is not a valid alphanumeric value')
     raise ValueError(f'{value} is not a valid alphanumeric value')
 
 
+
 def timestamp_value(timestamp):
 def timestamp_value(timestamp):
     try:
     try:
         int_timestamp = int(timestamp)
         int_timestamp = int(timestamp)
@@ -67,13 +67,12 @@ def timestamp_value(timestamp):
             raise ValueError
             raise ValueError
         return int_timestamp
         return int_timestamp
     except ValueError:
     except ValueError:
-        error = ('{timestamp} is not a valid timestamp.'
-                 .format(timestamp=timestamp))
+        error = '{timestamp} is not a valid timestamp.'.format(timestamp=timestamp)
         raise ValueError(error)
         raise ValueError(error)
 
 
 
 
 class str_len:
 class str_len:
-    """ Restrict input to an integer in a range (inclusive) """
+    """Restrict input to an integer in a range (inclusive)"""
 
 
     def __init__(self, max_length, argument='argument'):
     def __init__(self, max_length, argument='argument'):
         self.max_length = max_length
         self.max_length = max_length
@@ -82,15 +81,17 @@ class str_len:
     def __call__(self, value):
     def __call__(self, value):
         length = len(value)
         length = len(value)
         if length > self.max_length:
         if length > self.max_length:
-            error = ('Invalid {arg}: {val}. {arg} cannot exceed length {length}'
-                     .format(arg=self.argument, val=value, length=self.max_length))
+            error = 'Invalid {arg}: {val}. {arg} cannot exceed length {length}'.format(
+                arg=self.argument, val=value, length=self.max_length
+            )
             raise ValueError(error)
             raise ValueError(error)
 
 
         return value
         return value
 
 
 
 
 class float_range:
 class float_range:
-    """ Restrict input to an float in a range (inclusive) """
+    """Restrict input to an float in a range (inclusive)"""
+
     def __init__(self, low, high, argument='argument'):
     def __init__(self, low, high, argument='argument'):
         self.low = low
         self.low = low
         self.high = high
         self.high = high
@@ -99,8 +100,9 @@ class float_range:
     def __call__(self, value):
     def __call__(self, value):
         value = _get_float(value)
         value = _get_float(value)
         if value < self.low or value > self.high:
         if value < self.low or value > self.high:
-            error = ('Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'
-                     .format(arg=self.argument, val=value, lo=self.low, hi=self.high))
+            error = 'Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}'.format(
+                arg=self.argument, val=value, lo=self.low, hi=self.high
+            )
             raise ValueError(error)
             raise ValueError(error)
 
 
         return value
         return value
@@ -115,8 +117,9 @@ class datetime_string:
         try:
         try:
             datetime.strptime(value, self.format)
             datetime.strptime(value, self.format)
         except ValueError:
         except ValueError:
-            error = ('Invalid {arg}: {val}. {arg} must be conform to the format {format}'
-                     .format(arg=self.argument, val=value, format=self.format))
+            error = 'Invalid {arg}: {val}. {arg} must be conform to the format {format}'.format(
+                arg=self.argument, val=value, format=self.format
+            )
             raise ValueError(error)
             raise ValueError(error)
 
 
         return value
         return value
@@ -128,18 +131,18 @@ def _get_float(value):
     except (TypeError, ValueError):
     except (TypeError, ValueError):
         raise ValueError('{} is not a valid float'.format(value))
         raise ValueError('{} is not a valid float'.format(value))
 
 
+
 def timezone(timezone_string):
 def timezone(timezone_string):
     if timezone_string and timezone_string in available_timezones():
     if timezone_string and timezone_string in available_timezones():
         return timezone_string
         return timezone_string
 
 
-    error = ('{timezone_string} is not a valid timezone.'
-             .format(timezone_string=timezone_string))
+    error = '{timezone_string} is not a valid timezone.'.format(timezone_string=timezone_string)
     raise ValueError(error)
     raise ValueError(error)
 
 
 
 
 def generate_string(n):
 def generate_string(n):
     letters_digits = string.ascii_letters + string.digits
     letters_digits = string.ascii_letters + string.digits
-    result = ""
+    result = ''
     for i in range(n):
     for i in range(n):
         result += random.choice(letters_digits)
         result += random.choice(letters_digits)
 
 
@@ -149,8 +152,8 @@ def generate_string(n):
 def get_remote_ip(request) -> str:
 def get_remote_ip(request) -> str:
     if request.headers.get('CF-Connecting-IP'):
     if request.headers.get('CF-Connecting-IP'):
         return request.headers.get('Cf-Connecting-Ip')
         return request.headers.get('Cf-Connecting-Ip')
-    elif request.headers.getlist("X-Forwarded-For"):
-        return request.headers.getlist("X-Forwarded-For")[0]
+    elif request.headers.getlist('X-Forwarded-For'):
+        return request.headers.getlist('X-Forwarded-For')[0]
     else:
     else:
         return request.remote_addr
         return request.remote_addr
 
 
@@ -160,19 +163,24 @@ def generate_text_hash(text: str) -> str:
     return sha256(hash_text.encode()).hexdigest()
     return sha256(hash_text.encode()).hexdigest()
 
 
 
 
-def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response:
+def compact_generate_response(response: Union[dict, Generator, RateLimitGenerator]) -> Response:
     if isinstance(response, dict):
     if isinstance(response, dict):
         return Response(response=json.dumps(response), status=200, mimetype='application/json')
         return Response(response=json.dumps(response), status=200, mimetype='application/json')
     else:
     else:
+
         def generate() -> Generator:
         def generate() -> Generator:
-            yield from response
+            for data in response:
+                if isinstance(data, dict):
+                    yield json.dumps(data).encode()
+                if isinstance(data, str):
+                    yield data.encode()
+                else:
+                    yield data
 
 
-        return Response(stream_with_context(generate()), status=200,
-                        mimetype='text/event-stream')
+        return Response(stream_with_context(generate()), status=200, mimetype='text/event-stream')
 
 
 
 
 class TokenManager:
 class TokenManager:
-
     @classmethod
     @classmethod
     def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
     def generate_token(cls, account: Account, token_type: str, additional_data: dict = None) -> str:
         old_token = cls._get_current_token_for_account(account.id, token_type)
         old_token = cls._get_current_token_for_account(account.id, token_type)
@@ -182,21 +190,13 @@ class TokenManager:
             cls.revoke_token(old_token, token_type)
             cls.revoke_token(old_token, token_type)
 
 
         token = str(uuid.uuid4())
         token = str(uuid.uuid4())
-        token_data = {
-            'account_id': account.id,
-            'email': account.email,
-            'token_type': token_type
-        }
+        token_data = {'account_id': account.id, 'email': account.email, 'token_type': token_type}
         if additional_data:
         if additional_data:
             token_data.update(additional_data)
             token_data.update(additional_data)
 
 
         expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
         expiry_hours = current_app.config[f'{token_type.upper()}_TOKEN_EXPIRY_HOURS']
         token_key = cls._get_token_key(token, token_type)
         token_key = cls._get_token_key(token, token_type)
-        redis_client.setex(
-            token_key,
-            expiry_hours * 60 * 60,
-            json.dumps(token_data)
-        )
+        redis_client.setex(token_key, expiry_hours * 60 * 60, json.dumps(token_data))
 
 
         cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
         cls._set_current_token_for_account(account.id, token, token_type, expiry_hours)
         return token
         return token
@@ -215,7 +215,7 @@ class TokenManager:
         key = cls._get_token_key(token, token_type)
         key = cls._get_token_key(token, token_type)
         token_data_json = redis_client.get(key)
         token_data_json = redis_client.get(key)
         if token_data_json is None:
         if token_data_json is None:
-            logging.warning(f"{token_type} token {token} not found with key {key}")
+            logging.warning(f'{token_type} token {token} not found with key {key}')
             return None
             return None
         token_data = json.loads(token_data_json)
         token_data = json.loads(token_data_json)
         return token_data
         return token_data
@@ -243,7 +243,7 @@ class RateLimiter:
         self.time_window = time_window
         self.time_window = time_window
 
 
     def _get_key(self, email: str) -> str:
     def _get_key(self, email: str) -> str:
-        return f"{self.prefix}:{email}"
+        return f'{self.prefix}:{email}'
 
 
     def is_rate_limited(self, email: str) -> bool:
     def is_rate_limited(self, email: str) -> bool:
         key = self._get_key(email)
         key = self._get_key(email)