llm_router_chain.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. """Base classes for LLM-powered router chains."""
  2. from __future__ import annotations
  3. import json
  4. from typing import Any, Dict, List, Optional, Type, cast, NamedTuple
  5. from langchain.chains.base import Chain
  6. from pydantic import root_validator
  7. from langchain.chains import LLMChain
  8. from langchain.prompts import BasePromptTemplate
  9. from langchain.schema import BaseOutputParser, OutputParserException, BaseLanguageModel
  10. from libs.json_in_md_parser import parse_and_check_json_markdown
  11. class Route(NamedTuple):
  12. destination: Optional[str]
  13. next_inputs: Dict[str, Any]
  14. class LLMRouterChain(Chain):
  15. """A router chain that uses an LLM chain to perform routing."""
  16. llm_chain: LLMChain
  17. """LLM chain used to perform routing"""
  18. @root_validator()
  19. def validate_prompt(cls, values: dict) -> dict:
  20. prompt = values["llm_chain"].prompt
  21. if prompt.output_parser is None:
  22. raise ValueError(
  23. "LLMRouterChain requires base llm_chain prompt to have an output"
  24. " parser that converts LLM text output to a dictionary with keys"
  25. " 'destination' and 'next_inputs'. Received a prompt with no output"
  26. " parser."
  27. )
  28. return values
  29. @property
  30. def input_keys(self) -> List[str]:
  31. """Will be whatever keys the LLM chain prompt expects.
  32. :meta private:
  33. """
  34. return self.llm_chain.input_keys
  35. def _validate_outputs(self, outputs: Dict[str, Any]) -> None:
  36. super()._validate_outputs(outputs)
  37. if not isinstance(outputs["next_inputs"], dict):
  38. raise ValueError
  39. def _call(
  40. self,
  41. inputs: Dict[str, Any]
  42. ) -> Dict[str, Any]:
  43. output = cast(
  44. Dict[str, Any],
  45. self.llm_chain.predict_and_parse(**inputs),
  46. )
  47. return output
  48. @classmethod
  49. def from_llm(
  50. cls, llm: BaseLanguageModel, prompt: BasePromptTemplate, **kwargs: Any
  51. ) -> LLMRouterChain:
  52. """Convenience constructor."""
  53. llm_chain = LLMChain(llm=llm, prompt=prompt)
  54. return cls(llm_chain=llm_chain, **kwargs)
  55. @property
  56. def output_keys(self) -> List[str]:
  57. return ["destination", "next_inputs"]
  58. def route(self, inputs: Dict[str, Any]) -> Route:
  59. result = self(inputs)
  60. return Route(result["destination"], result["next_inputs"])
  61. class RouterOutputParser(BaseOutputParser[Dict[str, str]]):
  62. """Parser for output of router chain int he multi-prompt chain."""
  63. default_destination: str = "DEFAULT"
  64. next_inputs_type: Type = str
  65. next_inputs_inner_key: str = "input"
  66. def parse(self, text: str) -> Dict[str, Any]:
  67. try:
  68. expected_keys = ["destination", "next_inputs"]
  69. parsed = parse_and_check_json_markdown(text, expected_keys)
  70. if not isinstance(parsed["destination"], str):
  71. raise ValueError("Expected 'destination' to be a string.")
  72. if not isinstance(parsed["next_inputs"], self.next_inputs_type):
  73. raise ValueError(
  74. f"Expected 'next_inputs' to be {self.next_inputs_type}."
  75. )
  76. parsed["next_inputs"] = {self.next_inputs_inner_key: parsed["next_inputs"]}
  77. if (
  78. parsed["destination"].strip().lower()
  79. == self.default_destination.lower()
  80. ):
  81. parsed["destination"] = None
  82. else:
  83. parsed["destination"] = parsed["destination"].strip()
  84. return parsed
  85. except Exception as e:
  86. raise OutputParserException(
  87. f"Parsing text\n{text}\n of llm router raised following error:\n{e}"
  88. )