123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218 |
- from collections.abc import Sequence
- from enum import Enum, StrEnum
- from typing import Optional
- from pydantic import BaseModel, Field, field_validator
- class PromptMessageRole(Enum):
- """
- Enum class for prompt message.
- """
- SYSTEM = "system"
- USER = "user"
- ASSISTANT = "assistant"
- TOOL = "tool"
- @classmethod
- def value_of(cls, value: str) -> "PromptMessageRole":
- """
- Get value of given mode.
- :param value: mode value
- :return: mode
- """
- for mode in cls:
- if mode.value == value:
- return mode
- raise ValueError(f"invalid prompt message type value {value}")
- class PromptMessageTool(BaseModel):
- """
- Model class for prompt message tool.
- """
- name: str
- description: str
- parameters: dict
- class PromptMessageFunction(BaseModel):
- """
- Model class for prompt message function.
- """
- type: str = "function"
- function: PromptMessageTool
- class PromptMessageContentType(StrEnum):
- """
- Enum class for prompt message content type.
- """
- TEXT = "text"
- IMAGE = "image"
- AUDIO = "audio"
- VIDEO = "video"
- DOCUMENT = "document"
- class PromptMessageContent(BaseModel):
- """
- Model class for prompt message content.
- """
- type: PromptMessageContentType
- class TextPromptMessageContent(PromptMessageContent):
- """
- Model class for text prompt message content.
- """
- type: PromptMessageContentType = PromptMessageContentType.TEXT
- data: str
- class MultiModalPromptMessageContent(PromptMessageContent):
- """
- Model class for multi-modal prompt message content.
- """
- type: PromptMessageContentType
- format: str = Field(default=..., description="the format of multi-modal file")
- base64_data: str = Field(default="", description="the base64 data of multi-modal file")
- url: str = Field(default="", description="the url of multi-modal file")
- mime_type: str = Field(default=..., description="the mime type of multi-modal file")
- @property
- def data(self):
- return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
- class VideoPromptMessageContent(MultiModalPromptMessageContent):
- type: PromptMessageContentType = PromptMessageContentType.VIDEO
- class AudioPromptMessageContent(MultiModalPromptMessageContent):
- type: PromptMessageContentType = PromptMessageContentType.AUDIO
- class ImagePromptMessageContent(MultiModalPromptMessageContent):
- """
- Model class for image prompt message content.
- """
- class DETAIL(StrEnum):
- LOW = "low"
- HIGH = "high"
- type: PromptMessageContentType = PromptMessageContentType.IMAGE
- detail: DETAIL = DETAIL.LOW
- class DocumentPromptMessageContent(MultiModalPromptMessageContent):
- type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
- class PromptMessage(BaseModel):
- """
- Model class for prompt message.
- """
- role: PromptMessageRole
- content: Optional[str | Sequence[PromptMessageContent]] = None
- name: Optional[str] = None
- def is_empty(self) -> bool:
- """
- Check if prompt message is empty.
- :return: True if prompt message is empty, False otherwise
- """
- return not self.content
- class UserPromptMessage(PromptMessage):
- """
- Model class for user prompt message.
- """
- role: PromptMessageRole = PromptMessageRole.USER
- class AssistantPromptMessage(PromptMessage):
- """
- Model class for assistant prompt message.
- """
- class ToolCall(BaseModel):
- """
- Model class for assistant prompt message tool call.
- """
- class ToolCallFunction(BaseModel):
- """
- Model class for assistant prompt message tool call function.
- """
- name: str
- arguments: str
- id: str
- type: str
- function: ToolCallFunction
- @field_validator("id", mode="before")
- @classmethod
- def transform_id_to_str(cls, value) -> str:
- if not isinstance(value, str):
- return str(value)
- else:
- return value
- role: PromptMessageRole = PromptMessageRole.ASSISTANT
- tool_calls: list[ToolCall] = []
- def is_empty(self) -> bool:
- """
- Check if prompt message is empty.
- :return: True if prompt message is empty, False otherwise
- """
- if not super().is_empty() and not self.tool_calls:
- return False
- return True
- class SystemPromptMessage(PromptMessage):
- """
- Model class for system prompt message.
- """
- role: PromptMessageRole = PromptMessageRole.SYSTEM
- class ToolPromptMessage(PromptMessage):
- """
- Model class for tool prompt message.
- """
- role: PromptMessageRole = PromptMessageRole.TOOL
- tool_call_id: str
- def is_empty(self) -> bool:
- """
- Check if prompt message is empty.
- :return: True if prompt message is empty, False otherwise
- """
- if not super().is_empty() and not self.tool_call_id:
- return False
- return True
|