| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404 | import concurrentimport jsonimport loggingfrom concurrent.futures import ThreadPoolExecutorfrom typing import Optional, List, Union, Tuplefrom flask import current_app, Flaskfrom requests.exceptions import ChunkedEncodingErrorfrom core.agent.agent_executor import AgentExecuteResult, PlanningStrategyfrom core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandlerfrom core.callback_handler.llm_callback_handler import LLMCallbackHandlerfrom core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \    ConversationTaskInterruptExceptionfrom core.external_data_tool.factory import ExternalDataToolFactoryfrom core.model_providers.error import LLMBadRequestErrorfrom core.memory.read_only_conversation_token_db_buffer_shared_memory import \    ReadOnlyConversationTokenDBBufferSharedMemoryfrom core.model_providers.model_factory import ModelFactoryfrom core.model_providers.models.entity.message import PromptMessagefrom core.model_providers.models.llm.base import BaseLLMfrom core.orchestrator_rule_parser import OrchestratorRuleParserfrom core.prompt.prompt_template import PromptTemplateParserfrom core.prompt.prompt_transform import PromptTransformfrom models.model import App, AppModelConfig, Account, Conversation, EndUserfrom core.moderation.base import ModerationException, ModerationActionfrom core.moderation.factory import ModerationFactoryclass Completion:    @classmethod    def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,                 user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,                 is_override: bool = False, retriever_from: str = 'dev'):        """        errors: ProviderTokenNotInitError        """        query = PromptTemplateParser.remove_template_variables(query)        memory = None        if conversation:            # get memory of conversation (read-only)            memory = cls.get_memory_from_conversation(                tenant_id=app.tenant_id,                app_model_config=app_model_config,                conversation=conversation,                return_messages=False            )            inputs = conversation.inputs        final_model_instance = ModelFactory.get_text_generation_model_from_model_config(            tenant_id=app.tenant_id,            model_config=app_model_config.model_dict,            streaming=streaming        )        conversation_message_task = ConversationMessageTask(            task_id=task_id,            app=app,            app_model_config=app_model_config,            user=user,            conversation=conversation,            is_override=is_override,            inputs=inputs,            query=query,            streaming=streaming,            model_instance=final_model_instance        )        rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(            mode=app.mode,            model_instance=final_model_instance,            app_model_config=app_model_config,            query=query,            inputs=inputs        )        # init orchestrator rule parser        orchestrator_rule_parser = OrchestratorRuleParser(            tenant_id=app.tenant_id,            app_model_config=app_model_config        )        try:            chain_callback = MainChainGatherCallbackHandler(conversation_message_task)            try:                # process sensitive_word_avoidance                inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)            except ModerationException as e:                cls.run_final_llm(                    model_instance=final_model_instance,                    mode=app.mode,                    app_model_config=app_model_config,                    query=query,                    inputs=inputs,                    agent_execute_result=None,                    conversation_message_task=conversation_message_task,                    memory=memory,                    fake_response=str(e)                )                return            # fill in variable inputs from external data tools if exists            external_data_tools = app_model_config.external_data_tools_list            if external_data_tools:                inputs = cls.fill_in_inputs_from_external_data_tools(                    tenant_id=app.tenant_id,                    app_id=app.id,                    external_data_tools=external_data_tools,                    inputs=inputs,                    query=query                )            # get agent executor            agent_executor = orchestrator_rule_parser.to_agent_executor(                conversation_message_task=conversation_message_task,                memory=memory,                rest_tokens=rest_tokens_for_context_and_memory,                chain_callback=chain_callback,                retriever_from=retriever_from            )            query_for_agent = cls.get_query_for_agent(app, app_model_config, query, inputs)            # run agent executor            agent_execute_result = None            if query_for_agent and agent_executor:                should_use_agent = agent_executor.should_use_agent(query_for_agent)                if should_use_agent:                    agent_execute_result = agent_executor.run(query_for_agent)            # When no extra pre prompt is specified,            # the output of the agent can be used directly as the main output content without calling LLM again            fake_response = None            if not app_model_config.pre_prompt and agent_execute_result and agent_execute_result.output \                    and agent_execute_result.strategy not in [PlanningStrategy.ROUTER,                                                              PlanningStrategy.REACT_ROUTER]:                fake_response = agent_execute_result.output            # run the final llm            cls.run_final_llm(                model_instance=final_model_instance,                mode=app.mode,                app_model_config=app_model_config,                query=query,                inputs=inputs,                agent_execute_result=agent_execute_result,                conversation_message_task=conversation_message_task,                memory=memory,                fake_response=fake_response            )        except (ConversationTaskInterruptException, ConversationTaskStoppedException):            return        except ChunkedEncodingError as e:            # Interrupt by LLM (like OpenAI), handle it.            logging.warning(f'ChunkedEncodingError: {e}')            conversation_message_task.end()            return    @classmethod    def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):        if not app_model_config.sensitive_word_avoidance_dict['enabled']:            return inputs, query        type = app_model_config.sensitive_word_avoidance_dict['type']        moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])        moderation_result = moderation.moderation_for_inputs(inputs, query)        if not moderation_result.flagged:            return inputs, query        if moderation_result.action == ModerationAction.DIRECT_OUTPUT:            raise ModerationException(moderation_result.preset_response)        elif moderation_result.action == ModerationAction.OVERRIDED:            inputs = moderation_result.inputs            query = moderation_result.query        return inputs, query    @classmethod    def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],                                                inputs: dict, query: str) -> dict:        """        Fill in variable inputs from external data tools if exists.        :param tenant_id: workspace id        :param app_id: app id        :param external_data_tools: external data tools configs        :param inputs: the inputs        :param query: the query        :return: the filled inputs        """        # Group tools by type and config        grouped_tools = {}        for tool in external_data_tools:            if not tool.get("enabled"):                continue            tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))            grouped_tools.setdefault(tool_key, []).append(tool)        results = {}        with ThreadPoolExecutor() as executor:            futures = {}            for tool in external_data_tools:                if not tool.get("enabled"):                    continue                future = executor.submit(                    cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,                    inputs, query                )                futures[future] = tool            for future in concurrent.futures.as_completed(futures):                tool_variable, result = future.result()                results[tool_variable] = result        inputs.update(results)        return inputs    @classmethod    def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,                                 inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:        with flask_app.app_context():            tool_variable = external_data_tool.get("variable")            tool_type = external_data_tool.get("type")            tool_config = external_data_tool.get("config")            external_data_tool_factory = ExternalDataToolFactory(                name=tool_type,                tenant_id=tenant_id,                app_id=app_id,                variable=tool_variable,                config=tool_config            )            # query external data tool            result = external_data_tool_factory.query(                inputs=inputs,                query=query            )            return tool_variable, result    @classmethod    def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:        if app.mode != 'completion':            return query        return inputs.get(app_model_config.dataset_query_variable, "")    @classmethod    def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,                      inputs: dict,                      agent_execute_result: Optional[AgentExecuteResult],                      conversation_message_task: ConversationMessageTask,                      memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],                      fake_response: Optional[str]):        prompt_transform = PromptTransform()        # get llm prompt        if app_model_config.prompt_type == 'simple':            prompt_messages, stop_words = prompt_transform.get_prompt(                mode=mode,                pre_prompt=app_model_config.pre_prompt,                inputs=inputs,                query=query,                context=agent_execute_result.output if agent_execute_result else None,                memory=memory,                model_instance=model_instance            )        else:            prompt_messages = prompt_transform.get_advanced_prompt(                app_mode=mode,                app_model_config=app_model_config,                inputs=inputs,                query=query,                context=agent_execute_result.output if agent_execute_result else None,                memory=memory,                model_instance=model_instance            )            model_config = app_model_config.model_dict            completion_params = model_config.get("completion_params", {})            stop_words = completion_params.get("stop", [])        cls.recale_llm_max_tokens(            model_instance=model_instance,            prompt_messages=prompt_messages,        )        response = model_instance.run(            messages=prompt_messages,            stop=stop_words if stop_words else None,            callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],            fake_response=fake_response        )        return response    @classmethod    def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,                                         max_token_limit: int) -> str:        """Get memory messages."""        memory.max_token_limit = max_token_limit        memory_key = memory.memory_variables[0]        external_context = memory.load_memory_variables({})        return external_context[memory_key]    @classmethod    def get_memory_from_conversation(cls, tenant_id: str, app_model_config: AppModelConfig,                                     conversation: Conversation,                                     **kwargs) -> ReadOnlyConversationTokenDBBufferSharedMemory:        # only for calc token in memory        memory_model_instance = ModelFactory.get_text_generation_model_from_model_config(            tenant_id=tenant_id,            model_config=app_model_config.model_dict        )        # use llm config from conversation        memory = ReadOnlyConversationTokenDBBufferSharedMemory(            conversation=conversation,            model_instance=memory_model_instance,            max_token_limit=kwargs.get("max_token_limit", 2048),            memory_key=kwargs.get("memory_key", "chat_history"),            return_messages=kwargs.get("return_messages", True),            input_key=kwargs.get("input_key", "input"),            output_key=kwargs.get("output_key", "output"),            message_limit=kwargs.get("message_limit", 10),        )        return memory    @classmethod    def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,                                 query: str, inputs: dict) -> int:        model_limited_tokens = model_instance.model_rules.max_tokens.max        max_tokens = model_instance.get_model_kwargs().max_tokens        if model_limited_tokens is None:            return -1        if max_tokens is None:            max_tokens = 0        prompt_transform = PromptTransform()        prompt_messages = []        # get prompt without memory and context        if app_model_config.prompt_type == 'simple':            prompt_messages, _ = prompt_transform.get_prompt(                mode=mode,                pre_prompt=app_model_config.pre_prompt,                inputs=inputs,                query=query,                context=None,                memory=None,                model_instance=model_instance            )        else:            prompt_messages = prompt_transform.get_advanced_prompt(                app_mode=mode,                app_model_config=app_model_config,                inputs=inputs,                query=query,                context=None,                memory=None,                model_instance=model_instance            )        prompt_tokens = model_instance.get_num_tokens(prompt_messages)        rest_tokens = model_limited_tokens - max_tokens - prompt_tokens        if rest_tokens < 0:            raise LLMBadRequestError("Query or prefix prompt is too long, you can reduce the prefix prompt, "                                     "or shrink the max token, or switch to a llm with a larger token limit size.")        return rest_tokens    @classmethod    def recale_llm_max_tokens(cls, model_instance: BaseLLM, prompt_messages: List[PromptMessage]):        # recalc max_tokens if sum(prompt_token +  max_tokens) over model token limit        model_limited_tokens = model_instance.model_rules.max_tokens.max        max_tokens = model_instance.get_model_kwargs().max_tokens        if model_limited_tokens is None:            return        if max_tokens is None:            max_tokens = 0        prompt_tokens = model_instance.get_num_tokens(prompt_messages)        if prompt_tokens + max_tokens > model_limited_tokens:            max_tokens = max(model_limited_tokens - prompt_tokens, 16)            # update model instance max tokens            model_kwargs = model_instance.get_model_kwargs()            model_kwargs.max_tokens = max_tokens            model_instance.set_model_kwargs(model_kwargs)
 |