structured_chat.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import re
  2. from typing import Any, List, Optional, Sequence, Tuple, Union, cast
  3. from core.agent.agent.agent_llm_callback import AgentLLMCallback
  4. from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
  5. from core.chain.llm_chain import LLMChain
  6. from core.entities.application_entities import ModelConfigEntity
  7. from core.entities.message_entities import lc_messages_to_prompt_messages
  8. from langchain import BasePromptTemplate, PromptTemplate
  9. from langchain.agents import Agent, AgentOutputParser, StructuredChatAgent
  10. from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
  11. from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
  12. from langchain.callbacks.base import BaseCallbackManager
  13. from langchain.callbacks.manager import Callbacks
  14. from langchain.memory.prompt import SUMMARY_PROMPT
  15. from langchain.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, SystemMessagePromptTemplate
  16. from langchain.schema import (AgentAction, AgentFinish, AIMessage, BaseMessage, HumanMessage, OutputParserException,
  17. get_buffer_string)
  18. from langchain.tools import BaseTool
  19. FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
  20. The nouns in the format of "Thought", "Action", "Action Input", "Final Answer" must be expressed in English.
  21. Valid "action" values: "Final Answer" or {tool_names}
  22. Provide only ONE action per $JSON_BLOB, as shown:
  23. ```
  24. {{{{
  25. "action": $TOOL_NAME,
  26. "action_input": $INPUT
  27. }}}}
  28. ```
  29. Follow this format:
  30. Question: input question to answer
  31. Thought: consider previous and subsequent steps
  32. Action:
  33. ```
  34. $JSON_BLOB
  35. ```
  36. Observation: action result
  37. ... (repeat Thought/Action/Observation N times)
  38. Thought: I know what to respond
  39. Action:
  40. ```
  41. {{{{
  42. "action": "Final Answer",
  43. "action_input": "Final response to human"
  44. }}}}
  45. ```"""
  46. class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
  47. moving_summary_buffer: str = ""
  48. moving_summary_index: int = 0
  49. summary_model_config: ModelConfigEntity = None
  50. class Config:
  51. """Configuration for this pydantic object."""
  52. arbitrary_types_allowed = True
  53. def should_use_agent(self, query: str):
  54. """
  55. return should use agent
  56. Using the ReACT mode to determine whether an agent is needed is costly,
  57. so it's better to just use an Agent for reasoning, which is cheaper.
  58. :param query:
  59. :return:
  60. """
  61. return True
  62. def plan(
  63. self,
  64. intermediate_steps: List[Tuple[AgentAction, str]],
  65. callbacks: Callbacks = None,
  66. **kwargs: Any,
  67. ) -> Union[AgentAction, AgentFinish]:
  68. """Given input, decided what to do.
  69. Args:
  70. intermediate_steps: Steps the LLM has taken to date,
  71. along with observatons
  72. callbacks: Callbacks to run.
  73. **kwargs: User inputs.
  74. Returns:
  75. Action specifying what tool to use.
  76. """
  77. full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
  78. prompts, _ = self.llm_chain.prep_prompts(input_list=[self.llm_chain.prep_inputs(full_inputs)])
  79. messages = []
  80. if prompts:
  81. messages = prompts[0].to_messages()
  82. prompt_messages = lc_messages_to_prompt_messages(messages)
  83. rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_config, prompt_messages)
  84. if rest_tokens < 0:
  85. full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
  86. try:
  87. full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
  88. except Exception as e:
  89. raise e
  90. try:
  91. agent_decision = self.output_parser.parse(full_output)
  92. if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
  93. tool_inputs = agent_decision.tool_input
  94. if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
  95. tool_inputs['query'] = kwargs['input']
  96. agent_decision.tool_input = tool_inputs
  97. return agent_decision
  98. except OutputParserException:
  99. return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
  100. "I don't know how to respond to that."}, "")
  101. def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
  102. if len(intermediate_steps) >= 2 and self.summary_model_config:
  103. should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
  104. should_summary_messages = [AIMessage(content=observation)
  105. for _, observation in should_summary_intermediate_steps]
  106. if self.moving_summary_index == 0:
  107. should_summary_messages.insert(0, HumanMessage(content=kwargs.get("input")))
  108. self.moving_summary_index = len(intermediate_steps)
  109. else:
  110. error_msg = "Exceeded LLM tokens limit, stopped."
  111. raise ExceededLLMTokensLimitError(error_msg)
  112. if self.moving_summary_buffer and 'chat_history' in kwargs:
  113. kwargs["chat_history"].pop()
  114. self.moving_summary_buffer = self.predict_new_summary(
  115. messages=should_summary_messages,
  116. existing_summary=self.moving_summary_buffer
  117. )
  118. if 'chat_history' in kwargs:
  119. kwargs["chat_history"].append(AIMessage(content=self.moving_summary_buffer))
  120. return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
  121. def predict_new_summary(
  122. self, messages: List[BaseMessage], existing_summary: str
  123. ) -> str:
  124. new_lines = get_buffer_string(
  125. messages,
  126. human_prefix="Human",
  127. ai_prefix="AI",
  128. )
  129. chain = LLMChain(model_config=self.summary_model_config, prompt=SUMMARY_PROMPT)
  130. return chain.predict(summary=existing_summary, new_lines=new_lines)
  131. @classmethod
  132. def create_prompt(
  133. cls,
  134. tools: Sequence[BaseTool],
  135. prefix: str = PREFIX,
  136. suffix: str = SUFFIX,
  137. human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
  138. format_instructions: str = FORMAT_INSTRUCTIONS,
  139. input_variables: Optional[List[str]] = None,
  140. memory_prompts: Optional[List[BasePromptTemplate]] = None,
  141. ) -> BasePromptTemplate:
  142. tool_strings = []
  143. for tool in tools:
  144. args_schema = re.sub("}", "}}}}", re.sub("{", "{{{{", str(tool.args)))
  145. tool_strings.append(f"{tool.name}: {tool.description}, args: {args_schema}")
  146. formatted_tools = "\n".join(tool_strings)
  147. tool_names = ", ".join([('"' + tool.name + '"') for tool in tools])
  148. format_instructions = format_instructions.format(tool_names=tool_names)
  149. template = "\n\n".join([prefix, formatted_tools, format_instructions, suffix])
  150. if input_variables is None:
  151. input_variables = ["input", "agent_scratchpad"]
  152. _memory_prompts = memory_prompts or []
  153. messages = [
  154. SystemMessagePromptTemplate.from_template(template),
  155. *_memory_prompts,
  156. HumanMessagePromptTemplate.from_template(human_message_template),
  157. ]
  158. return ChatPromptTemplate(input_variables=input_variables, messages=messages)
  159. @classmethod
  160. def create_completion_prompt(
  161. cls,
  162. tools: Sequence[BaseTool],
  163. prefix: str = PREFIX,
  164. format_instructions: str = FORMAT_INSTRUCTIONS,
  165. input_variables: Optional[List[str]] = None,
  166. ) -> PromptTemplate:
  167. """Create prompt in the style of the zero shot agent.
  168. Args:
  169. tools: List of tools the agent will have access to, used to format the
  170. prompt.
  171. prefix: String to put before the list of tools.
  172. input_variables: List of input variables the final prompt will expect.
  173. Returns:
  174. A PromptTemplate with the template assembled from the pieces here.
  175. """
  176. suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
  177. Question: {input}
  178. Thought: {agent_scratchpad}
  179. """
  180. tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
  181. tool_names = ", ".join([tool.name for tool in tools])
  182. format_instructions = format_instructions.format(tool_names=tool_names)
  183. template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
  184. if input_variables is None:
  185. input_variables = ["input", "agent_scratchpad"]
  186. return PromptTemplate(template=template, input_variables=input_variables)
  187. def _construct_scratchpad(
  188. self, intermediate_steps: List[Tuple[AgentAction, str]]
  189. ) -> str:
  190. agent_scratchpad = ""
  191. for action, observation in intermediate_steps:
  192. agent_scratchpad += action.log
  193. agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
  194. if not isinstance(agent_scratchpad, str):
  195. raise ValueError("agent_scratchpad should be of type string.")
  196. if agent_scratchpad:
  197. llm_chain = cast(LLMChain, self.llm_chain)
  198. if llm_chain.model_config.mode == "chat":
  199. return (
  200. f"This was your previous work "
  201. f"(but I haven't seen any of it! I only see what "
  202. f"you return as final answer):\n{agent_scratchpad}"
  203. )
  204. else:
  205. return agent_scratchpad
  206. else:
  207. return agent_scratchpad
  208. @classmethod
  209. def from_llm_and_tools(
  210. cls,
  211. model_config: ModelConfigEntity,
  212. tools: Sequence[BaseTool],
  213. callback_manager: Optional[BaseCallbackManager] = None,
  214. output_parser: Optional[AgentOutputParser] = None,
  215. prefix: str = PREFIX,
  216. suffix: str = SUFFIX,
  217. human_message_template: str = HUMAN_MESSAGE_TEMPLATE,
  218. format_instructions: str = FORMAT_INSTRUCTIONS,
  219. input_variables: Optional[List[str]] = None,
  220. memory_prompts: Optional[List[BasePromptTemplate]] = None,
  221. agent_llm_callback: Optional[AgentLLMCallback] = None,
  222. **kwargs: Any,
  223. ) -> Agent:
  224. """Construct an agent from an LLM and tools."""
  225. cls._validate_tools(tools)
  226. if model_config.mode == "chat":
  227. prompt = cls.create_prompt(
  228. tools,
  229. prefix=prefix,
  230. suffix=suffix,
  231. human_message_template=human_message_template,
  232. format_instructions=format_instructions,
  233. input_variables=input_variables,
  234. memory_prompts=memory_prompts,
  235. )
  236. else:
  237. prompt = cls.create_completion_prompt(
  238. tools,
  239. prefix=prefix,
  240. format_instructions=format_instructions,
  241. input_variables=input_variables,
  242. )
  243. llm_chain = LLMChain(
  244. model_config=model_config,
  245. prompt=prompt,
  246. callback_manager=callback_manager,
  247. agent_llm_callback=agent_llm_callback,
  248. parameters={
  249. 'temperature': 0.2,
  250. 'top_p': 0.3,
  251. 'max_tokens': 1500
  252. }
  253. )
  254. tool_names = [tool.name for tool in tools]
  255. _output_parser = output_parser
  256. return cls(
  257. llm_chain=llm_chain,
  258. allowed_tools=tool_names,
  259. output_parser=_output_parser,
  260. **kwargs,
  261. )