agent_builder.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. from typing import Optional
  2. from langchain import LLMChain
  3. from langchain.agents import ZeroShotAgent, AgentExecutor, ConversationalAgent
  4. from langchain.callbacks.manager import CallbackManager
  5. from langchain.memory.chat_memory import BaseChatMemory
  6. from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
  7. from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
  8. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  9. from core.llm.llm_builder import LLMBuilder
  10. class AgentBuilder:
  11. @classmethod
  12. def to_agent_chain(cls, tenant_id: str, tools, memory: Optional[BaseChatMemory],
  13. dataset_tool_callback_handler: DatasetToolCallbackHandler,
  14. agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):
  15. llm = LLMBuilder.to_llm(
  16. tenant_id=tenant_id,
  17. model_name=agent_loop_gather_callback_handler.model_name,
  18. temperature=0,
  19. max_tokens=1024,
  20. callbacks=[agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
  21. )
  22. for tool in tools:
  23. tool.callbacks = [
  24. agent_loop_gather_callback_handler,
  25. dataset_tool_callback_handler,
  26. DifyStdOutCallbackHandler()
  27. ]
  28. prompt = cls.build_agent_prompt_template(
  29. tools=tools,
  30. memory=memory,
  31. )
  32. agent_llm_chain = LLMChain(
  33. llm=llm,
  34. prompt=prompt,
  35. )
  36. agent = cls.build_agent(agent_llm_chain=agent_llm_chain, memory=memory)
  37. agent_callback_manager = CallbackManager(
  38. [agent_loop_gather_callback_handler, DifyStdOutCallbackHandler()]
  39. )
  40. agent_chain = AgentExecutor.from_agent_and_tools(
  41. tools=tools,
  42. agent=agent,
  43. memory=memory,
  44. callbacks=agent_callback_manager,
  45. max_iterations=6,
  46. early_stopping_method="generate",
  47. # `generate` will continue to complete the last inference after reaching the iteration limit or request time limit
  48. )
  49. return agent_chain
  50. @classmethod
  51. def build_agent_prompt_template(cls, tools, memory: Optional[BaseChatMemory]):
  52. if memory:
  53. prompt = ConversationalAgent.create_prompt(
  54. tools=tools,
  55. )
  56. else:
  57. prompt = ZeroShotAgent.create_prompt(
  58. tools=tools,
  59. )
  60. return prompt
  61. @classmethod
  62. def build_agent(cls, agent_llm_chain: LLMChain, memory: Optional[BaseChatMemory]):
  63. if memory:
  64. agent = ConversationalAgent(
  65. llm_chain=agent_llm_chain
  66. )
  67. else:
  68. agent = ZeroShotAgent(
  69. llm_chain=agent_llm_chain
  70. )
  71. return agent