base.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. import json
  2. import os
  3. import re
  4. from abc import abstractmethod
  5. from typing import List, Optional, Any, Union, Tuple
  6. import decimal
  7. from langchain.callbacks.manager import Callbacks
  8. from langchain.memory.chat_memory import BaseChatMemory
  9. from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
  10. from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
  11. from core.model_providers.models.base import BaseProviderModel
  12. from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
  13. from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
  14. from core.model_providers.providers.base import BaseModelProvider
  15. from core.prompt.prompt_builder import PromptBuilder
  16. from core.prompt.prompt_template import JinjaPromptTemplate
  17. from core.third_party.langchain.llms.fake import FakeLLM
  18. import logging
  19. logger = logging.getLogger(__name__)
  20. class BaseLLM(BaseProviderModel):
  21. model_mode: ModelMode = ModelMode.COMPLETION
  22. name: str
  23. model_kwargs: ModelKwargs
  24. credentials: dict
  25. streaming: bool = False
  26. type: ModelType = ModelType.TEXT_GENERATION
  27. deduct_quota: bool = True
  28. def __init__(self, model_provider: BaseModelProvider,
  29. name: str,
  30. model_kwargs: ModelKwargs,
  31. streaming: bool = False,
  32. callbacks: Callbacks = None):
  33. self.name = name
  34. self.model_rules = model_provider.get_model_parameter_rules(name, self.type)
  35. self.model_kwargs = model_kwargs if model_kwargs else ModelKwargs(
  36. max_tokens=None,
  37. temperature=None,
  38. top_p=None,
  39. presence_penalty=None,
  40. frequency_penalty=None
  41. )
  42. self.credentials = model_provider.get_model_credentials(
  43. model_name=name,
  44. model_type=self.type
  45. )
  46. self.streaming = streaming
  47. if streaming:
  48. default_callback = DifyStreamingStdOutCallbackHandler()
  49. else:
  50. default_callback = DifyStdOutCallbackHandler()
  51. if not callbacks:
  52. callbacks = [default_callback]
  53. else:
  54. callbacks.append(default_callback)
  55. self.callbacks = callbacks
  56. client = self._init_client()
  57. super().__init__(model_provider, client)
  58. @abstractmethod
  59. def _init_client(self) -> Any:
  60. raise NotImplementedError
  61. @property
  62. def base_model_name(self) -> str:
  63. """
  64. get llm base model name
  65. :return: str
  66. """
  67. return self.name
  68. @property
  69. def price_config(self) -> dict:
  70. def get_or_default():
  71. default_price_config = {
  72. 'prompt': decimal.Decimal('0'),
  73. 'completion': decimal.Decimal('0'),
  74. 'unit': decimal.Decimal('0'),
  75. 'currency': 'USD'
  76. }
  77. rules = self.model_provider.get_rules()
  78. price_config = rules['price_config'][
  79. self.base_model_name] if 'price_config' in rules else default_price_config
  80. price_config = {
  81. 'prompt': decimal.Decimal(price_config['prompt']),
  82. 'completion': decimal.Decimal(price_config['completion']),
  83. 'unit': decimal.Decimal(price_config['unit']),
  84. 'currency': price_config['currency']
  85. }
  86. return price_config
  87. self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
  88. logger.debug(f"model: {self.name} price_config: {self._price_config}")
  89. return self._price_config
  90. def run(self, messages: List[PromptMessage],
  91. stop: Optional[List[str]] = None,
  92. callbacks: Callbacks = None,
  93. **kwargs) -> LLMRunResult:
  94. """
  95. run predict by prompt messages and stop words.
  96. :param messages:
  97. :param stop:
  98. :param callbacks:
  99. :return:
  100. """
  101. if self.deduct_quota:
  102. self.model_provider.check_quota_over_limit()
  103. if not callbacks:
  104. callbacks = self.callbacks
  105. else:
  106. callbacks.extend(self.callbacks)
  107. if 'fake_response' in kwargs and kwargs['fake_response']:
  108. prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
  109. fake_llm = FakeLLM(
  110. response=kwargs['fake_response'],
  111. num_token_func=self.get_num_tokens,
  112. streaming=self.streaming,
  113. callbacks=callbacks
  114. )
  115. result = fake_llm.generate([prompts])
  116. else:
  117. try:
  118. result = self._run(
  119. messages=messages,
  120. stop=stop,
  121. callbacks=callbacks if not (self.streaming and not self.support_streaming()) else None,
  122. **kwargs
  123. )
  124. except Exception as ex:
  125. raise self.handle_exceptions(ex)
  126. if isinstance(result.generations[0][0], ChatGeneration):
  127. completion_content = result.generations[0][0].message.content
  128. else:
  129. completion_content = result.generations[0][0].text
  130. if self.streaming and not self.support_streaming():
  131. # use FakeLLM to simulate streaming when current model not support streaming but streaming is True
  132. prompts = self._get_prompt_from_messages(messages, ModelMode.CHAT)
  133. fake_llm = FakeLLM(
  134. response=completion_content,
  135. num_token_func=self.get_num_tokens,
  136. streaming=self.streaming,
  137. callbacks=callbacks
  138. )
  139. fake_llm.generate([prompts])
  140. if result.llm_output and result.llm_output['token_usage']:
  141. prompt_tokens = result.llm_output['token_usage']['prompt_tokens']
  142. completion_tokens = result.llm_output['token_usage']['completion_tokens']
  143. total_tokens = result.llm_output['token_usage']['total_tokens']
  144. else:
  145. prompt_tokens = self.get_num_tokens(messages)
  146. completion_tokens = self.get_num_tokens(
  147. [PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
  148. total_tokens = prompt_tokens + completion_tokens
  149. self.model_provider.update_last_used()
  150. if self.deduct_quota:
  151. self.model_provider.deduct_quota(total_tokens)
  152. return LLMRunResult(
  153. content=completion_content,
  154. prompt_tokens=prompt_tokens,
  155. completion_tokens=completion_tokens
  156. )
  157. @abstractmethod
  158. def _run(self, messages: List[PromptMessage],
  159. stop: Optional[List[str]] = None,
  160. callbacks: Callbacks = None,
  161. **kwargs) -> LLMResult:
  162. """
  163. run predict by prompt messages and stop words.
  164. :param messages:
  165. :param stop:
  166. :param callbacks:
  167. :return:
  168. """
  169. raise NotImplementedError
  170. @abstractmethod
  171. def get_num_tokens(self, messages: List[PromptMessage]) -> int:
  172. """
  173. get num tokens of prompt messages.
  174. :param messages:
  175. :return:
  176. """
  177. raise NotImplementedError
  178. def calc_tokens_price(self, tokens: int, message_type: MessageType) -> decimal.Decimal:
  179. """
  180. calc tokens total price.
  181. :param tokens:
  182. :param message_type:
  183. :return:
  184. """
  185. if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
  186. unit_price = self.price_config['prompt']
  187. else:
  188. unit_price = self.price_config['completion']
  189. unit = self.get_price_unit(message_type)
  190. total_price = tokens * unit_price * unit
  191. total_price = total_price.quantize(decimal.Decimal('0.0000001'), rounding=decimal.ROUND_HALF_UP)
  192. logging.debug(f"tokens={tokens}, unit_price={unit_price}, unit={unit}, total_price:{total_price}")
  193. return total_price
  194. def get_tokens_unit_price(self, message_type: MessageType) -> decimal.Decimal:
  195. """
  196. get token price.
  197. :param message_type:
  198. :return: decimal.Decimal('0.0001')
  199. """
  200. if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
  201. unit_price = self.price_config['prompt']
  202. else:
  203. unit_price = self.price_config['completion']
  204. unit_price = unit_price.quantize(decimal.Decimal('0.0001'), rounding=decimal.ROUND_HALF_UP)
  205. logging.debug(f"unit_price={unit_price}")
  206. return unit_price
  207. def get_price_unit(self, message_type: MessageType) -> decimal.Decimal:
  208. """
  209. get price unit.
  210. :param message_type:
  211. :return: decimal.Decimal('0.000001')
  212. """
  213. if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
  214. price_unit = self.price_config['unit']
  215. else:
  216. price_unit = self.price_config['unit']
  217. price_unit = price_unit.quantize(decimal.Decimal('0.000001'), rounding=decimal.ROUND_HALF_UP)
  218. logging.debug(f"price_unit={price_unit}")
  219. return price_unit
  220. def get_currency(self) -> str:
  221. """
  222. get token currency.
  223. :return: get from price config, default 'USD'
  224. """
  225. currency = self.price_config['currency']
  226. return currency
  227. def get_model_kwargs(self):
  228. return self.model_kwargs
  229. def set_model_kwargs(self, model_kwargs: ModelKwargs):
  230. self.model_kwargs = model_kwargs
  231. self._set_model_kwargs(model_kwargs)
  232. @abstractmethod
  233. def _set_model_kwargs(self, model_kwargs: ModelKwargs):
  234. raise NotImplementedError
  235. @abstractmethod
  236. def handle_exceptions(self, ex: Exception) -> Exception:
  237. """
  238. Handle llm run exceptions.
  239. :param ex:
  240. :return:
  241. """
  242. raise NotImplementedError
  243. def add_callbacks(self, callbacks: Callbacks):
  244. """
  245. Add callbacks to client.
  246. :param callbacks:
  247. :return:
  248. """
  249. if not self.client.callbacks:
  250. self.client.callbacks = callbacks
  251. else:
  252. self.client.callbacks.extend(callbacks)
  253. @classmethod
  254. def support_streaming(cls):
  255. return False
  256. def get_prompt(self, mode: str,
  257. pre_prompt: str, inputs: dict,
  258. query: str,
  259. context: Optional[str],
  260. memory: Optional[BaseChatMemory]) -> \
  261. Tuple[List[PromptMessage], Optional[List[str]]]:
  262. prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
  263. prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
  264. return [PromptMessage(content=prompt)], stops
  265. def prompt_file_name(self, mode: str) -> str:
  266. if mode == 'completion':
  267. return 'common_completion'
  268. else:
  269. return 'common_chat'
  270. def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
  271. query: str,
  272. context: Optional[str],
  273. memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
  274. context_prompt_content = ''
  275. if context and 'context_prompt' in prompt_rules:
  276. prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
  277. context_prompt_content = prompt_template.format(
  278. context=context
  279. )
  280. pre_prompt_content = ''
  281. if pre_prompt:
  282. prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
  283. prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
  284. pre_prompt_content = prompt_template.format(
  285. **prompt_inputs
  286. )
  287. prompt = ''
  288. for order in prompt_rules['system_prompt_orders']:
  289. if order == 'context_prompt':
  290. prompt += context_prompt_content
  291. elif order == 'pre_prompt':
  292. prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
  293. query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
  294. if memory and 'histories_prompt' in prompt_rules:
  295. # append chat histories
  296. tmp_human_message = PromptBuilder.to_human_message(
  297. prompt_content=prompt + query_prompt,
  298. inputs={
  299. 'query': query
  300. }
  301. )
  302. if self.model_rules.max_tokens.max:
  303. curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
  304. max_tokens = self.model_kwargs.max_tokens
  305. rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
  306. rest_tokens = max(rest_tokens, 0)
  307. else:
  308. rest_tokens = 2000
  309. memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
  310. memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
  311. histories = self._get_history_messages_from_memory(memory, rest_tokens)
  312. prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
  313. histories_prompt_content = prompt_template.format(
  314. histories=histories
  315. )
  316. prompt = ''
  317. for order in prompt_rules['system_prompt_orders']:
  318. if order == 'context_prompt':
  319. prompt += context_prompt_content
  320. elif order == 'pre_prompt':
  321. prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
  322. elif order == 'histories_prompt':
  323. prompt += histories_prompt_content
  324. prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
  325. query_prompt_content = prompt_template.format(
  326. query=query
  327. )
  328. prompt += query_prompt_content
  329. prompt = re.sub(r'<\|.*?\|>', '', prompt)
  330. stops = prompt_rules.get('stops')
  331. if stops is not None and len(stops) == 0:
  332. stops = None
  333. return prompt, stops
  334. def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
  335. # Get the absolute path of the subdirectory
  336. prompt_path = os.path.join(
  337. os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
  338. 'prompt/generate_prompts')
  339. json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
  340. # Open the JSON file and read its content
  341. with open(json_file_path, 'r') as json_file:
  342. return json.load(json_file)
  343. def _get_history_messages_from_memory(self, memory: BaseChatMemory,
  344. max_token_limit: int) -> str:
  345. """Get memory messages."""
  346. memory.max_token_limit = max_token_limit
  347. memory_key = memory.memory_variables[0]
  348. external_context = memory.load_memory_variables({})
  349. return external_context[memory_key]
  350. def _get_prompt_from_messages(self, messages: List[PromptMessage],
  351. model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
  352. if not model_mode:
  353. model_mode = self.model_mode
  354. if model_mode == ModelMode.COMPLETION:
  355. if len(messages) == 0:
  356. return ''
  357. return messages[0].content
  358. else:
  359. if len(messages) == 0:
  360. return []
  361. chat_messages = []
  362. for message in messages:
  363. if message.type == MessageType.HUMAN:
  364. chat_messages.append(HumanMessage(content=message.content))
  365. elif message.type == MessageType.ASSISTANT:
  366. chat_messages.append(AIMessage(content=message.content))
  367. elif message.type == MessageType.SYSTEM:
  368. chat_messages.append(SystemMessage(content=message.content))
  369. return chat_messages
  370. def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
  371. """
  372. convert model kwargs to provider model kwargs.
  373. :param model_rules:
  374. :param model_kwargs:
  375. :return:
  376. """
  377. model_kwargs_input = {}
  378. for key, value in model_kwargs.dict().items():
  379. rule = getattr(model_rules, key)
  380. if not rule.enabled:
  381. continue
  382. if rule.alias:
  383. key = rule.alias
  384. if rule.default is not None and value is None:
  385. value = rule.default
  386. if rule.min is not None:
  387. value = max(value, rule.min)
  388. if rule.max is not None:
  389. value = min(value, rule.max)
  390. model_kwargs_input[key] = value
  391. return model_kwargs_input