| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341 | import enumimport jsonimport osfrom collections.abc import Mapping, Sequencefrom typing import TYPE_CHECKING, Any, Optional, castfrom core.app.app_config.entities import PromptTemplateEntityfrom core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntityfrom core.file import file_managerfrom core.memory.token_buffer_memory import TokenBufferMemoryfrom core.model_runtime.entities.message_entities import (    ImagePromptMessageContent,    PromptMessage,    PromptMessageContent,    SystemPromptMessage,    TextPromptMessageContent,    UserPromptMessage,)from core.prompt.entities.advanced_prompt_entities import MemoryConfigfrom core.prompt.prompt_transform import PromptTransformfrom core.prompt.utils.prompt_template_parser import PromptTemplateParserfrom models.model import AppModeif TYPE_CHECKING:    from core.file.models import Fileclass ModelMode(enum.StrEnum):    COMPLETION = "completion"    CHAT = "chat"    @classmethod    def value_of(cls, value: str) -> "ModelMode":        """        Get value of given mode.        :param value: mode value        :return: mode        """        for mode in cls:            if mode.value == value:                return mode        raise ValueError(f"invalid mode value {value}")prompt_file_contents: dict[str, Any] = {}class SimplePromptTransform(PromptTransform):    """    Simple Prompt Transform for Chatbot App Basic Mode.    """    def get_prompt(        self,        app_mode: AppMode,        prompt_template_entity: PromptTemplateEntity,        inputs: Mapping[str, str],        query: str,        files: Sequence["File"],        context: Optional[str],        memory: Optional[TokenBufferMemory],        model_config: ModelConfigWithCredentialsEntity,        image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,    ) -> tuple[list[PromptMessage], Optional[list[str]]]:        inputs = {key: str(value) for key, value in inputs.items()}        model_mode = ModelMode.value_of(model_config.mode)        if model_mode == ModelMode.CHAT:            prompt_messages, stops = self._get_chat_model_prompt_messages(                app_mode=app_mode,                pre_prompt=prompt_template_entity.simple_prompt_template or "",                inputs=inputs,                query=query,                files=files,                context=context,                memory=memory,                model_config=model_config,                image_detail_config=image_detail_config,            )        else:            prompt_messages, stops = self._get_completion_model_prompt_messages(                app_mode=app_mode,                pre_prompt=prompt_template_entity.simple_prompt_template or "",                inputs=inputs,                query=query,                files=files,                context=context,                memory=memory,                model_config=model_config,                image_detail_config=image_detail_config,            )        return prompt_messages, stops    def _get_prompt_str_and_rules(        self,        app_mode: AppMode,        model_config: ModelConfigWithCredentialsEntity,        pre_prompt: str,        inputs: dict,        query: Optional[str] = None,        context: Optional[str] = None,        histories: Optional[str] = None,    ) -> tuple[str, dict]:        # get prompt template        prompt_template_config = self.get_prompt_template(            app_mode=app_mode,            provider=model_config.provider,            model=model_config.model,            pre_prompt=pre_prompt,            has_context=context is not None,            query_in_prompt=query is not None,            with_memory_prompt=histories is not None,        )        variables = {k: inputs[k] for k in prompt_template_config["custom_variable_keys"] if k in inputs}        for v in prompt_template_config["special_variable_keys"]:            # support #context#, #query# and #histories#            if v == "#context#":                variables["#context#"] = context or ""            elif v == "#query#":                variables["#query#"] = query or ""            elif v == "#histories#":                variables["#histories#"] = histories or ""        prompt_template = prompt_template_config["prompt_template"]        prompt = prompt_template.format(variables)        return prompt, prompt_template_config["prompt_rules"]    def get_prompt_template(        self,        app_mode: AppMode,        provider: str,        model: str,        pre_prompt: str,        has_context: bool,        query_in_prompt: bool,        with_memory_prompt: bool = False,    ) -> dict:        prompt_rules = self._get_prompt_rule(app_mode=app_mode, provider=provider, model=model)        custom_variable_keys = []        special_variable_keys = []        prompt = ""        for order in prompt_rules["system_prompt_orders"]:            if order == "context_prompt" and has_context:                prompt += prompt_rules["context_prompt"]                special_variable_keys.append("#context#")            elif order == "pre_prompt" and pre_prompt:                prompt += pre_prompt + "\n"                pre_prompt_template = PromptTemplateParser(template=pre_prompt)                custom_variable_keys = pre_prompt_template.variable_keys            elif order == "histories_prompt" and with_memory_prompt:                prompt += prompt_rules["histories_prompt"]                special_variable_keys.append("#histories#")        if query_in_prompt:            prompt += prompt_rules.get("query_prompt", "{{#query#}}")            special_variable_keys.append("#query#")        return {            "prompt_template": PromptTemplateParser(template=prompt),            "custom_variable_keys": custom_variable_keys,            "special_variable_keys": special_variable_keys,            "prompt_rules": prompt_rules,        }    def _get_chat_model_prompt_messages(        self,        app_mode: AppMode,        pre_prompt: str,        inputs: dict,        query: str,        context: Optional[str],        files: Sequence["File"],        memory: Optional[TokenBufferMemory],        model_config: ModelConfigWithCredentialsEntity,        image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,    ) -> tuple[list[PromptMessage], Optional[list[str]]]:        prompt_messages: list[PromptMessage] = []        # get prompt        prompt, _ = self._get_prompt_str_and_rules(            app_mode=app_mode,            model_config=model_config,            pre_prompt=pre_prompt,            inputs=inputs,            query=None,            context=context,        )        if prompt and query:            prompt_messages.append(SystemPromptMessage(content=prompt))        if memory:            prompt_messages = self._append_chat_histories(                memory=memory,                memory_config=MemoryConfig(                    window=MemoryConfig.WindowConfig(                        enabled=False,                    )                ),                prompt_messages=prompt_messages,                model_config=model_config,            )        if query:            prompt_messages.append(self._get_last_user_message(query, files, image_detail_config))        else:            prompt_messages.append(self._get_last_user_message(prompt, files, image_detail_config))        return prompt_messages, None    def _get_completion_model_prompt_messages(        self,        app_mode: AppMode,        pre_prompt: str,        inputs: dict,        query: str,        context: Optional[str],        files: Sequence["File"],        memory: Optional[TokenBufferMemory],        model_config: ModelConfigWithCredentialsEntity,        image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,    ) -> tuple[list[PromptMessage], Optional[list[str]]]:        # get prompt        prompt, prompt_rules = self._get_prompt_str_and_rules(            app_mode=app_mode,            model_config=model_config,            pre_prompt=pre_prompt,            inputs=inputs,            query=query,            context=context,        )        if memory:            tmp_human_message = UserPromptMessage(content=prompt)            rest_tokens = self._calculate_rest_token([tmp_human_message], model_config)            histories = self._get_history_messages_from_memory(                memory=memory,                memory_config=MemoryConfig(                    window=MemoryConfig.WindowConfig(                        enabled=False,                    )                ),                max_token_limit=rest_tokens,                human_prefix=prompt_rules.get("human_prefix", "Human"),                ai_prefix=prompt_rules.get("assistant_prefix", "Assistant"),            )            # get prompt            prompt, prompt_rules = self._get_prompt_str_and_rules(                app_mode=app_mode,                model_config=model_config,                pre_prompt=pre_prompt,                inputs=inputs,                query=query,                context=context,                histories=histories,            )        stops = prompt_rules.get("stops")        if stops is not None and len(stops) == 0:            stops = None        return [self._get_last_user_message(prompt, files, image_detail_config)], stops    def _get_last_user_message(        self,        prompt: str,        files: Sequence["File"],        image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,    ) -> UserPromptMessage:        if files:            prompt_message_contents: list[PromptMessageContent] = []            prompt_message_contents.append(TextPromptMessageContent(data=prompt))            for file in files:                prompt_message_contents.append(                    file_manager.to_prompt_message_content(file, image_detail_config=image_detail_config)                )            prompt_message = UserPromptMessage(content=prompt_message_contents)        else:            prompt_message = UserPromptMessage(content=prompt)        return prompt_message    def _get_prompt_rule(self, app_mode: AppMode, provider: str, model: str) -> dict:        """        Get simple prompt rule.        :param app_mode: app mode        :param provider: model provider        :param model: model name        :return:        """        prompt_file_name = self._prompt_file_name(app_mode=app_mode, provider=provider, model=model)        # Check if the prompt file is already loaded        if prompt_file_name in prompt_file_contents:            return cast(dict, prompt_file_contents[prompt_file_name])        # Get the absolute path of the subdirectory        prompt_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "prompt_templates")        json_file_path = os.path.join(prompt_path, f"{prompt_file_name}.json")        # Open the JSON file and read its content        with open(json_file_path, encoding="utf-8") as json_file:            content = json.load(json_file)            # Store the content of the prompt file            prompt_file_contents[prompt_file_name] = content            return cast(dict, content)    def _prompt_file_name(self, app_mode: AppMode, provider: str, model: str) -> str:        # baichuan        is_baichuan = False        if provider == "baichuan":            is_baichuan = True        else:            baichuan_supported_providers = ["huggingface_hub", "openllm", "xinference"]            if provider in baichuan_supported_providers and "baichuan" in model.lower():                is_baichuan = True        if is_baichuan:            if app_mode == AppMode.COMPLETION:                return "baichuan_completion"            else:                return "baichuan_chat"        # common        if app_mode == AppMode.COMPLETION:            return "common_completion"        else:            return "common_chat"
 |