|
@@ -1,7 +1,8 @@
|
|
|
+import base64
|
|
|
from enum import Enum
|
|
|
from typing import Any, Optional, Union, cast
|
|
|
|
|
|
-from pydantic import BaseModel, Field, field_validator
|
|
|
+from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
|
|
|
|
from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
|
|
|
from core.tools.entities.common_entities import I18nObject
|
|
@@ -100,6 +101,26 @@ class ToolInvokeMessage(BaseModel):
|
|
|
class BlobMessage(BaseModel):
|
|
|
blob: bytes
|
|
|
|
|
|
+ class VariableMessage(BaseModel):
|
|
|
+ variable_name: str = Field(..., description="The name of the variable")
|
|
|
+ variable_value: str = Field(..., description="The value of the variable")
|
|
|
+ stream: bool = Field(default=False, description="Whether the variable is streamed")
|
|
|
+
|
|
|
+ @field_validator("variable_value", mode="before")
|
|
|
+ def transform_variable_value(cls, value, values) -> Any:
|
|
|
+ """
|
|
|
+ Only basic types and lists are allowed.
|
|
|
+ """
|
|
|
+ if not isinstance(value, dict | list | str | int | float | bool):
|
|
|
+ raise ValueError("Only basic types and lists are allowed.")
|
|
|
+
|
|
|
+ # if stream is true, the value must be a string
|
|
|
+ if values.get('stream'):
|
|
|
+ if not isinstance(value, str):
|
|
|
+ raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
|
|
+
|
|
|
+ return value
|
|
|
+
|
|
|
class MessageType(Enum):
|
|
|
TEXT = "text"
|
|
|
IMAGE = "image"
|
|
@@ -108,15 +129,34 @@ class ToolInvokeMessage(BaseModel):
|
|
|
JSON = "json"
|
|
|
IMAGE_LINK = "image_link"
|
|
|
FILE_VAR = "file_var"
|
|
|
+ VARIABLE = "variable"
|
|
|
|
|
|
type: MessageType = MessageType.TEXT
|
|
|
"""
|
|
|
plain text, image url or link url
|
|
|
"""
|
|
|
- message: JsonMessage | TextMessage | BlobMessage | None
|
|
|
+ message: JsonMessage | TextMessage | BlobMessage | VariableMessage | None
|
|
|
meta: dict[str, Any] | None = None
|
|
|
save_as: str = ''
|
|
|
|
|
|
+ @field_validator('message', mode='before')
|
|
|
+ @classmethod
|
|
|
+ def decode_blob_message(cls, v):
|
|
|
+ if isinstance(v, dict) and 'blob' in v:
|
|
|
+ try:
|
|
|
+ v['blob'] = base64.b64decode(v['blob'])
|
|
|
+ except Exception:
|
|
|
+ pass
|
|
|
+ return v
|
|
|
+
|
|
|
+ @field_serializer('message')
|
|
|
+ def serialize_message(self, v):
|
|
|
+ if isinstance(v, self.BlobMessage):
|
|
|
+ return {
|
|
|
+ 'blob': base64.b64encode(v.blob).decode('utf-8')
|
|
|
+ }
|
|
|
+ return v
|
|
|
+
|
|
|
class ToolInvokeMessageBinary(BaseModel):
|
|
|
mimetype: str = Field(..., description="The mimetype of the binary")
|
|
|
url: str = Field(..., description="The url of the binary")
|