prompt_template.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import re
  2. from typing import Any
  3. from jinja2 import Environment, meta
  4. from langchain import PromptTemplate
  5. from langchain.formatting import StrictFormatter
  6. class JinjaPromptTemplate(PromptTemplate):
  7. template_format: str = "jinja2"
  8. """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
  9. @classmethod
  10. def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
  11. """Load a prompt template from a template."""
  12. env = Environment()
  13. ast = env.parse(template)
  14. input_variables = meta.find_undeclared_variables(ast)
  15. if "partial_variables" in kwargs:
  16. partial_variables = kwargs["partial_variables"]
  17. input_variables = {
  18. var for var in input_variables if var not in partial_variables
  19. }
  20. return cls(
  21. input_variables=list(sorted(input_variables)), template=template, **kwargs
  22. )
  23. class OutLinePromptTemplate(PromptTemplate):
  24. @classmethod
  25. def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
  26. """Load a prompt template from a template."""
  27. input_variables = {
  28. v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
  29. }
  30. return cls(
  31. input_variables=list(sorted(input_variables)), template=template, **kwargs
  32. )
  33. def format(self, **kwargs: Any) -> str:
  34. """Format the prompt with the inputs.
  35. Args:
  36. kwargs: Any arguments to be passed to the prompt template.
  37. Returns:
  38. A formatted string.
  39. Example:
  40. .. code-block:: python
  41. prompt.format(variable1="foo")
  42. """
  43. kwargs = self._merge_partial_and_user_variables(**kwargs)
  44. return OneLineFormatter().format(self.template, **kwargs)
  45. class OneLineFormatter(StrictFormatter):
  46. def parse(self, format_string):
  47. last_end = 0
  48. results = []
  49. for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
  50. field_name = match.group(1)
  51. start, end = match.span()
  52. literal_text = format_string[last_end:start]
  53. last_end = end
  54. results.append((literal_text, field_name, '', None))
  55. remaining_literal_text = format_string[last_end:]
  56. if remaining_literal_text:
  57. results.append((remaining_literal_text, None, None, None))
  58. return results