llm_generator.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. import logging
  2. from langchain.chat_models.base import BaseChatModel
  3. from langchain.schema import HumanMessage, OutputParserException
  4. from core.constant import llm_constant
  5. from core.llm.llm_builder import LLMBuilder
  6. from core.llm.streamable_open_ai import StreamableOpenAI
  7. from core.llm.token_calculator import TokenCalculator
  8. from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
  9. from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
  10. from core.prompt.prompt_template import OutLinePromptTemplate
  11. from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
  12. # gpt-3.5-turbo works not well
  13. generate_base_model = 'text-davinci-003'
  14. class LLMGenerator:
  15. @classmethod
  16. def generate_conversation_name(cls, tenant_id: str, query, answer):
  17. prompt = CONVERSATION_TITLE_PROMPT
  18. prompt = prompt.format(query=query, answer=answer)
  19. llm: StreamableOpenAI = LLMBuilder.to_llm(
  20. tenant_id=tenant_id,
  21. model_name=generate_base_model,
  22. max_tokens=50
  23. )
  24. if isinstance(llm, BaseChatModel):
  25. prompt = [HumanMessage(content=prompt)]
  26. response = llm.generate([prompt])
  27. answer = response.generations[0][0].text
  28. return answer.strip()
  29. @classmethod
  30. def generate_conversation_summary(cls, tenant_id: str, messages):
  31. max_tokens = 200
  32. prompt = CONVERSATION_SUMMARY_PROMPT
  33. prompt_with_empty_context = prompt.format(context='')
  34. prompt_tokens = TokenCalculator.get_num_tokens(generate_base_model, prompt_with_empty_context)
  35. rest_tokens = llm_constant.max_context_token_length[generate_base_model] - prompt_tokens - max_tokens
  36. context = ''
  37. for message in messages:
  38. if not message.answer:
  39. continue
  40. message_qa_text = "Human:" + message.query + "\nAI:" + message.answer + "\n"
  41. if rest_tokens - TokenCalculator.get_num_tokens(generate_base_model, context + message_qa_text) > 0:
  42. context += message_qa_text
  43. prompt = prompt.format(context=context)
  44. llm: StreamableOpenAI = LLMBuilder.to_llm(
  45. tenant_id=tenant_id,
  46. model_name=generate_base_model,
  47. max_tokens=max_tokens
  48. )
  49. if isinstance(llm, BaseChatModel):
  50. prompt = [HumanMessage(content=prompt)]
  51. response = llm.generate([prompt])
  52. answer = response.generations[0][0].text
  53. return answer.strip()
  54. @classmethod
  55. def generate_introduction(cls, tenant_id: str, pre_prompt: str):
  56. prompt = INTRODUCTION_GENERATE_PROMPT
  57. prompt = prompt.format(prompt=pre_prompt)
  58. llm: StreamableOpenAI = LLMBuilder.to_llm(
  59. tenant_id=tenant_id,
  60. model_name=generate_base_model,
  61. )
  62. if isinstance(llm, BaseChatModel):
  63. prompt = [HumanMessage(content=prompt)]
  64. response = llm.generate([prompt])
  65. answer = response.generations[0][0].text
  66. return answer.strip()
  67. @classmethod
  68. def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
  69. output_parser = SuggestedQuestionsAfterAnswerOutputParser()
  70. format_instructions = output_parser.get_format_instructions()
  71. prompt = OutLinePromptTemplate(
  72. template="{histories}\n{format_instructions}\nquestions:\n",
  73. input_variables=["histories"],
  74. partial_variables={"format_instructions": format_instructions}
  75. )
  76. _input = prompt.format_prompt(histories=histories)
  77. llm: StreamableOpenAI = LLMBuilder.to_llm(
  78. tenant_id=tenant_id,
  79. model_name=generate_base_model,
  80. temperature=0,
  81. max_tokens=256
  82. )
  83. if isinstance(llm, BaseChatModel):
  84. query = [HumanMessage(content=_input.to_string())]
  85. else:
  86. query = _input.to_string()
  87. try:
  88. output = llm(query)
  89. questions = output_parser.parse(output)
  90. except Exception:
  91. logging.exception("Error generating suggested questions after answer")
  92. questions = []
  93. return questions
  94. @classmethod
  95. def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
  96. output_parser = RuleConfigGeneratorOutputParser()
  97. prompt = OutLinePromptTemplate(
  98. template=output_parser.get_format_instructions(),
  99. input_variables=["audiences", "hoping_to_solve"],
  100. partial_variables={
  101. "variable": '{variable}',
  102. "lanA": '{lanA}',
  103. "lanB": '{lanB}',
  104. "topic": '{topic}'
  105. },
  106. validate_template=False
  107. )
  108. _input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
  109. llm: StreamableOpenAI = LLMBuilder.to_llm(
  110. tenant_id=tenant_id,
  111. model_name=generate_base_model,
  112. temperature=0,
  113. max_tokens=512
  114. )
  115. if isinstance(llm, BaseChatModel):
  116. query = [HumanMessage(content=_input.to_string())]
  117. else:
  118. query = _input.to_string()
  119. try:
  120. output = llm(query)
  121. rule_config = output_parser.parse(output)
  122. except OutputParserException:
  123. raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
  124. except Exception:
  125. logging.exception("Error generating prompt")
  126. rule_config = {
  127. "prompt": "",
  128. "variables": [],
  129. "opening_statement": ""
  130. }
  131. return rule_config