prompt_builder.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839
  1. import re
  2. from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
  3. from langchain.schema import BaseMessage
  4. from core.prompt.prompt_template import JinjaPromptTemplate
  5. class PromptBuilder:
  6. @classmethod
  7. def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
  8. prompt_template = JinjaPromptTemplate.from_template(prompt_content)
  9. system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
  10. prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
  11. system_message = system_prompt_template.format(**prompt_inputs)
  12. return system_message
  13. @classmethod
  14. def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
  15. prompt_template = JinjaPromptTemplate.from_template(prompt_content)
  16. ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
  17. prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
  18. ai_message = ai_prompt_template.format(**prompt_inputs)
  19. return ai_message
  20. @classmethod
  21. def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
  22. prompt_template = JinjaPromptTemplate.from_template(prompt_content)
  23. human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
  24. human_message = human_prompt_template.format(**inputs)
  25. return human_message
  26. @classmethod
  27. def process_template(cls, template: str):
  28. processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
  29. # processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
  30. # processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
  31. return processed_template