import re import string import threading from _decimal import Decimal, ROUND_HALF_UP from typing import Dict, List, Optional, Any, Mapping from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun from langchain.chat_models.base import BaseChatModel from langchain.llms.utils import enforce_stop_tokens from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, ChatResult, \ ChatGeneration from langchain.utils import get_from_dict_or_env from pydantic import root_validator from core.third_party.spark.spark_llm import SparkLLMClient class ChatSpark(BaseChatModel): r"""Wrapper around Spark's large language model. To use, you should pass `app_id`, `api_key`, `api_secret` as a named parameter to the constructor. Example: .. code-block:: python client = SparkLLMClient( app_id="", api_key="", api_secret="" ) """ client: Any = None #: :meta private: max_tokens: int = 256 """Denotes the number of tokens to predict per generation.""" temperature: Optional[float] = None """A non-negative float that tunes the degree of randomness in generation.""" top_k: Optional[int] = None """Number of most likely tokens to consider at each step.""" user_id: Optional[str] = None """User ID to use for the model.""" streaming: bool = False """Whether to stream the results.""" app_id: Optional[str] = None api_key: Optional[str] = None api_secret: Optional[str] = None api_domain: Optional[str] = None @root_validator() def validate_environment(cls, values: Dict) -> Dict: """Validate that api key and python package exists in environment.""" values["app_id"] = get_from_dict_or_env( values, "app_id", "SPARK_APP_ID" ) values["api_key"] = get_from_dict_or_env( values, "api_key", "SPARK_API_KEY" ) values["api_secret"] = get_from_dict_or_env( values, "api_secret", "SPARK_API_SECRET" ) values["client"] = SparkLLMClient( app_id=values["app_id"], api_key=values["api_key"], api_secret=values["api_secret"], api_domain=values.get('api_domain') ) return values @property def _default_params(self) -> Mapping[str, Any]: """Get the default parameters for calling Anthropic API.""" d = { "max_tokens": self.max_tokens } if self.temperature is not None: d["temperature"] = self.temperature if self.top_k is not None: d["top_k"] = self.top_k return d @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return {**{}, **self._default_params} @property def lc_secrets(self) -> Dict[str, str]: return {"api_key": "API_KEY", "api_secret": "API_SECRET"} @property def _llm_type(self) -> str: """Return type of chat model.""" return "spark-chat" @property def lc_serializable(self) -> bool: return True def _convert_messages_to_dicts(self, messages: List[BaseMessage]) -> list[dict]: """Format a list of messages into a full dict list. Args: messages (List[BaseMessage]): List of BaseMessage to combine. Returns: list[dict] """ messages = messages.copy() # don't mutate the original list new_messages = [] for message in messages: if isinstance(message, ChatMessage): new_messages.append({'role': 'user', 'content': message.content}) elif isinstance(message, HumanMessage) or isinstance(message, SystemMessage): new_messages.append({'role': 'user', 'content': message.content}) elif isinstance(message, AIMessage): new_messages.append({'role': 'assistant', 'content': message.content}) else: raise ValueError(f"Got unknown type {message}") return new_messages def _generate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: messages = self._convert_messages_to_dicts(messages) thread = threading.Thread(target=self.client.run, args=( messages, self.user_id, self._default_params, self.streaming )) thread.start() completion = "" for content in self.client.subscribe(): if isinstance(content, dict): delta = content['data'] else: delta = content completion += delta if self.streaming and run_manager: run_manager.on_llm_new_token( delta, ) thread.join() if stop is not None: completion = enforce_stop_tokens(completion, stop) message = AIMessage(content=completion) return ChatResult(generations=[ChatGeneration(message=message)]) async def _agenerate( self, messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> ChatResult: message = AIMessage(content='') return ChatResult(generations=[ChatGeneration(message=message)]) def get_num_tokens(self, text: str) -> float: """Calculate number of tokens.""" total = Decimal(0) words = re.findall(r'\b\w+\b|[{}]|\s'.format(re.escape(string.punctuation)), text) for word in words: if word: if '\u4e00' <= word <= '\u9fff': # if chinese total += Decimal('1.5') else: total += Decimal('0.8') return int(total)