agent_runner.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. import logging
  2. from typing import cast, Optional, List
  3. from langchain import WikipediaAPIWrapper
  4. from langchain.callbacks.base import BaseCallbackHandler
  5. from langchain.tools import BaseTool, WikipediaQueryRun, Tool
  6. from pydantic import BaseModel, Field
  7. from core.agent.agent.agent_llm_callback import AgentLLMCallback
  8. from core.agent.agent_executor import PlanningStrategy, AgentConfiguration, AgentExecutor
  9. from core.application_queue_manager import ApplicationQueueManager
  10. from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
  11. from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
  12. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  13. from core.entities.application_entities import ModelConfigEntity, InvokeFrom, \
  14. AgentEntity, AgentToolEntity, AppOrchestrationConfigEntity
  15. from core.memory.token_buffer_memory import TokenBufferMemory
  16. from core.model_runtime.entities.model_entities import ModelFeature, ModelType
  17. from core.model_runtime.model_providers import model_provider_factory
  18. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  19. from core.tools.tool.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool
  20. from extensions.ext_database import db
  21. from models.dataset import Dataset
  22. from models.model import Message
  23. logger = logging.getLogger(__name__)
  24. class AgentRunnerFeature:
  25. def __init__(self, tenant_id: str,
  26. app_orchestration_config: AppOrchestrationConfigEntity,
  27. model_config: ModelConfigEntity,
  28. config: AgentEntity,
  29. queue_manager: ApplicationQueueManager,
  30. message: Message,
  31. user_id: str,
  32. agent_llm_callback: AgentLLMCallback,
  33. callback: AgentLoopGatherCallbackHandler,
  34. memory: Optional[TokenBufferMemory] = None,) -> None:
  35. """
  36. Agent runner
  37. :param tenant_id: tenant id
  38. :param app_orchestration_config: app orchestration config
  39. :param model_config: model config
  40. :param config: dataset config
  41. :param queue_manager: queue manager
  42. :param message: message
  43. :param user_id: user id
  44. :param agent_llm_callback: agent llm callback
  45. :param callback: callback
  46. :param memory: memory
  47. """
  48. self.tenant_id = tenant_id
  49. self.app_orchestration_config = app_orchestration_config
  50. self.model_config = model_config
  51. self.config = config
  52. self.queue_manager = queue_manager
  53. self.message = message
  54. self.user_id = user_id
  55. self.agent_llm_callback = agent_llm_callback
  56. self.callback = callback
  57. self.memory = memory
  58. def run(self, query: str,
  59. invoke_from: InvokeFrom) -> Optional[str]:
  60. """
  61. Retrieve agent loop result.
  62. :param query: query
  63. :param invoke_from: invoke from
  64. :return:
  65. """
  66. provider = self.config.provider
  67. model = self.config.model
  68. tool_configs = self.config.tools
  69. # check model is support tool calling
  70. provider_instance = model_provider_factory.get_provider_instance(provider=provider)
  71. model_type_instance = provider_instance.get_model_instance(ModelType.LLM)
  72. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  73. # get model schema
  74. model_schema = model_type_instance.get_model_schema(
  75. model=model,
  76. credentials=self.model_config.credentials
  77. )
  78. if not model_schema:
  79. return None
  80. planning_strategy = PlanningStrategy.REACT
  81. features = model_schema.features
  82. if features:
  83. if ModelFeature.TOOL_CALL in features \
  84. or ModelFeature.MULTI_TOOL_CALL in features:
  85. planning_strategy = PlanningStrategy.FUNCTION_CALL
  86. tools = self.to_tools(
  87. tool_configs=tool_configs,
  88. invoke_from=invoke_from,
  89. callbacks=[self.callback, DifyStdOutCallbackHandler()],
  90. )
  91. if len(tools) == 0:
  92. return None
  93. agent_configuration = AgentConfiguration(
  94. strategy=planning_strategy,
  95. model_config=self.model_config,
  96. tools=tools,
  97. memory=self.memory,
  98. max_iterations=10,
  99. max_execution_time=400.0,
  100. early_stopping_method="generate",
  101. agent_llm_callback=self.agent_llm_callback,
  102. callbacks=[self.callback, DifyStdOutCallbackHandler()]
  103. )
  104. agent_executor = AgentExecutor(agent_configuration)
  105. try:
  106. # check if should use agent
  107. should_use_agent = agent_executor.should_use_agent(query)
  108. if not should_use_agent:
  109. return None
  110. result = agent_executor.run(query)
  111. return result.output
  112. except Exception as ex:
  113. logger.exception("agent_executor run failed")
  114. return None
  115. def to_dataset_retriever_tool(self, tool_config: dict,
  116. invoke_from: InvokeFrom) \
  117. -> Optional[BaseTool]:
  118. """
  119. A dataset tool is a tool that can be used to retrieve information from a dataset
  120. :param tool_config: tool config
  121. :param invoke_from: invoke from
  122. """
  123. show_retrieve_source = self.app_orchestration_config.show_retrieve_source
  124. hit_callback = DatasetIndexToolCallbackHandler(
  125. queue_manager=self.queue_manager,
  126. app_id=self.message.app_id,
  127. message_id=self.message.id,
  128. user_id=self.user_id,
  129. invoke_from=invoke_from
  130. )
  131. # get dataset from dataset id
  132. dataset = db.session.query(Dataset).filter(
  133. Dataset.tenant_id == self.tenant_id,
  134. Dataset.id == tool_config.get("id")
  135. ).first()
  136. # pass if dataset is not available
  137. if not dataset:
  138. return None
  139. # pass if dataset is not available
  140. if (dataset and dataset.available_document_count == 0
  141. and dataset.available_document_count == 0):
  142. return None
  143. # get retrieval model config
  144. default_retrieval_model = {
  145. 'search_method': 'semantic_search',
  146. 'reranking_enable': False,
  147. 'reranking_model': {
  148. 'reranking_provider_name': '',
  149. 'reranking_model_name': ''
  150. },
  151. 'top_k': 2,
  152. 'score_threshold_enabled': False
  153. }
  154. retrieval_model_config = dataset.retrieval_model \
  155. if dataset.retrieval_model else default_retrieval_model
  156. # get top k
  157. top_k = retrieval_model_config['top_k']
  158. # get score threshold
  159. score_threshold = None
  160. score_threshold_enabled = retrieval_model_config.get("score_threshold_enabled")
  161. if score_threshold_enabled:
  162. score_threshold = retrieval_model_config.get("score_threshold")
  163. tool = DatasetRetrieverTool.from_dataset(
  164. dataset=dataset,
  165. top_k=top_k,
  166. score_threshold=score_threshold,
  167. hit_callbacks=[hit_callback],
  168. return_resource=show_retrieve_source,
  169. retriever_from=invoke_from.to_source()
  170. )
  171. return tool