orchestrator_rule_parser.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. import math
  2. from typing import Optional
  3. from langchain import WikipediaAPIWrapper
  4. from langchain.callbacks.manager import Callbacks
  5. from langchain.memory.chat_memory import BaseChatMemory
  6. from langchain.tools import BaseTool, Tool, WikipediaQueryRun
  7. from pydantic import BaseModel, Field
  8. from core.agent.agent_executor import AgentExecutor, PlanningStrategy, AgentConfiguration
  9. from core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandler
  10. from core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandler
  11. from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
  12. from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
  13. from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceChain
  14. from core.conversation_message_task import ConversationMessageTask
  15. from core.model_providers.error import ProviderTokenNotInitError
  16. from core.model_providers.model_factory import ModelFactory
  17. from core.model_providers.models.entity.model_params import ModelKwargs, ModelMode
  18. from core.tool.dataset_retriever_tool import DatasetRetrieverTool
  19. from core.tool.provider.serpapi_provider import SerpAPIToolProvider
  20. from core.tool.serpapi_wrapper import OptimizedSerpAPIWrapper, OptimizedSerpAPIInput
  21. from core.tool.web_reader_tool import WebReaderTool
  22. from extensions.ext_database import db
  23. from libs import helper
  24. from models.dataset import Dataset, DatasetProcessRule
  25. from models.model import AppModelConfig
  26. class OrchestratorRuleParser:
  27. """Parse the orchestrator rule to entities."""
  28. def __init__(self, tenant_id: str, app_model_config: AppModelConfig):
  29. self.tenant_id = tenant_id
  30. self.app_model_config = app_model_config
  31. def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
  32. rest_tokens: int, chain_callback: MainChainGatherCallbackHandler) \
  33. -> Optional[AgentExecutor]:
  34. if not self.app_model_config.agent_mode_dict:
  35. return None
  36. agent_mode_config = self.app_model_config.agent_mode_dict
  37. model_dict = self.app_model_config.model_dict
  38. chain = None
  39. if agent_mode_config and agent_mode_config.get('enabled'):
  40. tool_configs = agent_mode_config.get('tools', [])
  41. agent_provider_name = model_dict.get('provider', 'openai')
  42. agent_model_name = model_dict.get('name', 'gpt-4')
  43. agent_model_instance = ModelFactory.get_text_generation_model(
  44. tenant_id=self.tenant_id,
  45. model_provider_name=agent_provider_name,
  46. model_name=agent_model_name,
  47. model_kwargs=ModelKwargs(
  48. temperature=0.2,
  49. top_p=0.3,
  50. max_tokens=1500
  51. )
  52. )
  53. # add agent callback to record agent thoughts
  54. agent_callback = AgentLoopGatherCallbackHandler(
  55. model_instant=agent_model_instance,
  56. conversation_message_task=conversation_message_task
  57. )
  58. chain_callback.agent_callback = agent_callback
  59. agent_model_instance.add_callbacks([agent_callback])
  60. planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
  61. # only OpenAI chat model (include Azure) support function call, use ReACT instead
  62. if agent_model_instance.model_mode != ModelMode.CHAT \
  63. or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
  64. if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
  65. planning_strategy = PlanningStrategy.REACT
  66. elif planning_strategy == PlanningStrategy.ROUTER:
  67. planning_strategy = PlanningStrategy.REACT_ROUTER
  68. try:
  69. summary_model_instance = ModelFactory.get_text_generation_model(
  70. tenant_id=self.tenant_id,
  71. model_kwargs=ModelKwargs(
  72. temperature=0,
  73. max_tokens=500
  74. )
  75. )
  76. except ProviderTokenNotInitError as e:
  77. summary_model_instance = None
  78. tools = self.to_tools(
  79. tool_configs=tool_configs,
  80. conversation_message_task=conversation_message_task,
  81. rest_tokens=rest_tokens,
  82. callbacks=[agent_callback, DifyStdOutCallbackHandler()]
  83. )
  84. if len(tools) == 0:
  85. return None
  86. agent_configuration = AgentConfiguration(
  87. strategy=planning_strategy,
  88. model_instance=agent_model_instance,
  89. tools=tools,
  90. summary_model_instance=summary_model_instance,
  91. memory=memory,
  92. callbacks=[chain_callback, agent_callback],
  93. max_iterations=10,
  94. max_execution_time=400.0,
  95. early_stopping_method="generate"
  96. )
  97. return AgentExecutor(agent_configuration)
  98. return chain
  99. def to_sensitive_word_avoidance_chain(self, callbacks: Callbacks = None, **kwargs) \
  100. -> Optional[SensitiveWordAvoidanceChain]:
  101. """
  102. Convert app sensitive word avoidance config to chain
  103. :param kwargs:
  104. :return:
  105. """
  106. if not self.app_model_config.sensitive_word_avoidance_dict:
  107. return None
  108. sensitive_word_avoidance_config = self.app_model_config.sensitive_word_avoidance_dict
  109. sensitive_words = sensitive_word_avoidance_config.get("words", "")
  110. if sensitive_word_avoidance_config.get("enabled", False) and sensitive_words:
  111. return SensitiveWordAvoidanceChain(
  112. sensitive_words=sensitive_words.split(","),
  113. canned_response=sensitive_word_avoidance_config.get("canned_response", ''),
  114. output_key="sensitive_word_avoidance_output",
  115. callbacks=callbacks,
  116. **kwargs
  117. )
  118. return None
  119. def to_tools(self, tool_configs: list, conversation_message_task: ConversationMessageTask,
  120. rest_tokens: int, callbacks: Callbacks = None) -> list[BaseTool]:
  121. """
  122. Convert app agent tool configs to tools
  123. :param rest_tokens:
  124. :param tool_configs: app agent tool configs
  125. :param conversation_message_task:
  126. :param callbacks:
  127. :return:
  128. """
  129. tools = []
  130. for tool_config in tool_configs:
  131. tool_type = list(tool_config.keys())[0]
  132. tool_val = list(tool_config.values())[0]
  133. if not tool_val.get("enabled") or tool_val.get("enabled") is not True:
  134. continue
  135. tool = None
  136. if tool_type == "dataset":
  137. tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens)
  138. elif tool_type == "web_reader":
  139. tool = self.to_web_reader_tool()
  140. elif tool_type == "google_search":
  141. tool = self.to_google_search_tool()
  142. elif tool_type == "wikipedia":
  143. tool = self.to_wikipedia_tool()
  144. elif tool_type == "current_datetime":
  145. tool = self.to_current_datetime_tool()
  146. if tool:
  147. tool.callbacks.extend(callbacks)
  148. tools.append(tool)
  149. return tools
  150. def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
  151. rest_tokens: int) \
  152. -> Optional[BaseTool]:
  153. """
  154. A dataset tool is a tool that can be used to retrieve information from a dataset
  155. :param rest_tokens:
  156. :param tool_config:
  157. :param conversation_message_task:
  158. :return:
  159. """
  160. # get dataset from dataset id
  161. dataset = db.session.query(Dataset).filter(
  162. Dataset.tenant_id == self.tenant_id,
  163. Dataset.id == tool_config.get("id")
  164. ).first()
  165. if not dataset:
  166. return None
  167. if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
  168. return None
  169. k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
  170. tool = DatasetRetrieverTool.from_dataset(
  171. dataset=dataset,
  172. k=k,
  173. callbacks=[DatasetToolCallbackHandler(conversation_message_task)]
  174. )
  175. return tool
  176. def to_web_reader_tool(self) -> Optional[BaseTool]:
  177. """
  178. A tool for reading web pages
  179. :return:
  180. """
  181. summary_model_instance = ModelFactory.get_text_generation_model(
  182. tenant_id=self.tenant_id,
  183. model_kwargs=ModelKwargs(
  184. temperature=0,
  185. max_tokens=500
  186. )
  187. )
  188. summary_llm = summary_model_instance.client
  189. tool = WebReaderTool(
  190. llm=summary_llm,
  191. max_chunk_length=4000,
  192. continue_reading=True,
  193. callbacks=[DifyStdOutCallbackHandler()]
  194. )
  195. return tool
  196. def to_google_search_tool(self) -> Optional[BaseTool]:
  197. tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
  198. func_kwargs = tool_provider.credentials_to_func_kwargs()
  199. if not func_kwargs:
  200. return None
  201. tool = Tool(
  202. name="google_search",
  203. description="A tool for performing a Google search and extracting snippets and webpages "
  204. "when you need to search for something you don't know or when your information "
  205. "is not up to date. "
  206. "Input should be a search query.",
  207. func=OptimizedSerpAPIWrapper(**func_kwargs).run,
  208. args_schema=OptimizedSerpAPIInput,
  209. callbacks=[DifyStdOutCallbackHandler()]
  210. )
  211. return tool
  212. def to_current_datetime_tool(self) -> Optional[BaseTool]:
  213. tool = Tool(
  214. name="current_datetime",
  215. description="A tool when you want to get the current date, time, week, month or year, "
  216. "and the time zone is UTC. Result is \"<date> <time> <timezone> <week>\".",
  217. func=helper.get_current_datetime,
  218. callbacks=[DifyStdOutCallbackHandler()]
  219. )
  220. return tool
  221. def to_wikipedia_tool(self) -> Optional[BaseTool]:
  222. class WikipediaInput(BaseModel):
  223. query: str = Field(..., description="search query.")
  224. return WikipediaQueryRun(
  225. name="wikipedia",
  226. api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
  227. args_schema=WikipediaInput,
  228. callbacks=[DifyStdOutCallbackHandler()]
  229. )
  230. @classmethod
  231. def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
  232. DEFAULT_K = 2
  233. CONTEXT_TOKENS_PERCENT = 0.3
  234. if rest_tokens == -1:
  235. return DEFAULT_K
  236. processing_rule = dataset.latest_process_rule
  237. if not processing_rule:
  238. return DEFAULT_K
  239. if processing_rule.mode == "custom":
  240. rules = processing_rule.rules_dict
  241. if not rules:
  242. return DEFAULT_K
  243. segmentation = rules["segmentation"]
  244. segment_max_tokens = segmentation["max_tokens"]
  245. else:
  246. segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
  247. # when rest_tokens is less than default context tokens
  248. if rest_tokens < segment_max_tokens * DEFAULT_K:
  249. return rest_tokens // segment_max_tokens
  250. context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
  251. # when context_limit_tokens is less than default context tokens, use default_k
  252. if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
  253. return DEFAULT_K
  254. # Expand the k value when there's still some room left in the 30% rest tokens space
  255. return context_limit_tokens // segment_max_tokens