| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350 | from abc import abstractmethodfrom typing import List, Optional, Any, Unionimport decimalfrom langchain.callbacks.manager import Callbacksfrom langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGenerationfrom core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandlerfrom core.model_providers.models.base import BaseProviderModelfrom core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResultfrom core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRulesfrom core.model_providers.providers.base import BaseModelProviderfrom core.third_party.langchain.llms.fake import FakeLLMimport logginglogger = logging.getLogger(__name__)class BaseLLM(BaseProviderModel):    model_mode: ModelMode = ModelMode.COMPLETION    name: str    model_kwargs: ModelKwargs    credentials: dict    streaming: bool = False    type: ModelType = ModelType.TEXT_GENERATION    deduct_quota: bool = True    def __init__(self, model_provider: BaseModelProvider,                 name: str,                 model_kwargs: ModelKwargs,                 streaming: bool = False,                 callbacks: Callbacks = None):        self.name = name        self.model_rules = model_provider.get_model_parameter_rules(name, self.type)        self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(            max_tokens=None,            temperature=None,            top_p=None,            presence_penalty=None,            frequency_penalty=None        )        self.credentials = model_provider.get_model_credentials(            model_name=name,            model_type=self.type        )        self.streaming = streaming        if streaming:            default_callback = DifyStreamingStdOutCallbackHandler()        else:            default_callback = DifyStdOutCallbackHandler()        if not callbacks:            callbacks = [default_callback]        else:            callbacks.append(default_callback)        self.callbacks = callbacks        client = self._init_client()        super().__init__(model_provider, client)    @abstractmethod    def _init_client(self) -> Any:        raise NotImplementedError    @property    def base_model_name(self) -> str:        """        get llm base model name        :return: str        """        return self.name    @property    def price_config(self) -> dict:        def get_or_default():            default_price_config = {                    'prompt': decimal.Decimal('0'),                    'completion': decimal.Decimal('0'),                    'unit': decimal.Decimal('0'),                    'currency': 'USD'                }            rules = self.model_provider.get_rules()            price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config            price_config = {                'prompt': decimal.Decimal(price_config['prompt']),                'completion': decimal.Decimal(price_config['completion']),                'unit': decimal.Decimal(price_config['unit']),                'currency': price_config['currency']            }            return price_config                self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()        logger.debug(f"model: {self.name} price_config: {self._price_config}")        return self._price_config    def run(self, messages: List[PromptMessage],            stop: Optional[List[str]] = None,            callbacks: Callbacks = None,            **kwargs) -> LLMRunResult:        """        run predict by prompt messages and stop words.        :param messages:        :param stop:        :param callbacks:        :return:        """        if self.deduct_quota:            self.model_provider.check_quota_over_limit()        if not callbacks:            callbacks = self.callbacks        else:            callbacks.extend(self.callbacks)        if 'fake_response' in kwargs and kwargs['fake_response']:            prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)            fake_llm = FakeLLM(                response=kwargs['fake_response'],                num_token_func=self.get_num_tokens,                streaming=self.streaming,                callbacks=callbacks            )            result = fake_llm.generate([prompts])        else:            try:                result = self._run(                    messages=messages,                    stop=stop,                    callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,                    **kwargs                )            except Exception as ex:                raise self.handle_exceptions(ex)        if isinstance(result.generations[0][0], ChatGeneration):            completion_content = result.generations[0][0].message.content        else:            completion_content = result.generations[0][0].text        if self.streaming and not self.support_streaming():            # use FakeLLM to simulate streaming when current model not support streaming but streaming is True            prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)            fake_llm = FakeLLM(                response=completion_content,                num_token_func=self.get_num_tokens,                streaming=self.streaming,                callbacks=callbacks            )            fake_llm.generate([prompts])        if result.llm_output and result.llm_output['token_usage']:            prompt_tokens = result.llm_output['token_usage']['prompt_tokens']            completion_tokens = result.llm_output['token_usage']['completion_tokens']            total_tokens = result.llm_output['token_usage']['total_tokens']        else:            prompt_tokens = self.get_num_tokens(messages)            completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])            total_tokens = prompt_tokens + completion_tokens        self.model_provider.update_last_used()        if self.deduct_quota:            self.model_provider.deduct_quota(total_tokens)        return LLMRunResult(            content=completion_content,            prompt_tokens=prompt_tokens,            completion_tokens=completion_tokens        )    @abstractmethod    def _run(self, messages: List[PromptMessage],             stop: Optional[List[str]] = None,             callbacks: Callbacks = None,             **kwargs) -> LLMResult:        """        run predict by prompt messages and stop words.        :param messages:        :param stop:        :param callbacks:        :return:        """        raise NotImplementedError    @abstractmethod    def get_num_tokens(self, messages: List[PromptMessage]) -> int:        """        get num tokens of prompt messages.        :param messages:        :return:        """        raise NotImplementedError    def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:        """        calc tokens total price.        :param tokens:        :param message_type:        :return:        """        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:            unit_price = self.price_config['prompt']        else:            unit_price = self.price_config['completion']        unit = self.get_price_unit(message_type)        total_price = tokens * unit_price * unit        total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)        logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")        return total_price    def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:        """        get token price.        :param message_type:        :return: decimal.Decimal('0.0001')        """        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:            unit_price = self.price_config['prompt']        else:            unit_price = self.price_config['completion']        unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)        logging.debug(f"unit_price={unit_price}")        return unit_price    def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:        """        get price unit.        :param message_type:        :return: decimal.Decimal('0.000001')        """        if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:            price_unit = self.price_config['unit']        else:            price_unit = self.price_config['unit']        price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)        logging.debug(f"price_unit={price_unit}")        return price_unit    def get_currency(self) -> str:        """        get token currency.        :return: get from price config, default 'USD'        """        currency = self.price_config['currency']        return currency    def get_model_kwargs(self):        return self.model_kwargs    def set_model_kwargs(self, model_kwargs: ModelKwargs):        self.model_kwargs = model_kwargs        self._set_model_kwargs(model_kwargs)    @abstractmethod    def _set_model_kwargs(self, model_kwargs: ModelKwargs):        raise NotImplementedError    @abstractmethod    def handle_exceptions(self, ex: Exception) -> Exception:        """        Handle llm run exceptions.        :param ex:        :return:        """        raise NotImplementedError    def add_callbacks(self, callbacks: Callbacks):        """        Add callbacks to client.        :param callbacks:        :return:        """        if not self.client.callbacks:            self.client.callbacks = callbacks        else:            self.client.callbacks.extend(callbacks)    @classmethod    def support_streaming(cls):        return False    def _get_prompt_from_messages(self, messages: List[PromptMessage],                                  model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:        if not model_mode:            model_mode = self.model_mode        if model_mode == ModelMode.COMPLETION:            if len(messages) == 0:                return ''            return messages[0].content        else:            if len(messages) == 0:                return []            chat_messages = []            for message in messages:                if message.type == MessageType.HUMAN:                    chat_messages.append(HumanMessage(content=message.content))                elif message.type == MessageType.ASSISTANT:                    chat_messages.append(AIMessage(content=message.content))                elif message.type == MessageType.SYSTEM:                    chat_messages.append(SystemMessage(content=message.content))            return chat_messages    def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:        """        convert model kwargs to provider model kwargs.        :param model_rules:        :param model_kwargs:        :return:        """        model_kwargs_input = {}        for key, value in model_kwargs.dict().items():            rule = getattr(model_rules, key)            if not rule.enabled:                continue            if rule.alias:                key = rule.alias            if rule.default is not None and value is None:                value = rule.default            if rule.min is not None:                value = max(value, rule.min)            if rule.max is not None:                value = min(value, rule.max)            model_kwargs_input[key] = value        return model_kwargs_input
 |