spark.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import re
  2. import string
  3. import threading
  4. from _decimal import Decimal, ROUND_HALF_UP
  5. from typing import Dict, List, Optional, Any, Mapping
  6. from langchain.callbacks.manager import CallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun
  7. from langchain.chat_models.base import BaseChatModel
  8. from langchain.llms.utils import enforce_stop_tokens
  9. from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, ChatResult, \
  10. ChatGeneration
  11. from langchain.utils import get_from_dict_or_env
  12. from pydantic import root_validator
  13. from core.third_party.spark.spark_llm import SparkLLMClient
  14. class ChatSpark(BaseChatModel):
  15. r"""Wrapper around Spark's large language model.
  16. To use, you should pass `app_id`, `api_key`, `api_secret`
  17. as a named parameter to the constructor.
  18. Example:
  19. .. code-block:: python
  20. client = SparkLLMClient(
  21. app_id="<app_id>",
  22. api_key="<api_key>",
  23. api_secret="<api_secret>"
  24. )
  25. """
  26. client: Any = None #: :meta private:
  27. max_tokens: int = 256
  28. """Denotes the number of tokens to predict per generation."""
  29. temperature: Optional[float] = None
  30. """A non-negative float that tunes the degree of randomness in generation."""
  31. top_k: Optional[int] = None
  32. """Number of most likely tokens to consider at each step."""
  33. user_id: Optional[str] = None
  34. """User ID to use for the model."""
  35. streaming: bool = False
  36. """Whether to stream the results."""
  37. app_id: Optional[str] = None
  38. api_key: Optional[str] = None
  39. api_secret: Optional[str] = None
  40. api_domain: Optional[str] = None
  41. @root_validator()
  42. def validate_environment(cls, values: Dict) -> Dict:
  43. """Validate that api key and python package exists in environment."""
  44. values["app_id"] = get_from_dict_or_env(
  45. values, "app_id", "SPARK_APP_ID"
  46. )
  47. values["api_key"] = get_from_dict_or_env(
  48. values, "api_key", "SPARK_API_KEY"
  49. )
  50. values["api_secret"] = get_from_dict_or_env(
  51. values, "api_secret", "SPARK_API_SECRET"
  52. )
  53. values["client"] = SparkLLMClient(
  54. app_id=values["app_id"],
  55. api_key=values["api_key"],
  56. api_secret=values["api_secret"],
  57. api_domain=values.get('api_domain')
  58. )
  59. return values
  60. @property
  61. def _default_params(self) -> Mapping[str, Any]:
  62. """Get the default parameters for calling Anthropic API."""
  63. d = {
  64. "max_tokens": self.max_tokens
  65. }
  66. if self.temperature is not None:
  67. d["temperature"] = self.temperature
  68. if self.top_k is not None:
  69. d["top_k"] = self.top_k
  70. return d
  71. @property
  72. def _identifying_params(self) -> Mapping[str, Any]:
  73. """Get the identifying parameters."""
  74. return {**{}, **self._default_params}
  75. @property
  76. def lc_secrets(self) -> Dict[str, str]:
  77. return {"api_key": "API_KEY", "api_secret": "API_SECRET"}
  78. @property
  79. def _llm_type(self) -> str:
  80. """Return type of chat model."""
  81. return "spark-chat"
  82. @property
  83. def lc_serializable(self) -> bool:
  84. return True
  85. def _convert_messages_to_dicts(self, messages: List[BaseMessage]) -> list[dict]:
  86. """Format a list of messages into a full dict list.
  87. Args:
  88. messages (List[BaseMessage]): List of BaseMessage to combine.
  89. Returns:
  90. list[dict]
  91. """
  92. messages = messages.copy() # don't mutate the original list
  93. new_messages = []
  94. for message in messages:
  95. if isinstance(message, ChatMessage):
  96. new_messages.append({'role': 'user', 'content': message.content})
  97. elif isinstance(message, HumanMessage) or isinstance(message, SystemMessage):
  98. new_messages.append({'role': 'user', 'content': message.content})
  99. elif isinstance(message, AIMessage):
  100. new_messages.append({'role': 'assistant', 'content': message.content})
  101. else:
  102. raise ValueError(f"Got unknown type {message}")
  103. return new_messages
  104. def _generate(
  105. self,
  106. messages: List[BaseMessage],
  107. stop: Optional[List[str]] = None,
  108. run_manager: Optional[CallbackManagerForLLMRun] = None,
  109. **kwargs: Any,
  110. ) -> ChatResult:
  111. messages = self._convert_messages_to_dicts(messages)
  112. thread = threading.Thread(target=self.client.run, args=(
  113. messages,
  114. self.user_id,
  115. self._default_params,
  116. self.streaming
  117. ))
  118. thread.start()
  119. completion = ""
  120. for content in self.client.subscribe():
  121. if isinstance(content, dict):
  122. delta = content['data']
  123. else:
  124. delta = content
  125. completion += delta
  126. if self.streaming and run_manager:
  127. run_manager.on_llm_new_token(
  128. delta,
  129. )
  130. thread.join()
  131. if stop is not None:
  132. completion = enforce_stop_tokens(completion, stop)
  133. message = AIMessage(content=completion)
  134. return ChatResult(generations=[ChatGeneration(message=message)])
  135. async def _agenerate(
  136. self,
  137. messages: List[BaseMessage],
  138. stop: Optional[List[str]] = None,
  139. run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
  140. **kwargs: Any,
  141. ) -> ChatResult:
  142. message = AIMessage(content='')
  143. return ChatResult(generations=[ChatGeneration(message=message)])
  144. def get_num_tokens(self, text: str) -> float:
  145. """Calculate number of tokens."""
  146. total = Decimal(0)
  147. words = re.findall(r'\b\w+\b|[{}]|\s'.format(re.escape(string.punctuation)), text)
  148. for word in words:
  149. if word:
  150. if '\u4e00' <= word <= '\u9fff': # if chinese
  151. total += Decimal('1.5')
  152. else:
  153. total += Decimal('0.8')
  154. return int(total)