llm_generator.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. import json
  2. import logging
  3. from langchain.schema import OutputParserException
  4. from core.model_providers.error import LLMError, ProviderTokenNotInitError
  5. from core.model_providers.model_factory import ModelFactory
  6. from core.model_providers.models.entity.message import PromptMessage, MessageType
  7. from core.model_providers.models.entity.model_params import ModelKwargs
  8. from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
  9. from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
  10. from core.prompt.prompt_template import PromptTemplateParser
  11. from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
  12. class LLMGenerator:
  13. @classmethod
  14. def generate_conversation_name(cls, tenant_id: str, query):
  15. prompt = CONVERSATION_TITLE_PROMPT
  16. if len(query) > 2000:
  17. query = query[:300] + "...[TRUNCATED]..." + query[-300:]
  18. query = query.replace("\n", " ")
  19. prompt += query + "\n"
  20. model_instance = ModelFactory.get_text_generation_model(
  21. tenant_id=tenant_id,
  22. model_kwargs=ModelKwargs(
  23. temperature=1,
  24. max_tokens=100
  25. )
  26. )
  27. prompts = [PromptMessage(content=prompt)]
  28. response = model_instance.run(prompts)
  29. answer = response.content
  30. result_dict = json.loads(answer)
  31. answer = result_dict['Your Output']
  32. name = answer.strip()
  33. if len(name) > 75:
  34. name = name[:75] + '...'
  35. return name
  36. @classmethod
  37. def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
  38. output_parser = SuggestedQuestionsAfterAnswerOutputParser()
  39. format_instructions = output_parser.get_format_instructions()
  40. prompt_template = PromptTemplateParser(
  41. template="{{histories}}\n{{format_instructions}}\nquestions:\n"
  42. )
  43. prompt = prompt_template.format({
  44. "histories": histories,
  45. "format_instructions": format_instructions
  46. })
  47. try:
  48. model_instance = ModelFactory.get_text_generation_model(
  49. tenant_id=tenant_id,
  50. model_kwargs=ModelKwargs(
  51. max_tokens=256,
  52. temperature=0
  53. )
  54. )
  55. except ProviderTokenNotInitError:
  56. return []
  57. prompt_messages = [PromptMessage(content=prompt)]
  58. try:
  59. output = model_instance.run(prompt_messages)
  60. questions = output_parser.parse(output.content)
  61. except LLMError:
  62. questions = []
  63. except Exception as e:
  64. logging.exception(e)
  65. questions = []
  66. return questions
  67. @classmethod
  68. def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
  69. output_parser = RuleConfigGeneratorOutputParser()
  70. prompt_template = PromptTemplateParser(
  71. template=output_parser.get_format_instructions()
  72. )
  73. prompt = prompt_template.format(
  74. inputs={
  75. "audiences": audiences,
  76. "hoping_to_solve": hoping_to_solve,
  77. "variable": "{{variable}}",
  78. "lanA": "{{lanA}}",
  79. "lanB": "{{lanB}}",
  80. "topic": "{{topic}}"
  81. },
  82. remove_template_variables=False
  83. )
  84. model_instance = ModelFactory.get_text_generation_model(
  85. tenant_id=tenant_id,
  86. model_kwargs=ModelKwargs(
  87. max_tokens=512,
  88. temperature=0
  89. )
  90. )
  91. prompt_messages = [PromptMessage(content=prompt)]
  92. try:
  93. output = model_instance.run(prompt_messages)
  94. rule_config = output_parser.parse(output.content)
  95. except LLMError as e:
  96. raise e
  97. except OutputParserException:
  98. raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
  99. except Exception as e:
  100. logging.exception(e)
  101. rule_config = {
  102. "prompt": "",
  103. "variables": [],
  104. "opening_statement": ""
  105. }
  106. return rule_config
  107. @classmethod
  108. def generate_qa_document(cls, tenant_id: str, query, document_language: str):
  109. prompt = GENERATOR_QA_PROMPT.format(language=document_language)
  110. model_instance = ModelFactory.get_text_generation_model(
  111. tenant_id=tenant_id,
  112. model_kwargs=ModelKwargs(
  113. max_tokens=2000
  114. )
  115. )
  116. prompts = [
  117. PromptMessage(content=prompt, type=MessageType.SYSTEM),
  118. PromptMessage(content=query)
  119. ]
  120. response = model_instance.run(prompts)
  121. answer = response.content
  122. return answer.strip()