message_entities.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. from abc import ABC
  2. from collections.abc import Sequence
  3. from enum import Enum, StrEnum
  4. from typing import Literal, Optional
  5. from pydantic import BaseModel, Field, field_validator
  6. class PromptMessageRole(Enum):
  7. """
  8. Enum class for prompt message.
  9. """
  10. SYSTEM = "system"
  11. USER = "user"
  12. ASSISTANT = "assistant"
  13. TOOL = "tool"
  14. @classmethod
  15. def value_of(cls, value: str) -> "PromptMessageRole":
  16. """
  17. Get value of given mode.
  18. :param value: mode value
  19. :return: mode
  20. """
  21. for mode in cls:
  22. if mode.value == value:
  23. return mode
  24. raise ValueError(f"invalid prompt message type value {value}")
  25. class PromptMessageTool(BaseModel):
  26. """
  27. Model class for prompt message tool.
  28. """
  29. name: str
  30. description: str
  31. parameters: dict
  32. class PromptMessageFunction(BaseModel):
  33. """
  34. Model class for prompt message function.
  35. """
  36. type: str = "function"
  37. function: PromptMessageTool
  38. class PromptMessageContentType(StrEnum):
  39. """
  40. Enum class for prompt message content type.
  41. """
  42. TEXT = "text"
  43. IMAGE = "image"
  44. AUDIO = "audio"
  45. VIDEO = "video"
  46. DOCUMENT = "document"
  47. class PromptMessageContent(BaseModel):
  48. """
  49. Model class for prompt message content.
  50. """
  51. type: PromptMessageContentType
  52. data: str
  53. class TextPromptMessageContent(PromptMessageContent):
  54. """
  55. Model class for text prompt message content.
  56. """
  57. type: PromptMessageContentType = PromptMessageContentType.TEXT
  58. class VideoPromptMessageContent(PromptMessageContent):
  59. type: PromptMessageContentType = PromptMessageContentType.VIDEO
  60. data: str = Field(..., description="Base64 encoded video data")
  61. format: str = Field(..., description="Video format")
  62. class AudioPromptMessageContent(PromptMessageContent):
  63. type: PromptMessageContentType = PromptMessageContentType.AUDIO
  64. data: str = Field(..., description="Base64 encoded audio data")
  65. format: str = Field(..., description="Audio format")
  66. class ImagePromptMessageContent(PromptMessageContent):
  67. """
  68. Model class for image prompt message content.
  69. """
  70. class DETAIL(StrEnum):
  71. LOW = "low"
  72. HIGH = "high"
  73. type: PromptMessageContentType = PromptMessageContentType.IMAGE
  74. detail: DETAIL = DETAIL.LOW
  75. format: str = Field("jpg", description="Image format")
  76. class DocumentPromptMessageContent(PromptMessageContent):
  77. type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
  78. encode_format: Literal["base64"]
  79. data: str
  80. format: str = Field(..., description="Document format")
  81. class PromptMessage(ABC, BaseModel):
  82. """
  83. Model class for prompt message.
  84. """
  85. role: PromptMessageRole
  86. content: Optional[str | Sequence[PromptMessageContent]] = None
  87. name: Optional[str] = None
  88. def is_empty(self) -> bool:
  89. """
  90. Check if prompt message is empty.
  91. :return: True if prompt message is empty, False otherwise
  92. """
  93. return not self.content
  94. class UserPromptMessage(PromptMessage):
  95. """
  96. Model class for user prompt message.
  97. """
  98. role: PromptMessageRole = PromptMessageRole.USER
  99. class AssistantPromptMessage(PromptMessage):
  100. """
  101. Model class for assistant prompt message.
  102. """
  103. class ToolCall(BaseModel):
  104. """
  105. Model class for assistant prompt message tool call.
  106. """
  107. class ToolCallFunction(BaseModel):
  108. """
  109. Model class for assistant prompt message tool call function.
  110. """
  111. name: str
  112. arguments: str
  113. id: str
  114. type: str
  115. function: ToolCallFunction
  116. @field_validator("id", mode="before")
  117. @classmethod
  118. def transform_id_to_str(cls, value) -> str:
  119. if not isinstance(value, str):
  120. return str(value)
  121. else:
  122. return value
  123. role: PromptMessageRole = PromptMessageRole.ASSISTANT
  124. tool_calls: list[ToolCall] = []
  125. def is_empty(self) -> bool:
  126. """
  127. Check if prompt message is empty.
  128. :return: True if prompt message is empty, False otherwise
  129. """
  130. if not super().is_empty() and not self.tool_calls:
  131. return False
  132. return True
  133. class SystemPromptMessage(PromptMessage):
  134. """
  135. Model class for system prompt message.
  136. """
  137. role: PromptMessageRole = PromptMessageRole.SYSTEM
  138. class ToolPromptMessage(PromptMessage):
  139. """
  140. Model class for tool prompt message.
  141. """
  142. role: PromptMessageRole = PromptMessageRole.TOOL
  143. tool_call_id: str
  144. def is_empty(self) -> bool:
  145. """
  146. Check if prompt message is empty.
  147. :return: True if prompt message is empty, False otherwise
  148. """
  149. if not super().is_empty() and not self.tool_call_id:
  150. return False
  151. return True