| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 | from typing import Optional, Listfrom langchain.callbacks import SharedCallbackManagerfrom langchain.chains import SequentialChainfrom langchain.chains.base import Chainfrom langchain.memory.chat_memory import BaseChatMemoryfrom core.agent.agent_builder import AgentBuilderfrom core.callback_handler.agent_loop_gather_callback_handler import AgentLoopGatherCallbackHandlerfrom core.callback_handler.dataset_tool_callback_handler import DatasetToolCallbackHandlerfrom core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandlerfrom core.chain.chain_builder import ChainBuilderfrom core.constant import llm_constantfrom core.conversation_message_task import ConversationMessageTaskfrom core.tool.dataset_tool_builder import DatasetToolBuilderclass MainChainBuilder:    @classmethod    def to_langchain_components(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],                                conversation_message_task: ConversationMessageTask):        first_input_key = "input"        final_output_key = "output"        chains = []        chain_callback_handler = MainChainGatherCallbackHandler(conversation_message_task)        # agent mode        tool_chains, chains_output_key = cls.get_agent_chains(            tenant_id=tenant_id,            agent_mode=agent_mode,            memory=memory,            dataset_tool_callback_handler=DatasetToolCallbackHandler(conversation_message_task),            agent_loop_gather_callback_handler=chain_callback_handler.agent_loop_gather_callback_handler        )        chains += tool_chains        if chains_output_key:            final_output_key = chains_output_key        if len(chains) == 0:            return None        for chain in chains:            # do not add handler into singleton callback manager            if not isinstance(chain.callback_manager, SharedCallbackManager):                chain.callback_manager.add_handler(chain_callback_handler)        # build main chain        overall_chain = SequentialChain(            chains=chains,            input_variables=[first_input_key],            output_variables=[final_output_key],            memory=memory,  # only for use the memory prompt input key        )        return overall_chain    @classmethod    def get_agent_chains(cls, tenant_id: str, agent_mode: dict, memory: Optional[BaseChatMemory],                         dataset_tool_callback_handler: DatasetToolCallbackHandler,                         agent_loop_gather_callback_handler: AgentLoopGatherCallbackHandler):        # agent mode        chains = []        if agent_mode and agent_mode.get('enabled'):            tools = agent_mode.get('tools', [])            pre_fixed_chains = []            agent_tools = []            for tool in tools:                tool_type = list(tool.keys())[0]                tool_config = list(tool.values())[0]                if tool_type == 'sensitive-word-avoidance':                    chain = ChainBuilder.to_sensitive_word_avoidance_chain(tool_config)                    if chain:                        pre_fixed_chains.append(chain)                elif tool_type == "dataset":                    dataset_tool = DatasetToolBuilder.build_dataset_tool(                        tenant_id=tenant_id,                        dataset_id=tool_config.get("id"),                        response_mode='no_synthesizer',  # "compact"                        callback_handler=dataset_tool_callback_handler                    )                    if dataset_tool:                        agent_tools.append(dataset_tool)            # add pre-fixed chains            chains += pre_fixed_chains            if len(agent_tools) == 1:                # tool to chain                tool_chain = ChainBuilder.to_tool_chain(tool=agent_tools[0], output_key='tool_output')                chains.append(tool_chain)            elif len(agent_tools) > 1:                # build agent config                agent_chain = AgentBuilder.to_agent_chain(                    tenant_id=tenant_id,                    tools=agent_tools,                    memory=memory,                    dataset_tool_callback_handler=dataset_tool_callback_handler,                    agent_loop_gather_callback_handler=agent_loop_gather_callback_handler                )                chains.append(agent_chain)        final_output_key = cls.get_chains_output_key(chains)        return chains, final_output_key    @classmethod    def get_chains_output_key(cls, chains: List[Chain]):        if len(chains) > 0:            return chains[-1].output_keys[0]        return None
 |