provider_entities.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. from enum import Enum
  2. from typing import Optional, Union
  3. from pydantic import BaseModel, ConfigDict, Field
  4. from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
  5. from core.model_runtime.entities.model_entities import ModelType
  6. from core.tools.entities.common_entities import I18nObject
  7. class ProviderQuotaType(Enum):
  8. PAID = "paid"
  9. """hosted paid quota"""
  10. FREE = "free"
  11. """third-party free quota"""
  12. TRIAL = "trial"
  13. """hosted trial quota"""
  14. @staticmethod
  15. def value_of(value):
  16. for member in ProviderQuotaType:
  17. if member.value == value:
  18. return member
  19. raise ValueError(f"No matching enum found for value '{value}'")
  20. class QuotaUnit(Enum):
  21. TIMES = "times"
  22. TOKENS = "tokens"
  23. CREDITS = "credits"
  24. class SystemConfigurationStatus(Enum):
  25. """
  26. Enum class for system configuration status.
  27. """
  28. ACTIVE = "active"
  29. QUOTA_EXCEEDED = "quota-exceeded"
  30. UNSUPPORTED = "unsupported"
  31. class RestrictModel(BaseModel):
  32. model: str
  33. base_model_name: Optional[str] = None
  34. model_type: ModelType
  35. # pydantic configs
  36. model_config = ConfigDict(protected_namespaces=())
  37. class QuotaConfiguration(BaseModel):
  38. """
  39. Model class for provider quota configuration.
  40. """
  41. quota_type: ProviderQuotaType
  42. quota_unit: QuotaUnit
  43. quota_limit: int
  44. quota_used: int
  45. is_valid: bool
  46. restrict_models: list[RestrictModel] = []
  47. class SystemConfiguration(BaseModel):
  48. """
  49. Model class for provider system configuration.
  50. """
  51. enabled: bool
  52. current_quota_type: Optional[ProviderQuotaType] = None
  53. quota_configurations: list[QuotaConfiguration] = []
  54. credentials: Optional[dict] = None
  55. class CustomProviderConfiguration(BaseModel):
  56. """
  57. Model class for provider custom configuration.
  58. """
  59. credentials: dict
  60. class CustomModelConfiguration(BaseModel):
  61. """
  62. Model class for provider custom model configuration.
  63. """
  64. model: str
  65. model_type: ModelType
  66. credentials: dict
  67. # pydantic configs
  68. model_config = ConfigDict(protected_namespaces=())
  69. class CustomConfiguration(BaseModel):
  70. """
  71. Model class for provider custom configuration.
  72. """
  73. provider: Optional[CustomProviderConfiguration] = None
  74. models: list[CustomModelConfiguration] = []
  75. class ModelLoadBalancingConfiguration(BaseModel):
  76. """
  77. Class for model load balancing configuration.
  78. """
  79. id: str
  80. name: str
  81. credentials: dict
  82. class ModelSettings(BaseModel):
  83. """
  84. Model class for model settings.
  85. """
  86. model: str
  87. model_type: ModelType
  88. enabled: bool = True
  89. load_balancing_configs: list[ModelLoadBalancingConfiguration] = []
  90. # pydantic configs
  91. model_config = ConfigDict(protected_namespaces=())
  92. class BasicProviderConfig(BaseModel):
  93. """
  94. Base model class for common provider settings like credentials
  95. """
  96. class Type(Enum):
  97. SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
  98. TEXT_INPUT = CommonParameterType.TEXT_INPUT.value
  99. SELECT = CommonParameterType.SELECT.value
  100. BOOLEAN = CommonParameterType.BOOLEAN.value
  101. APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
  102. MODEL_CONFIG = CommonParameterType.MODEL_CONFIG.value
  103. @classmethod
  104. def value_of(cls, value: str) -> "ProviderConfig.Type":
  105. """
  106. Get value of given mode.
  107. :param value: mode value
  108. :return: mode
  109. """
  110. for mode in cls:
  111. if mode.value == value:
  112. return mode
  113. raise ValueError(f"invalid mode value {value}")
  114. type: Type = Field(..., description="The type of the credentials")
  115. name: str = Field(..., description="The name of the credentials")
  116. class ProviderConfig(BasicProviderConfig):
  117. """
  118. Model class for common provider settings like credentials
  119. """
  120. class Option(BaseModel):
  121. value: str = Field(..., description="The value of the option")
  122. label: I18nObject = Field(..., description="The label of the option")
  123. scope: AppSelectorScope | ModelConfigScope | None = None
  124. required: bool = False
  125. default: Optional[Union[int, str]] = None
  126. options: Optional[list[Option]] = None
  127. label: Optional[I18nObject] = None
  128. help: Optional[I18nObject] = None
  129. url: Optional[str] = None
  130. placeholder: Optional[I18nObject] = None
  131. def to_basic_provider_config(self) -> BasicProviderConfig:
  132. return BasicProviderConfig(type=self.type, name=self.name)