prompt_template.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. import re
  2. from typing import Any
  3. from jinja2 import Environment, meta
  4. from langchain import PromptTemplate
  5. from langchain.formatting import StrictFormatter
  6. class JinjaPromptTemplate(PromptTemplate):
  7. template_format: str = "jinja2"
  8. """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
  9. @classmethod
  10. def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
  11. """Load a prompt template from a template."""
  12. env = Environment()
  13. template = template.replace("{{}}", "{}")
  14. ast = env.parse(template)
  15. input_variables = meta.find_undeclared_variables(ast)
  16. if "partial_variables" in kwargs:
  17. partial_variables = kwargs["partial_variables"]
  18. input_variables = {
  19. var for var in input_variables if var not in partial_variables
  20. }
  21. return cls(
  22. input_variables=list(sorted(input_variables)), template=template, **kwargs
  23. )
  24. class OutLinePromptTemplate(PromptTemplate):
  25. @classmethod
  26. def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
  27. """Load a prompt template from a template."""
  28. input_variables = {
  29. v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
  30. }
  31. return cls(
  32. input_variables=list(sorted(input_variables)), template=template, **kwargs
  33. )
  34. def format(self, **kwargs: Any) -> str:
  35. """Format the prompt with the inputs.
  36. Args:
  37. kwargs: Any arguments to be passed to the prompt template.
  38. Returns:
  39. A formatted string.
  40. Example:
  41. .. code-block:: python
  42. prompt.format(variable1="foo")
  43. """
  44. kwargs = self._merge_partial_and_user_variables(**kwargs)
  45. return OneLineFormatter().format(self.template, **kwargs)
  46. class OneLineFormatter(StrictFormatter):
  47. def parse(self, format_string):
  48. last_end = 0
  49. results = []
  50. for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
  51. field_name = match.group(1)
  52. start, end = match.span()
  53. literal_text = format_string[last_end:start]
  54. last_end = end
  55. results.append((literal_text, field_name, '', None))
  56. remaining_literal_text = format_string[last_end:]
  57. if remaining_literal_text:
  58. results.append((remaining_literal_text, None, None, None))
  59. return results