orchestrator_rule_parser.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. import json
  2. import threading
  3. from typing import Optional, List
  4. from flask import Flask
  5. from langchain import WikipediaAPIWrapper
  6. from langchain.callbacks.manager import Callbacks
  7. from langchain.memory.chat_memory import BaseChatMemory
  8. from langchain.tools import BaseTool, Tool, WikipediaQueryRun
  9. from pydantic import BaseModel, Field
  10. from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
  11. from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
  12. from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
  13. from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
  14. from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
  15. from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
  16. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  17. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  18. from core.conversation_message_task import ConversationMessageTask
  19. from core.model_providers.error import ProviderTokenNotInitError
  20. from core.model_providers.model_factory import ModelFactory
  21. from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
  22. from core.model_providers.models.llm.base import BaseLLM
  23. from core.tool.current_datetime_tool import DatetimeTool
  24. from core.tool.dataset_multi_retriever_tool import DatasetMultiRetrieverTool
  25. from core.tool.dataset_retriever_tool import DatasetRetrieverTool
  26. from core.tool.provider.serpapi_provider import SerpAPIToolProvider
  27. from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
  28. from core.tool.web_reader_tool import WebReaderTool
  29. from extensions.ext_database import db
  30. from models.dataset import Dataset, DatasetProcessRule
  31. from models.model import AppModelConfig
  32. default_retrieval_model = {
  33. 'search_method': 'semantic_search',
  34. 'reranking_enable': False,
  35. 'reranking_model': {
  36. 'reranking_provider_name': '',
  37. 'reranking_model_name': ''
  38. },
  39. 'top_k': 2,
  40. 'score_threshold_enable': False
  41. }
  42. class OrchestratorRuleParser:
  43. """Parse the orchestrator rule to entities."""
  44. def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
  45. self.tenant_id = tenant_id
  46. self.app_model_config = app_model_config
  47. def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
  48. rest_tokens: int, chain_callback: MainChainGatherCallbackHandler, tenant_id: str,
  49. retriever_from: str = 'dev') -> Optional[AgentExecutor]:
  50. if not self.app_model_config.agent_mode_dict:
  51. return None
  52. agent_mode_config = self.app_model_config.agent_mode_dict
  53. model_dict = self.app_model_config.model_dict
  54. return_resource = self.app_model_config.retriever_resource_dict.get('enabled', False)
  55. chain = None
  56. if agent_mode_config and agent_mode_config.get('enabled'):
  57. tool_configs = agent_mode_config.get('tools', [])
  58. agent_provider_name = model_dict.get('provider', 'openai')
  59. agent_model_name = model_dict.get('name', 'gpt-4')
  60. dataset_configs = self.app_model_config.dataset_configs_dict
  61. agent_model_instance = ModelFactory.get_text_generation_model(
  62. tenant_id=self.tenant_id,
  63. model_provider_name=agent_provider_name,
  64. model_name=agent_model_name,
  65. model_kwargs=ModelKwargs(
  66. temperature=0.2,
  67. top_p=0.3,
  68. max_tokens=1500
  69. )
  70. )
  71. # add agent callback to record agent thoughts
  72. agent_callback = AgentLoopGatherCallbackHandler(
  73. model_instance=agent_model_instance,
  74. conversation_message_task=conversation_message_task
  75. )
  76. chain_callback.agent_callback = agent_callback
  77. agent_model_instance.add_callbacks([agent_callback])
  78. planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
  79. # only OpenAI chat model (include Azure) support function call, use ReACT instead
  80. if not agent_model_instance.support_function_call:
  81. if planning_strategy == PlanningStrategy.FUNCTION_CALL:
  82. planning_strategy = PlanningStrategy.REACT
  83. elif planning_strategy == PlanningStrategy.ROUTER:
  84. planning_strategy = PlanningStrategy.REACT_ROUTER
  85. try:
  86. summary_model_instance = ModelFactory.get_text_generation_model(
  87. tenant_id=self.tenant_id,
  88. model_provider_name=agent_provider_name,
  89. model_name=agent_model_name,
  90. model_kwargs=ModelKwargs(
  91. temperature=0,
  92. max_tokens=500
  93. ),
  94. deduct_quota=False
  95. )
  96. except ProviderTokenNotInitError as e:
  97. summary_model_instance = None
  98. tools = self.to_tools(
  99. tool_configs=tool_configs,
  100. callbacks=[agent_callback, DifyStdOutCallbackHandler()],
  101. agent_model_instance=agent_model_instance,
  102. conversation_message_task=conversation_message_task,
  103. rest_tokens=rest_tokens,
  104. return_resource=return_resource,
  105. retriever_from=retriever_from,
  106. dataset_configs=dataset_configs,
  107. tenant_id=tenant_id
  108. )
  109. if len(tools) == 0:
  110. return None
  111. agent_configuration = AgentConfiguration(
  112. strategy=planning_strategy,
  113. model_instance=agent_model_instance,
  114. tools=tools,
  115. summary_model_instance=summary_model_instance,
  116. memory=memory,
  117. callbacks=[chain_callback, agent_callback],
  118. max_iterations=10,
  119. max_execution_time=400.0,
  120. early_stopping_method="generate"
  121. )
  122. return AgentExecutor(agent_configuration)
  123. return chain
  124. def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
  125. """
  126. Convert app agent tool configs to tools
  127. :param tool_configs: app agent tool configs
  128. :param callbacks:
  129. :return:
  130. """
  131. tools = []
  132. dataset_tools = []
  133. for tool_config in tool_configs:
  134. tool_type = list(tool_config.keys())[0]
  135. tool_val = list(tool_config.values())[0]
  136. if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
  137. continue
  138. tool = None
  139. if tool_type == "dataset":
  140. dataset_tools.append(tool_config)
  141. elif tool_type == "web_reader":
  142. tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
  143. elif tool_type == "google_search":
  144. tool = self.to_google_search_tool(tool_config=tool_val, **kwargs)
  145. elif tool_type == "wikipedia":
  146. tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs)
  147. elif tool_type == "current_datetime":
  148. tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs)
  149. if tool:
  150. if tool.callbacks is not None:
  151. tool.callbacks.extend(callbacks)
  152. else:
  153. tool.callbacks = callbacks
  154. tools.append(tool)
  155. # format dataset tool
  156. if len(dataset_tools) > 0:
  157. dataset_retriever_tools = self.to_dataset_retriever_tool(tool_configs=dataset_tools, **kwargs)
  158. if dataset_retriever_tools:
  159. tools.extend(dataset_retriever_tools)
  160. return tools
  161. def to_dataset_retriever_tool(self, tool_configs: List, conversation_message_task: ConversationMessageTask,
  162. return_resource: bool = False, retriever_from: str = 'dev',
  163. **kwargs) \
  164. -> Optional[List[BaseTool]]:
  165. """
  166. A dataset tool is a tool that can be used to retrieve information from a dataset
  167. :param tool_configs:
  168. :param conversation_message_task:
  169. :param return_resource:
  170. :param retriever_from:
  171. :return:
  172. """
  173. dataset_configs = kwargs['dataset_configs']
  174. retrieval_model = dataset_configs.get('retrieval_model', 'single')
  175. tools = []
  176. dataset_ids = []
  177. tenant_id = None
  178. for tool_config in tool_configs:
  179. # get dataset from dataset id
  180. dataset = db.session.query(Dataset).filter(
  181. Dataset.tenant_id == self.tenant_id,
  182. Dataset.id == tool_config.get('dataset').get("id")
  183. ).first()
  184. if not dataset:
  185. continue
  186. if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
  187. continue
  188. dataset_ids.append(dataset.id)
  189. if retrieval_model == 'single':
  190. retrieval_model = dataset.retrieval_model if dataset.retrieval_model else default_retrieval_model
  191. top_k = retrieval_model['top_k']
  192. # dynamically adjust top_k when the remaining token number is not enough to support top_k
  193. # top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
  194. score_threshold = None
  195. score_threshold_enable = retrieval_model.get("score_threshold_enable")
  196. if score_threshold_enable:
  197. score_threshold = retrieval_model.get("score_threshold")
  198. tool = DatasetRetrieverTool.from_dataset(
  199. dataset=dataset,
  200. top_k=top_k,
  201. score_threshold=score_threshold,
  202. callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
  203. conversation_message_task=conversation_message_task,
  204. return_resource=return_resource,
  205. retriever_from=retriever_from
  206. )
  207. tools.append(tool)
  208. if retrieval_model == 'multiple':
  209. tool = DatasetMultiRetrieverTool.from_dataset(
  210. dataset_ids=dataset_ids,
  211. tenant_id=kwargs['tenant_id'],
  212. top_k=dataset_configs.get('top_k', 2),
  213. score_threshold=dataset_configs.get('score_threshold', 0.5) if dataset_configs.get('score_threshold_enable', False) else None,
  214. callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
  215. conversation_message_task=conversation_message_task,
  216. return_resource=return_resource,
  217. retriever_from=retriever_from,
  218. reranking_provider_name=dataset_configs.get('reranking_model').get('reranking_provider_name'),
  219. reranking_model_name=dataset_configs.get('reranking_model').get('reranking_model_name')
  220. )
  221. tools.append(tool)
  222. return tools
  223. def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
  224. """
  225. A tool for reading web pages
  226. :return:
  227. """
  228. try:
  229. summary_model_instance = ModelFactory.get_text_generation_model(
  230. tenant_id=self.tenant_id,
  231. model_provider_name=agent_model_instance.model_provider.provider_name,
  232. model_name=agent_model_instance.name,
  233. model_kwargs=ModelKwargs(
  234. temperature=0,
  235. max_tokens=500
  236. ),
  237. deduct_quota=False
  238. )
  239. except ProviderTokenNotInitError:
  240. summary_model_instance = None
  241. tool = WebReaderTool(
  242. model_instance=summary_model_instance if summary_model_instance else None,
  243. max_chunk_length=4000,
  244. continue_reading=True
  245. )
  246. return tool
  247. def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
  248. tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
  249. func_kwargs = tool_provider.credentials_to_func_kwargs()
  250. if not func_kwargs:
  251. return None
  252. tool = Tool(
  253. name="google_search",
  254. description="A tool for performing a Google search and extracting snippets and webpages "
  255. "when you need to search for something you don't know or when your information "
  256. "is not up to date. "
  257. "Input should be a search query.",
  258. func=OptimizedSerpAPIWrapper(**func_kwargs).run,
  259. args_schema=OptimizedSerpAPIInput
  260. )
  261. return tool
  262. def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
  263. tool = DatetimeTool()
  264. return tool
  265. def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
  266. class WikipediaInput(BaseModel):
  267. query: str = Field(..., description="search query.")
  268. return WikipediaQueryRun(
  269. name="wikipedia",
  270. api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
  271. args_schema=WikipediaInput
  272. )
  273. @classmethod
  274. def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int:
  275. if rest_tokens == -1:
  276. return top_k
  277. processing_rule = dataset.latest_process_rule
  278. if not processing_rule:
  279. return top_k
  280. if processing_rule.mode == "custom":
  281. rules = processing_rule.rules_dict
  282. if not rules:
  283. return top_k
  284. segmentation = rules["segmentation"]
  285. segment_max_tokens = segmentation["max_tokens"]
  286. else:
  287. segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
  288. # when rest_tokens is less than default context tokens
  289. if rest_tokens < segment_max_tokens * top_k:
  290. return rest_tokens // segment_max_tokens
  291. return min(top_k, 10)