| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 | 
							- """Wrapper around Wenxin APIs."""
 
- from __future__ import annotations
 
- import json
 
- import logging
 
- from json import JSONDecodeError
 
- from typing import (
 
-     Any,
 
-     Dict,
 
-     List,
 
-     Optional, Iterator, Tuple,
 
- )
 
- import requests
 
- from langchain.chat_models.base import BaseChatModel
 
- from langchain.llms.utils import enforce_stop_tokens
 
- from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
 
- from langchain.schema.messages import AIMessageChunk
 
- from langchain.schema.output import GenerationChunk, ChatResult, ChatGenerationChunk, ChatGeneration
 
- from pydantic import BaseModel, Extra, Field, PrivateAttr, root_validator
 
- from langchain.callbacks.manager import (
 
-     CallbackManagerForLLMRun,
 
- )
 
- from langchain.llms.base import LLM
 
- from langchain.utils import get_from_dict_or_env
 
- logger = logging.getLogger(__name__)
 
- class _WenxinEndpointClient(BaseModel):
 
-     """An API client that talks to a Wenxin llm endpoint."""
 
-     base_url: str = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/"
 
-     secret_key: str
 
-     api_key: str
 
-     def get_access_token(self) -> str:
 
-         url = f"https://aip.baidubce.com/oauth/2.0/token?client_id={self.api_key}" \
 
-               f"&client_secret={self.secret_key}&grant_type=client_credentials"
 
-         headers = {
 
-             'Content-Type': 'application/json',
 
-             'Accept': 'application/json'
 
-         }
 
-         response = requests.post(url, headers=headers)
 
-         if not response.ok:
 
-             raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
 
-         if 'error' in response.json():
 
-             raise ValueError(
 
-                 f"Wenxin API {response.json()['error']}"
 
-                 f" error: {response.json()['error_description']}"
 
-             )
 
-         access_token = response.json()['access_token']
 
-         # todo add cache
 
-         return access_token
 
-     def post(self, request: dict) -> Any:
 
-         if 'model' not in request:
 
-             raise ValueError(f"Wenxin Model name is required")
 
-         model_url_map = {
 
-             'ernie-bot-4': 'completions_pro',
 
-             'ernie-bot': 'completions',
 
-             'ernie-bot-turbo': 'eb-instant',
 
-             'bloomz-7b': 'bloomz_7b1',
 
-         }
 
-         stream = 'stream' in request and request['stream']
 
-         access_token = self.get_access_token()
 
-         api_url = f"{self.base_url}{model_url_map[request['model']]}?access_token={access_token}"
 
-         del request['model']
 
-         headers = {"Content-Type": "application/json"}
 
-         response = requests.post(api_url,
 
-                                  headers=headers,
 
-                                  json=request,
 
-                                  stream=stream)
 
-         if not response.ok:
 
-             raise ValueError(f"Wenxin HTTP {response.status_code} error: {response.text}")
 
-         if not stream:
 
-             json_response = response.json()
 
-             if 'error_code' in json_response:
 
-                 raise ValueError(
 
-                     f"Wenxin API {json_response['error_code']}"
 
-                     f" error: {json_response['error_msg']}"
 
-                 )
 
-             return json_response
 
-         else:
 
-             return response
 
- class Wenxin(BaseChatModel):
 
-     """Wrapper around Wenxin large language models."""
 
-     @property
 
-     def lc_secrets(self) -> Dict[str, str]:
 
-         return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
 
-     @property
 
-     def lc_serializable(self) -> bool:
 
-         return True
 
-     _client: _WenxinEndpointClient = PrivateAttr()
 
-     model: str = "ernie-bot"
 
-     """Model name to use."""
 
-     temperature: float = 0.7
 
-     """A non-negative float that tunes the degree of randomness in generation."""
 
-     top_p: float = 0.95
 
-     """Total probability mass of tokens to consider at each step."""
 
-     model_kwargs: Dict[str, Any] = Field(default_factory=dict)
 
-     """Holds any model parameters valid for `create` call not explicitly specified."""
 
-     streaming: bool = False
 
-     """Whether to stream the response or return it all at once."""
 
-     api_key: Optional[str] = None
 
-     secret_key: Optional[str] = None
 
-     class Config:
 
-         """Configuration for this pydantic object."""
 
-         extra = Extra.forbid
 
-     @root_validator()
 
-     def validate_environment(cls, values: Dict) -> Dict:
 
-         """Validate that api key and python package exists in environment."""
 
-         values["api_key"] = get_from_dict_or_env(
 
-             values, "api_key", "WENXIN_API_KEY"
 
-         )
 
-         values["secret_key"] = get_from_dict_or_env(
 
-             values, "secret_key", "WENXIN_SECRET_KEY"
 
-         )
 
-         return values
 
-     @property
 
-     def _default_params(self) -> Dict[str, Any]:
 
-         """Get the default parameters for calling OpenAI API."""
 
-         return {
 
-             "model": self.model,
 
-             "temperature": self.temperature,
 
-             "top_p": self.top_p,
 
-             "stream": self.streaming,
 
-             **self.model_kwargs,
 
-         }
 
-     @property
 
-     def _identifying_params(self) -> Dict[str, Any]:
 
-         """Get the identifying parameters."""
 
-         return {**{"model": self.model}, **self._default_params}
 
-     @property
 
-     def _llm_type(self) -> str:
 
-         """Return type of llm."""
 
-         return "wenxin"
 
-     def __init__(self, **data: Any):
 
-         super().__init__(**data)
 
-         self._client = _WenxinEndpointClient(
 
-             api_key=self.api_key,
 
-             secret_key=self.secret_key,
 
-         )
 
-     def _convert_message_to_dict(self, message: BaseMessage) -> dict:
 
-         if isinstance(message, ChatMessage):
 
-             message_dict = {"role": message.role, "content": message.content}
 
-         elif isinstance(message, HumanMessage):
 
-             message_dict = {"role": "user", "content": message.content}
 
-         elif isinstance(message, AIMessage):
 
-             message_dict = {"role": "assistant", "content": message.content}
 
-         elif isinstance(message, SystemMessage):
 
-             message_dict = {"role": "system", "content": message.content}
 
-         else:
 
-             raise ValueError(f"Got unknown type {message}")
 
-         return message_dict
 
-     def _create_message_dicts(
 
-         self, messages: List[BaseMessage]
 
-     ) -> Tuple[List[Dict[str, Any]], str]:
 
-         dict_messages = []
 
-         system = None
 
-         for m in messages:
 
-             message = self._convert_message_to_dict(m)
 
-             if message['role'] == 'system':
 
-                 if not system:
 
-                     system = message['content']
 
-                 else:
 
-                     system += f"\n{message['content']}"
 
-                 continue
 
-             if dict_messages:
 
-                 previous_message = dict_messages[-1]
 
-                 if previous_message['role'] == message['role']:
 
-                     dict_messages[-1]['content'] += f"\n{message['content']}"
 
-                 else:
 
-                     dict_messages.append(message)
 
-             else:
 
-                 dict_messages.append(message)
 
-         return dict_messages, system
 
-     def _generate(
 
-         self,
 
-         messages: List[BaseMessage],
 
-         stop: Optional[List[str]] = None,
 
-         run_manager: Optional[CallbackManagerForLLMRun] = None,
 
-         **kwargs: Any,
 
-     ) -> ChatResult:
 
-         if self.streaming:
 
-             generation: Optional[ChatGenerationChunk] = None
 
-             llm_output: Optional[Dict] = None
 
-             for chunk in self._stream(
 
-                     messages=messages, stop=stop, run_manager=run_manager, **kwargs
 
-             ):
 
-                 if chunk.generation_info is not None \
 
-                         and 'token_usage' in chunk.generation_info:
 
-                     llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
 
-                 if generation is None:
 
-                     generation = chunk
 
-                 else:
 
-                     generation += chunk
 
-             assert generation is not None
 
-             return ChatResult(generations=[generation], llm_output=llm_output)
 
-         else:
 
-             message_dicts, system = self._create_message_dicts(messages)
 
-             request = self._default_params
 
-             request["messages"] = message_dicts
 
-             if system:
 
-                 request["system"] = system
 
-             request.update(kwargs)
 
-             response = self._client.post(request)
 
-             return self._create_chat_result(response)
 
-     def _stream(
 
-             self,
 
-             messages: List[BaseMessage],
 
-             stop: Optional[List[str]] = None,
 
-             run_manager: Optional[CallbackManagerForLLMRun] = None,
 
-             **kwargs: Any,
 
-     ) -> Iterator[ChatGenerationChunk]:
 
-         message_dicts, system = self._create_message_dicts(messages)
 
-         request = self._default_params
 
-         request["messages"] = message_dicts
 
-         if system:
 
-             request["system"] = system
 
-         request.update(kwargs)
 
-         for token in self._client.post(request).iter_lines():
 
-             if token:
 
-                 token = token.decode("utf-8")
 
-                 if token.startswith('data:'):
 
-                     completion = json.loads(token[5:])
 
-                     chunk_dict = {
 
-                         'message': AIMessageChunk(content=completion['result']),
 
-                     }
 
-                     if completion['is_end']:
 
-                         token_usage = completion['usage']
 
-                         token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
 
-                         chunk_dict['generation_info'] = dict({'token_usage': token_usage})
 
-                     yield ChatGenerationChunk(**chunk_dict)
 
-                     if run_manager:
 
-                         run_manager.on_llm_new_token(completion['result'])
 
-                 else:
 
-                     try:
 
-                         json_response = json.loads(token)
 
-                     except JSONDecodeError:
 
-                         raise ValueError(f"Wenxin Response Error {token}")
 
-                     raise ValueError(
 
-                         f"Wenxin API {json_response['error_code']}"
 
-                         f" error: {json_response['error_msg']}, "
 
-                         f"please confirm if the model you have chosen is already paid for."
 
-                     )
 
-     def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
 
-         generations = [ChatGeneration(
 
-             message=AIMessage(content=response['result']),
 
-         )]
 
-         token_usage = response.get("usage")
 
-         token_usage['completion_tokens'] = token_usage['total_tokens'] - token_usage['prompt_tokens']
 
-         llm_output = {"token_usage": token_usage, "model_name": self.model}
 
-         return ChatResult(generations=generations, llm_output=llm_output)
 
-     def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
 
-         """Get the number of tokens in the messages.
 
-         Useful for checking if an input will fit in a model's context window.
 
-         Args:
 
-             messages: The message inputs to tokenize.
 
-         Returns:
 
-             The sum of the number of tokens across the messages.
 
-         """
 
-         return sum([self.get_num_tokens(m.content) for m in messages])
 
-     def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
 
-         overall_token_usage: dict = {}
 
-         for output in llm_outputs:
 
-             if output is None:
 
-                 # Happens in streaming
 
-                 continue
 
-             token_usage = output["token_usage"]
 
-             for k, v in token_usage.items():
 
-                 if k in overall_token_usage:
 
-                     overall_token_usage[k] += v
 
-                 else:
 
-                     overall_token_usage[k] = v
 
-         return {"token_usage": overall_token_usage, "model_name": self.model}
 
 
  |