llm_generator.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. import logging
  2. from langchain.schema import OutputParserException
  3. from core.model_providers.error import LLMError, ProviderTokenNotInitError
  4. from core.model_providers.model_factory import ModelFactory
  5. from core.model_providers.models.entity.message import PromptMessage, MessageType
  6. from core.model_providers.models.entity.model_params import ModelKwargs
  7. from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
  8. from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
  9. from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
  10. from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
  11. GENERATOR_QA_PROMPT
  12. class LLMGenerator:
  13. @classmethod
  14. def generate_conversation_name(cls, tenant_id: str, query, answer):
  15. prompt = CONVERSATION_TITLE_PROMPT
  16. if len(query) > 2000:
  17. query = query[:300] + "...[TRUNCATED]..." + query[-300:]
  18. prompt = prompt.format(query=query)
  19. model_instance = ModelFactory.get_text_generation_model(
  20. tenant_id=tenant_id,
  21. model_kwargs=ModelKwargs(
  22. max_tokens=50
  23. )
  24. )
  25. prompts = [PromptMessage(content=prompt)]
  26. response = model_instance.run(prompts)
  27. answer = response.content
  28. return answer.strip()
  29. @classmethod
  30. def generate_conversation_summary(cls, tenant_id: str, messages):
  31. max_tokens = 200
  32. model_instance = ModelFactory.get_text_generation_model(
  33. tenant_id=tenant_id,
  34. model_kwargs=ModelKwargs(
  35. max_tokens=max_tokens
  36. )
  37. )
  38. prompt = CONVERSATION_SUMMARY_PROMPT
  39. prompt_with_empty_context = prompt.format(context='')
  40. prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
  41. max_context_token_length = model_instance.model_rules.max_tokens.max
  42. rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
  43. context = ''
  44. for message in messages:
  45. if not message.answer:
  46. continue
  47. if len(message.query) > 2000:
  48. query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
  49. else:
  50. query = message.query
  51. if len(message.answer) > 2000:
  52. answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
  53. else:
  54. answer = message.answer
  55. message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
  56. if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
  57. context += message_qa_text
  58. if not context:
  59. return '[message too long, no summary]'
  60. prompt = prompt.format(context=context)
  61. prompts = [PromptMessage(content=prompt)]
  62. response = model_instance.run(prompts)
  63. answer = response.content
  64. return answer.strip()
  65. @classmethod
  66. def generate_introduction(cls, tenant_id: str, pre_prompt: str):
  67. prompt = INTRODUCTION_GENERATE_PROMPT
  68. prompt = prompt.format(prompt=pre_prompt)
  69. model_instance = ModelFactory.get_text_generation_model(
  70. tenant_id=tenant_id
  71. )
  72. prompts = [PromptMessage(content=prompt)]
  73. response = model_instance.run(prompts)
  74. answer = response.content
  75. return answer.strip()
  76. @classmethod
  77. def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
  78. output_parser = SuggestedQuestionsAfterAnswerOutputParser()
  79. format_instructions = output_parser.get_format_instructions()
  80. prompt = JinjaPromptTemplate(
  81. template="{{histories}}\n{{format_instructions}}\nquestions:\n",
  82. input_variables=["histories"],
  83. partial_variables={"format_instructions": format_instructions}
  84. )
  85. _input = prompt.format_prompt(histories=histories)
  86. try:
  87. model_instance = ModelFactory.get_text_generation_model(
  88. tenant_id=tenant_id,
  89. model_kwargs=ModelKwargs(
  90. max_tokens=256,
  91. temperature=0
  92. )
  93. )
  94. except ProviderTokenNotInitError:
  95. return []
  96. prompts = [PromptMessage(content=_input.to_string())]
  97. try:
  98. output = model_instance.run(prompts)
  99. questions = output_parser.parse(output.content)
  100. except LLMError:
  101. questions = []
  102. except Exception as e:
  103. logging.exception(e)
  104. questions = []
  105. return questions
  106. @classmethod
  107. def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
  108. output_parser = RuleConfigGeneratorOutputParser()
  109. prompt = OutLinePromptTemplate(
  110. template=output_parser.get_format_instructions(),
  111. input_variables=["audiences", "hoping_to_solve"],
  112. partial_variables={
  113. "variable": '{variable}',
  114. "lanA": '{lanA}',
  115. "lanB": '{lanB}',
  116. "topic": '{topic}'
  117. },
  118. validate_template=False
  119. )
  120. _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
  121. model_instance = ModelFactory.get_text_generation_model(
  122. tenant_id=tenant_id,
  123. model_kwargs=ModelKwargs(
  124. max_tokens=512,
  125. temperature=0
  126. )
  127. )
  128. prompts = [PromptMessage(content=_input.to_string())]
  129. try:
  130. output = model_instance.run(prompts)
  131. rule_config = output_parser.parse(output.content)
  132. except LLMError as e:
  133. raise e
  134. except OutputParserException:
  135. raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
  136. except Exception as e:
  137. logging.exception(e)
  138. rule_config = {
  139. "prompt": "",
  140. "variables": [],
  141. "opening_statement": ""
  142. }
  143. return rule_config
  144. @classmethod
  145. def generate_qa_document(cls, tenant_id: str, query):
  146. prompt = GENERATOR_QA_PROMPT
  147. model_instance = ModelFactory.get_text_generation_model(
  148. tenant_id=tenant_id,
  149. model_kwargs=ModelKwargs(
  150. max_tokens=2000
  151. )
  152. )
  153. prompts = [
  154. PromptMessage(content=prompt, type=MessageType.SYSTEM),
  155. PromptMessage(content=query)
  156. ]
  157. response = model_instance.run(prompts)
  158. answer = response.content
  159. return answer.strip()