| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297 | 
							- from typing import Optional
 
- from langchain.callbacks.base import Callbacks
 
- from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
 
- from core.model_providers.model_provider_factory import ModelProviderFactory, DEFAULT_MODELS
 
- from core.model_providers.models.base import BaseProviderModel
 
- from core.model_providers.models.embedding.base import BaseEmbedding
 
- from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
 
- from core.model_providers.models.llm.base import BaseLLM
 
- from core.model_providers.models.moderation.base import BaseModeration
 
- from core.model_providers.models.speech2text.base import BaseSpeech2Text
 
- from extensions.ext_database import db
 
- from models.provider import TenantDefaultModel
 
- class ModelFactory:
 
-     @classmethod
 
-     def get_text_generation_model_from_model_config(cls, tenant_id: str,
 
-                                                     model_config: dict,
 
-                                                     streaming: bool = False,
 
-                                                     callbacks: Callbacks = None) -> Optional[BaseLLM]:
 
-         provider_name = model_config.get("provider")
 
-         model_name = model_config.get("name")
 
-         completion_params = model_config.get("completion_params", {})
 
-         return cls.get_text_generation_model(
 
-             tenant_id=tenant_id,
 
-             model_provider_name=provider_name,
 
-             model_name=model_name,
 
-             model_kwargs=ModelKwargs(
 
-                 temperature=completion_params.get('temperature', 0),
 
-                 max_tokens=completion_params.get('max_tokens', 256),
 
-                 top_p=completion_params.get('top_p', 0),
 
-                 frequency_penalty=completion_params.get('frequency_penalty', 0.1),
 
-                 presence_penalty=completion_params.get('presence_penalty', 0.1)
 
-             ),
 
-             streaming=streaming,
 
-             callbacks=callbacks
 
-         )
 
-     @classmethod
 
-     def get_text_generation_model(cls,
 
-                                   tenant_id: str,
 
-                                   model_provider_name: Optional[str] = None,
 
-                                   model_name: Optional[str] = None,
 
-                                   model_kwargs: Optional[ModelKwargs] = None,
 
-                                   streaming: bool = False,
 
-                                   callbacks: Callbacks = None,
 
-                                   deduct_quota: bool = True) -> Optional[BaseLLM]:
 
-         """
 
-         get text generation model.
 
-         :param tenant_id: a string representing the ID of the tenant.
 
-         :param model_provider_name:
 
-         :param model_name:
 
-         :param model_kwargs:
 
-         :param streaming:
 
-         :param callbacks:
 
-         :param deduct_quota:
 
-         :return:
 
-         """
 
-         is_default_model = False
 
-         if model_provider_name is None and model_name is None:
 
-             default_model = cls.get_default_model(tenant_id, ModelType.TEXT_GENERATION)
 
-             if not default_model:
 
-                 raise LLMBadRequestError(f"Default model is not available. "
 
-                                          f"Please configure a Default System Reasoning Model "
 
-                                          f"in the Settings -> Model Provider.")
 
-             model_provider_name = default_model.provider_name
 
-             model_name = default_model.model_name
 
-             is_default_model = True
 
-         # get model provider
 
-         model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
 
-         if not model_provider:
 
-             raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
 
-         # init text generation model
 
-         model_class = model_provider.get_model_class(model_type=ModelType.TEXT_GENERATION)
 
-         try:
 
-             model_instance = model_class(
 
-                 model_provider=model_provider,
 
-                 name=model_name,
 
-                 model_kwargs=model_kwargs,
 
-                 streaming=streaming,
 
-                 callbacks=callbacks
 
-             )
 
-         except LLMBadRequestError as e:
 
-             if is_default_model:
 
-                 raise LLMBadRequestError(f"Default model {model_name} is not available. "
 
-                                          f"Please check your model provider credentials.")
 
-             else:
 
-                 raise e
 
-         if is_default_model or not deduct_quota:
 
-             model_instance.deduct_quota = False
 
-         return model_instance
 
-     @classmethod
 
-     def get_embedding_model(cls,
 
-                             tenant_id: str,
 
-                             model_provider_name: Optional[str] = None,
 
-                             model_name: Optional[str] = None) -> Optional[BaseEmbedding]:
 
-         """
 
-         get embedding model.
 
-         :param tenant_id: a string representing the ID of the tenant.
 
-         :param model_provider_name:
 
-         :param model_name:
 
-         :return:
 
-         """
 
-         if model_provider_name is None and model_name is None:
 
-             default_model = cls.get_default_model(tenant_id, ModelType.EMBEDDINGS)
 
-             if not default_model:
 
-                 raise LLMBadRequestError(f"Default model is not available. "
 
-                                          f"Please configure a Default Embedding Model "
 
-                                          f"in the Settings -> Model Provider.")
 
-             model_provider_name = default_model.provider_name
 
-             model_name = default_model.model_name
 
-         # get model provider
 
-         model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
 
-         if not model_provider:
 
-             raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
 
-         # init embedding model
 
-         model_class = model_provider.get_model_class(model_type=ModelType.EMBEDDINGS)
 
-         return model_class(
 
-             model_provider=model_provider,
 
-             name=model_name
 
-         )
 
-     @classmethod
 
-     def get_speech2text_model(cls,
 
-                               tenant_id: str,
 
-                               model_provider_name: Optional[str] = None,
 
-                               model_name: Optional[str] = None) -> Optional[BaseSpeech2Text]:
 
-         """
 
-         get speech to text model.
 
-         :param tenant_id: a string representing the ID of the tenant.
 
-         :param model_provider_name:
 
-         :param model_name:
 
-         :return:
 
-         """
 
-         if model_provider_name is None and model_name is None:
 
-             default_model = cls.get_default_model(tenant_id, ModelType.SPEECH_TO_TEXT)
 
-             if not default_model:
 
-                 raise LLMBadRequestError(f"Default model is not available. "
 
-                                          f"Please configure a Default Speech-to-Text Model "
 
-                                          f"in the Settings -> Model Provider.")
 
-             model_provider_name = default_model.provider_name
 
-             model_name = default_model.model_name
 
-         # get model provider
 
-         model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
 
-         if not model_provider:
 
-             raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
 
-         # init speech to text model
 
-         model_class = model_provider.get_model_class(model_type=ModelType.SPEECH_TO_TEXT)
 
-         return model_class(
 
-             model_provider=model_provider,
 
-             name=model_name
 
-         )
 
-     @classmethod
 
-     def get_moderation_model(cls,
 
-                              tenant_id: str,
 
-                              model_provider_name: str,
 
-                              model_name: str) -> Optional[BaseModeration]:
 
-         """
 
-         get moderation model.
 
-         :param tenant_id: a string representing the ID of the tenant.
 
-         :param model_provider_name:
 
-         :param model_name:
 
-         :return:
 
-         """
 
-         # get model provider
 
-         model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
 
-         if not model_provider:
 
-             raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
 
-         # init moderation model
 
-         model_class = model_provider.get_model_class(model_type=ModelType.MODERATION)
 
-         return model_class(
 
-             model_provider=model_provider,
 
-             name=model_name
 
-         )
 
-     @classmethod
 
-     def get_default_model(cls, tenant_id: str, model_type: ModelType) -> TenantDefaultModel:
 
-         """
 
-         get default model of model type.
 
-         :param tenant_id:
 
-         :param model_type:
 
-         :return:
 
-         """
 
-         # get default model
 
-         default_model = db.session.query(TenantDefaultModel) \
 
-             .filter(
 
-             TenantDefaultModel.tenant_id == tenant_id,
 
-             TenantDefaultModel.model_type == model_type.value
 
-         ).first()
 
-         if not default_model:
 
-             model_provider_rules = ModelProviderFactory.get_provider_rules()
 
-             for model_provider_name, model_provider_rule in model_provider_rules.items():
 
-                 model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, model_provider_name)
 
-                 if not model_provider:
 
-                     continue
 
-                 model_list = model_provider.get_supported_model_list(model_type)
 
-                 if model_list:
 
-                     model_info = model_list[0]
 
-                     default_model = TenantDefaultModel(
 
-                         tenant_id=tenant_id,
 
-                         model_type=model_type.value,
 
-                         provider_name=model_provider_name,
 
-                         model_name=model_info['id']
 
-                     )
 
-                     db.session.add(default_model)
 
-                     db.session.commit()
 
-                     break
 
-         return default_model
 
-     @classmethod
 
-     def update_default_model(cls,
 
-                              tenant_id: str,
 
-                              model_type: ModelType,
 
-                              provider_name: str,
 
-                              model_name: str) -> TenantDefaultModel:
 
-         """
 
-         update default model of model type.
 
-         :param tenant_id:
 
-         :param model_type:
 
-         :param provider_name:
 
-         :param model_name:
 
-         :return:
 
-         """
 
-         model_provider_name = ModelProviderFactory.get_provider_names()
 
-         if provider_name not in model_provider_name:
 
-             raise ValueError(f'Invalid provider name: {provider_name}')
 
-         model_provider = ModelProviderFactory.get_preferred_model_provider(tenant_id, provider_name)
 
-         if not model_provider:
 
-             raise ProviderTokenNotInitError(f"Model {model_name} provider credentials is not initialized.")
 
-         model_list = model_provider.get_supported_model_list(model_type)
 
-         model_ids = [model['id'] for model in model_list]
 
-         if model_name not in model_ids:
 
-             raise ValueError(f'Invalid model name: {model_name}')
 
-         # get default model
 
-         default_model = db.session.query(TenantDefaultModel) \
 
-             .filter(
 
-             TenantDefaultModel.tenant_id == tenant_id,
 
-             TenantDefaultModel.model_type == model_type.value
 
-         ).first()
 
-         if default_model:
 
-             # update default model
 
-             default_model.provider_name = provider_name
 
-             default_model.model_name = model_name
 
-             db.session.commit()
 
-         else:
 
-             # create default model
 
-             default_model = TenantDefaultModel(
 
-                 tenant_id=tenant_id,
 
-                 model_type=model_type.value,
 
-                 provider_name=provider_name,
 
-                 model_name=model_name,
 
-             )
 
-             db.session.add(default_model)
 
-             db.session.commit()
 
-         return default_model
 
 
  |