orchestrator_rule_parser.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. from typing import Optional
  2. from langchain import WikipediaAPIWrapper
  3. from langchain.callbacks.manager import Callbacks
  4. from langchain.memory.chat_memory import BaseChatMemory
  5. from langchain.tools import BaseTool, Tool, WikipediaQueryRun
  6. from pydantic import BaseModel, Field
  7. from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
  8. from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
  9. from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
  10. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  11. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  12. from core.conversation_message_task import ConversationMessageTask
  13. from core.model_providers.error import ProviderTokenNotInitError
  14. from core.model_providers.model_factory import ModelFactory
  15. from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
  16. from core.model_providers.models.llm.base import BaseLLM
  17. from core.tool.current_datetime_tool import DatetimeTool
  18. from core.tool.dataset_retriever_tool import DatasetRetrieverTool
  19. from core.tool.provider.serpapi_provider import SerpAPIToolProvider
  20. from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
  21. from core.tool.web_reader_tool import WebReaderTool
  22. from extensions.ext_database import db
  23. from models.dataset import Dataset, DatasetProcessRule
  24. from models.model import AppModelConfig
  25. class OrchestratorRuleParser:
  26. """Parse the orchestrator rule to entities."""
  27. def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
  28. self.tenant_id = tenant_id
  29. self.app_model_config = app_model_config
  30. def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
  31. rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
  32. retriever_from: str = 'dev') -> Optional[AgentExecutor]:
  33. if not self.app_model_config.agent_mode_dict:
  34. return None
  35. agent_mode_config = self.app_model_config.agent_mode_dict
  36. model_dict = self.app_model_config.model_dict
  37. return_resource = self.app_model_config.retriever_resource_dict.get('enabled', False)
  38. chain = None
  39. if agent_mode_config and agent_mode_config.get('enabled'):
  40. tool_configs = agent_mode_config.get('tools', [])
  41. agent_provider_name = model_dict.get('provider', 'openai')
  42. agent_model_name = model_dict.get('name', 'gpt-4')
  43. dataset_configs = self.app_model_config.dataset_configs_dict
  44. agent_model_instance = ModelFactory.get_text_generation_model(
  45. tenant_id=self.tenant_id,
  46. model_provider_name=agent_provider_name,
  47. model_name=agent_model_name,
  48. model_kwargs=ModelKwargs(
  49. temperature=0.2,
  50. top_p=0.3,
  51. max_tokens=1500
  52. )
  53. )
  54. # add agent callback to record agent thoughts
  55. agent_callback = AgentLoopGatherCallbackHandler(
  56. model_instance=agent_model_instance,
  57. conversation_message_task=conversation_message_task
  58. )
  59. chain_callback.agent_callback = agent_callback
  60. agent_model_instance.add_callbacks([agent_callback])
  61. planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
  62. # only OpenAI chat model (include Azure) support function call, use ReACT instead
  63. if not agent_model_instance.support_function_call:
  64. if planning_strategy == PlanningStrategy.FUNCTION_CALL:
  65. planning_strategy = PlanningStrategy.REACT
  66. elif planning_strategy == PlanningStrategy.ROUTER:
  67. planning_strategy = PlanningStrategy.REACT_ROUTER
  68. try:
  69. summary_model_instance = ModelFactory.get_text_generation_model(
  70. tenant_id=self.tenant_id,
  71. model_provider_name=agent_provider_name,
  72. model_name=agent_model_name,
  73. model_kwargs=ModelKwargs(
  74. temperature=0,
  75. max_tokens=500
  76. ),
  77. deduct_quota=False
  78. )
  79. except ProviderTokenNotInitError as e:
  80. summary_model_instance = None
  81. tools = self.to_tools(
  82. tool_configs=tool_configs,
  83. callbacks=[agent_callback, DifyStdOutCallbackHandler()],
  84. agent_model_instance=agent_model_instance,
  85. conversation_message_task=conversation_message_task,
  86. rest_tokens=rest_tokens,
  87. return_resource=return_resource,
  88. retriever_from=retriever_from,
  89. dataset_configs=dataset_configs
  90. )
  91. if len(tools) == 0:
  92. return None
  93. agent_configuration = AgentConfiguration(
  94. strategy=planning_strategy,
  95. model_instance=agent_model_instance,
  96. tools=tools,
  97. summary_model_instance=summary_model_instance,
  98. memory=memory,
  99. callbacks=[chain_callback, agent_callback],
  100. max_iterations=10,
  101. max_execution_time=400.0,
  102. early_stopping_method="generate"
  103. )
  104. return AgentExecutor(agent_configuration)
  105. return chain
  106. def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
  107. """
  108. Convert app agent tool configs to tools
  109. :param tool_configs: app agent tool configs
  110. :param callbacks:
  111. :return:
  112. """
  113. tools = []
  114. for tool_config in tool_configs:
  115. tool_type = list(tool_config.keys())[0]
  116. tool_val = list(tool_config.values())[0]
  117. if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
  118. continue
  119. tool = None
  120. if tool_type == "dataset":
  121. tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
  122. elif tool_type == "web_reader":
  123. tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
  124. elif tool_type == "google_search":
  125. tool = self.to_google_search_tool(tool_config=tool_val, **kwargs)
  126. elif tool_type == "wikipedia":
  127. tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs)
  128. elif tool_type == "current_datetime":
  129. tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs)
  130. if tool:
  131. if tool.callbacks is not None:
  132. tool.callbacks.extend(callbacks)
  133. else:
  134. tool.callbacks = callbacks
  135. tools.append(tool)
  136. return tools
  137. def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
  138. dataset_configs: dict, rest_tokens: int,
  139. return_resource: bool = False, retriever_from: str = 'dev',
  140. **kwargs) \
  141. -> Optional[BaseTool]:
  142. """
  143. A dataset tool is a tool that can be used to retrieve information from a dataset
  144. :param rest_tokens:
  145. :param tool_config:
  146. :param dataset_configs:
  147. :param conversation_message_task:
  148. :param return_resource:
  149. :param retriever_from:
  150. :return:
  151. """
  152. # get dataset from dataset id
  153. dataset = db.session.query(Dataset).filter(
  154. Dataset.tenant_id == self.tenant_id,
  155. Dataset.id == tool_config.get("id")
  156. ).first()
  157. if not dataset:
  158. return None
  159. if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
  160. return None
  161. top_k = dataset_configs.get("top_k", 2)
  162. # dynamically adjust top_k when the remaining token number is not enough to support top_k
  163. top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
  164. score_threshold = None
  165. score_threshold_config = dataset_configs.get("score_threshold")
  166. if score_threshold_config and score_threshold_config.get("enable"):
  167. score_threshold = score_threshold_config.get("value")
  168. tool = DatasetRetrieverTool.from_dataset(
  169. dataset=dataset,
  170. top_k=top_k,
  171. score_threshold=score_threshold,
  172. callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
  173. conversation_message_task=conversation_message_task,
  174. return_resource=return_resource,
  175. retriever_from=retriever_from
  176. )
  177. return tool
  178. def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
  179. """
  180. A tool for reading web pages
  181. :return:
  182. """
  183. try:
  184. summary_model_instance = ModelFactory.get_text_generation_model(
  185. tenant_id=self.tenant_id,
  186. model_provider_name=agent_model_instance.model_provider.provider_name,
  187. model_name=agent_model_instance.name,
  188. model_kwargs=ModelKwargs(
  189. temperature=0,
  190. max_tokens=500
  191. ),
  192. deduct_quota=False
  193. )
  194. except ProviderTokenNotInitError:
  195. summary_model_instance = None
  196. tool = WebReaderTool(
  197. model_instance=summary_model_instance if summary_model_instance else None,
  198. max_chunk_length=4000,
  199. continue_reading=True
  200. )
  201. return tool
  202. def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
  203. tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
  204. func_kwargs = tool_provider.credentials_to_func_kwargs()
  205. if not func_kwargs:
  206. return None
  207. tool = Tool(
  208. name="google_search",
  209. description="A tool for performing a Google search and extracting snippets and webpages "
  210. "when you need to search for something you don't know or when your information "
  211. "is not up to date. "
  212. "Input should be a search query.",
  213. func=OptimizedSerpAPIWrapper(**func_kwargs).run,
  214. args_schema=OptimizedSerpAPIInput
  215. )
  216. return tool
  217. def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
  218. tool = DatetimeTool()
  219. return tool
  220. def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
  221. class WikipediaInput(BaseModel):
  222. query: str = Field(..., description="search query.")
  223. return WikipediaQueryRun(
  224. name="wikipedia",
  225. api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
  226. args_schema=WikipediaInput
  227. )
  228. @classmethod
  229. def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int:
  230. if rest_tokens == -1:
  231. return top_k
  232. processing_rule = dataset.latest_process_rule
  233. if not processing_rule:
  234. return top_k
  235. if processing_rule.mode == "custom":
  236. rules = processing_rule.rules_dict
  237. if not rules:
  238. return top_k
  239. segmentation = rules["segmentation"]
  240. segment_max_tokens = segmentation["max_tokens"]
  241. else:
  242. segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
  243. # when rest_tokens is less than default context tokens
  244. if rest_tokens < segment_max_tokens * top_k:
  245. return rest_tokens // segment_max_tokens
  246. return min(top_k, 10)