model_entities.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. from decimal import Decimal
  2. from enum import Enum, StrEnum
  3. from typing import Any, Optional
  4. from pydantic import BaseModel, ConfigDict
  5. from core.model_runtime.entities.common_entities import I18nObject
  6. class ModelType(Enum):
  7. """
  8. Enum class for model type.
  9. """
  10. LLM = "llm"
  11. TEXT_EMBEDDING = "text-embedding"
  12. RERANK = "rerank"
  13. SPEECH2TEXT = "speech2text"
  14. MODERATION = "moderation"
  15. TTS = "tts"
  16. @classmethod
  17. def value_of(cls, origin_model_type: str) -> "ModelType":
  18. """
  19. Get model type from origin model type.
  20. :return: model type
  21. """
  22. if origin_model_type in {"text-generation", cls.LLM.value}:
  23. return cls.LLM
  24. elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING.value}:
  25. return cls.TEXT_EMBEDDING
  26. elif origin_model_type in {"reranking", cls.RERANK.value}:
  27. return cls.RERANK
  28. elif origin_model_type in {"speech2text", cls.SPEECH2TEXT.value}:
  29. return cls.SPEECH2TEXT
  30. elif origin_model_type in {"tts", cls.TTS.value}:
  31. return cls.TTS
  32. elif origin_model_type == cls.MODERATION.value:
  33. return cls.MODERATION
  34. else:
  35. raise ValueError(f"invalid origin model type {origin_model_type}")
  36. def to_origin_model_type(self) -> str:
  37. """
  38. Get origin model type from model type.
  39. :return: origin model type
  40. """
  41. if self == self.LLM:
  42. return "text-generation"
  43. elif self == self.TEXT_EMBEDDING:
  44. return "embeddings"
  45. elif self == self.RERANK:
  46. return "reranking"
  47. elif self == self.SPEECH2TEXT:
  48. return "speech2text"
  49. elif self == self.TTS:
  50. return "tts"
  51. elif self == self.MODERATION:
  52. return "moderation"
  53. else:
  54. raise ValueError(f"invalid model type {self}")
  55. class FetchFrom(Enum):
  56. """
  57. Enum class for fetch from.
  58. """
  59. PREDEFINED_MODEL = "predefined-model"
  60. CUSTOMIZABLE_MODEL = "customizable-model"
  61. class ModelFeature(Enum):
  62. """
  63. Enum class for llm feature.
  64. """
  65. TOOL_CALL = "tool-call"
  66. MULTI_TOOL_CALL = "multi-tool-call"
  67. AGENT_THOUGHT = "agent-thought"
  68. VISION = "vision"
  69. STREAM_TOOL_CALL = "stream-tool-call"
  70. DOCUMENT = "document"
  71. VIDEO = "video"
  72. AUDIO = "audio"
  73. class DefaultParameterName(StrEnum):
  74. """
  75. Enum class for parameter template variable.
  76. """
  77. TEMPERATURE = "temperature"
  78. TOP_P = "top_p"
  79. TOP_K = "top_k"
  80. PRESENCE_PENALTY = "presence_penalty"
  81. FREQUENCY_PENALTY = "frequency_penalty"
  82. MAX_TOKENS = "max_tokens"
  83. RESPONSE_FORMAT = "response_format"
  84. JSON_SCHEMA = "json_schema"
  85. @classmethod
  86. def value_of(cls, value: Any) -> "DefaultParameterName":
  87. """
  88. Get parameter name from value.
  89. :param value: parameter value
  90. :return: parameter name
  91. """
  92. for name in cls:
  93. if name.value == value:
  94. return name
  95. raise ValueError(f"invalid parameter name {value}")
  96. class ParameterType(Enum):
  97. """
  98. Enum class for parameter type.
  99. """
  100. FLOAT = "float"
  101. INT = "int"
  102. STRING = "string"
  103. BOOLEAN = "boolean"
  104. TEXT = "text"
  105. class ModelPropertyKey(Enum):
  106. """
  107. Enum class for model property key.
  108. """
  109. MODE = "mode"
  110. CONTEXT_SIZE = "context_size"
  111. MAX_CHUNKS = "max_chunks"
  112. FILE_UPLOAD_LIMIT = "file_upload_limit"
  113. SUPPORTED_FILE_EXTENSIONS = "supported_file_extensions"
  114. MAX_CHARACTERS_PER_CHUNK = "max_characters_per_chunk"
  115. DEFAULT_VOICE = "default_voice"
  116. VOICES = "voices"
  117. WORD_LIMIT = "word_limit"
  118. AUDIO_TYPE = "audio_type"
  119. MAX_WORKERS = "max_workers"
  120. class ProviderModel(BaseModel):
  121. """
  122. Model class for provider model.
  123. """
  124. model: str
  125. label: I18nObject
  126. model_type: ModelType
  127. features: Optional[list[ModelFeature]] = None
  128. fetch_from: FetchFrom
  129. model_properties: dict[ModelPropertyKey, Any]
  130. deprecated: bool = False
  131. model_config = ConfigDict(protected_namespaces=())
  132. class ParameterRule(BaseModel):
  133. """
  134. Model class for parameter rule.
  135. """
  136. name: str
  137. use_template: Optional[str] = None
  138. label: I18nObject
  139. type: ParameterType
  140. help: Optional[I18nObject] = None
  141. required: bool = False
  142. default: Optional[Any] = None
  143. min: Optional[float] = None
  144. max: Optional[float] = None
  145. precision: Optional[int] = None
  146. options: list[str] = []
  147. class PriceConfig(BaseModel):
  148. """
  149. Model class for pricing info.
  150. """
  151. input: Decimal
  152. output: Optional[Decimal] = None
  153. unit: Decimal
  154. currency: str
  155. class AIModelEntity(ProviderModel):
  156. """
  157. Model class for AI model.
  158. """
  159. parameter_rules: list[ParameterRule] = []
  160. pricing: Optional[PriceConfig] = None
  161. class ModelUsage(BaseModel):
  162. pass
  163. class PriceType(Enum):
  164. """
  165. Enum class for price type.
  166. """
  167. INPUT = "input"
  168. OUTPUT = "output"
  169. class PriceInfo(BaseModel):
  170. """
  171. Model class for price info.
  172. """
  173. unit_price: Decimal
  174. unit: Decimal
  175. total_amount: Decimal
  176. currency: str