Explorar o código

feat: add baichuan prompt (#985)

takatost hai 1 ano
pai
achega
2c30d19cbe

+ 10 - 120
api/core/completion.py

@@ -130,13 +130,12 @@ class Completion:
             fake_response = agent_execute_result.output
 
         # get llm prompt
-        prompt_messages, stop_words = cls.get_main_llm_prompt(
+        prompt_messages, stop_words = model_instance.get_prompt(
             mode=mode,
-            model=app_model_config.model_dict,
             pre_prompt=app_model_config.pre_prompt,
-            query=query,
             inputs=inputs,
-            agent_execute_result=agent_execute_result,
+            query=query,
+            context=agent_execute_result.output if agent_execute_result else None,
             memory=memory
         )
 
@@ -155,113 +154,6 @@ class Completion:
         return response
 
     @classmethod
-    def get_main_llm_prompt(cls, mode: str, model: dict,
-                            pre_prompt: str, query: str, inputs: dict,
-                            agent_execute_result: Optional[AgentExecuteResult],
-                            memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
-            Tuple[List[PromptMessage], Optional[List[str]]]:
-        if mode == 'completion':
-            prompt_template = JinjaPromptTemplate.from_template(
-                template=("""Use the following context as your learned knowledge, inside <context></context> XML tags.
-
-<context>
-{{context}}
-</context>
-
-When answer to user:
-- If you don't know, just say that you don't know.
-- If you don't know when you are not sure, ask for clarification. 
-Avoid mentioning that you obtained the information from the context.
-And answer according to the language of the user's question.
-""" if agent_execute_result else "")
-                         + (pre_prompt + "\n" if pre_prompt else "")
-                         + "{{query}}\n"
-            )
-
-            if agent_execute_result:
-                inputs['context'] = agent_execute_result.output
-
-            prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
-            prompt_content = prompt_template.format(
-                query=query,
-                **prompt_inputs
-            )
-
-            return [PromptMessage(content=prompt_content)], None
-        else:
-            messages: List[BaseMessage] = []
-
-            human_inputs = {
-                "query": query
-            }
-
-            human_message_prompt = ""
-
-            if pre_prompt:
-                pre_prompt_inputs = {k: inputs[k] for k in
-                                     JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
-                                     if k in inputs}
-
-                if pre_prompt_inputs:
-                    human_inputs.update(pre_prompt_inputs)
-
-            if agent_execute_result:
-                human_inputs['context'] = agent_execute_result.output
-                human_message_prompt += """Use the following context as your learned knowledge, inside <context></context> XML tags.
-
-<context>
-{{context}}
-</context>
-
-When answer to user:
-- If you don't know, just say that you don't know.
-- If you don't know when you are not sure, ask for clarification. 
-Avoid mentioning that you obtained the information from the context.
-And answer according to the language of the user's question.
-"""
-
-            if pre_prompt:
-                human_message_prompt += pre_prompt
-
-            query_prompt = "\n\nHuman: {{query}}\n\nAssistant: "
-
-            if memory:
-                # append chat histories
-                tmp_human_message = PromptBuilder.to_human_message(
-                    prompt_content=human_message_prompt + query_prompt,
-                    inputs=human_inputs
-                )
-
-                if memory.model_instance.model_rules.max_tokens.max:
-                    curr_message_tokens = memory.model_instance.get_num_tokens(to_prompt_messages([tmp_human_message]))
-                    max_tokens = model.get("completion_params").get('max_tokens')
-                    rest_tokens = memory.model_instance.model_rules.max_tokens.max - max_tokens - curr_message_tokens
-                    rest_tokens = max(rest_tokens, 0)
-                else:
-                    rest_tokens = 2000
-
-                histories = cls.get_history_messages_from_memory(memory, rest_tokens)
-                human_message_prompt += "\n\n" if human_message_prompt else ""
-                human_message_prompt += "Here is the chat histories between human and assistant, " \
-                                        "inside <histories></histories> XML tags.\n\n<histories>\n"
-                human_message_prompt += histories + "\n</histories>"
-
-            human_message_prompt += query_prompt
-
-            # construct main prompt
-            human_message = PromptBuilder.to_human_message(
-                prompt_content=human_message_prompt,
-                inputs=human_inputs
-            )
-
-            messages.append(human_message)
-
-            for message in messages:
-                message.content = re.sub(r'<\|.*?\|>', '', message.content)
-
-            return to_prompt_messages(messages), ['\nHuman:', '</histories>']
-
-    @classmethod
     def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
                                          max_token_limit: int) -> str:
         """Get memory messages."""
@@ -307,13 +199,12 @@ And answer according to the language of the user's question.
             max_tokens = 0
 
         # get prompt without memory and context
-        prompt_messages, _ = cls.get_main_llm_prompt(
+        prompt_messages, _ = model_instance.get_prompt(
             mode=mode,
-            model=app_model_config.model_dict,
             pre_prompt=app_model_config.pre_prompt,
-            query=query,
             inputs=inputs,
-            agent_execute_result=None,
+            query=query,
+            context=None,
             memory=None
         )
 
@@ -358,13 +249,12 @@ And answer according to the language of the user's question.
         )
 
         # get llm prompt
-        old_prompt_messages, _ = cls.get_main_llm_prompt(
-            mode="completion",
-            model=app_model_config.model_dict,
+        old_prompt_messages, _ = final_model_instance.get_prompt(
+            mode='completion',
             pre_prompt=pre_prompt,
-            query=message.query,
             inputs=message.inputs,
-            agent_execute_result=None,
+            query=message.query,
+            context=None,
             memory=None
         )
 

+ 132 - 10
api/core/model_providers/models/llm/base.py

@@ -1,17 +1,24 @@
+import json
+import os
+import re
 from abc import abstractmethod
-from typing import List, Optional, Any, Union
+from typing import List, Optional, Any, Union, Tuple
 import decimal
 
 from langchain.callbacks.manager import Callbacks
+from langchain.memory.chat_memory import BaseChatMemory
 from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
 
 from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
 from core.model_providers.models.base import BaseProviderModel
-from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult
+from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
 from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
 from core.model_providers.providers.base import BaseModelProvider
+from core.prompt.prompt_builder import PromptBuilder
+from core.prompt.prompt_template import JinjaPromptTemplate
 from core.third_party.langchain.llms.fake import FakeLLM
 import logging
+
 logger = logging.getLogger(__name__)
 
 
@@ -76,13 +83,14 @@ class BaseLLM(BaseProviderModel):
     def price_config(self) -> dict:
         def get_or_default():
             default_price_config = {
-                    'prompt': decimal.Decimal('0'),
-                    'completion': decimal.Decimal('0'),
-                    'unit': decimal.Decimal('0'),
-                    'currency': 'USD'
-                }
+                'prompt': decimal.Decimal('0'),
+                'completion': decimal.Decimal('0'),
+                'unit': decimal.Decimal('0'),
+                'currency': 'USD'
+            }
             rules = self.model_provider.get_rules()
-            price_config = rules['price_config'][self.base_model_name] if 'price_config' in rules else default_price_config
+            price_config = rules['price_config'][
+                self.base_model_name] if 'price_config' in rules else default_price_config
             price_config = {
                 'prompt': decimal.Decimal(price_config['prompt']),
                 'completion': decimal.Decimal(price_config['completion']),
@@ -90,7 +98,7 @@ class BaseLLM(BaseProviderModel):
                 'currency': price_config['currency']
             }
             return price_config
-        
+
         self._price_config = self._price_config if hasattr(self, '_price_config') else get_or_default()
 
         logger.debug(f"model: {self.name} price_config: {self._price_config}")
@@ -158,7 +166,8 @@ class BaseLLM(BaseProviderModel):
             total_tokens = result.llm_output['token_usage']['total_tokens']
         else:
             prompt_tokens = self.get_num_tokens(messages)
-            completion_tokens = self.get_num_tokens([PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
+            completion_tokens = self.get_num_tokens(
+                [PromptMessage(content=completion_content, type=MessageType.ASSISTANT)])
             total_tokens = prompt_tokens + completion_tokens
 
         self.model_provider.update_last_used()
@@ -293,6 +302,119 @@ class BaseLLM(BaseProviderModel):
     def support_streaming(cls):
         return False
 
+    def get_prompt(self, mode: str,
+                   pre_prompt: str, inputs: dict,
+                   query: str,
+                   context: Optional[str],
+                   memory: Optional[BaseChatMemory]) -> \
+            Tuple[List[PromptMessage], Optional[List[str]]]:
+        prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
+        prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
+        return [PromptMessage(content=prompt)], stops
+
+    def prompt_file_name(self, mode: str) -> str:
+        if mode == 'completion':
+            return 'common_completion'
+        else:
+            return 'common_chat'
+
+    def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
+                             query: str,
+                             context: Optional[str],
+                             memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
+        context_prompt_content = ''
+        if context and 'context_prompt' in prompt_rules:
+            prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
+            context_prompt_content = prompt_template.format(
+                context=context
+            )
+
+        pre_prompt_content = ''
+        if pre_prompt:
+            prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
+            prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
+            pre_prompt_content = prompt_template.format(
+                **prompt_inputs
+            )
+
+        prompt = ''
+        for order in prompt_rules['system_prompt_orders']:
+            if order == 'context_prompt':
+                prompt += context_prompt_content
+            elif order == 'pre_prompt':
+                prompt += (pre_prompt_content + '\n\n') if pre_prompt_content else ''
+
+        query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
+
+        if memory and 'histories_prompt' in prompt_rules:
+            # append chat histories
+            tmp_human_message = PromptBuilder.to_human_message(
+                prompt_content=prompt + query_prompt,
+                inputs={
+                    'query': query
+                }
+            )
+
+            if self.model_rules.max_tokens.max:
+                curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
+                max_tokens = self.model_kwargs.max_tokens
+                rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
+                rest_tokens = max(rest_tokens, 0)
+            else:
+                rest_tokens = 2000
+
+            memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
+            memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
+
+            histories = self._get_history_messages_from_memory(memory, rest_tokens)
+            prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
+            histories_prompt_content = prompt_template.format(
+                histories=histories
+            )
+
+            prompt = ''
+            for order in prompt_rules['system_prompt_orders']:
+                if order == 'context_prompt':
+                    prompt += context_prompt_content
+                elif order == 'pre_prompt':
+                    prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
+                elif order == 'histories_prompt':
+                    prompt += histories_prompt_content
+
+        prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
+        query_prompt_content = prompt_template.format(
+            query=query
+        )
+
+        prompt += query_prompt_content
+
+        prompt = re.sub(r'<\|.*?\|>', '', prompt)
+
+        stops = prompt_rules.get('stops')
+        if stops is not None and len(stops) == 0:
+            stops = None
+
+        return prompt, stops
+
+    def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
+        # Get the absolute path of the subdirectory
+        prompt_path = os.path.join(
+            os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
+            'prompt/generate_prompts')
+
+        json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
+        # Open the JSON file and read its content
+        with open(json_file_path, 'r') as json_file:
+            return json.load(json_file)
+
+    def _get_history_messages_from_memory(self, memory: BaseChatMemory,
+                                          max_token_limit: int) -> str:
+        """Get memory messages."""
+        memory.max_token_limit = max_token_limit
+        memory_key = memory.memory_variables[0]
+        external_context = memory.load_memory_variables({})
+        return external_context[memory_key]
+
     def _get_prompt_from_messages(self, messages: List[PromptMessage],
                                   model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
         if not model_mode:

+ 9 - 0
api/core/model_providers/models/llm/huggingface_hub_model.py

@@ -60,6 +60,15 @@ class HuggingfaceHubModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         return self._client.get_num_tokens(prompts)
 
+    def prompt_file_name(self, mode: str) -> str:
+        if 'baichuan' in self.name.lower():
+            if mode == 'completion':
+                return 'baichuan_completion'
+            else:
+                return 'baichuan_chat'
+        else:
+            return super().prompt_file_name(mode)
+
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
         self.client.model_kwargs = provider_model_kwargs

+ 9 - 0
api/core/model_providers/models/llm/openllm_model.py

@@ -49,6 +49,15 @@ class OpenLLMModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         return max(self._client.get_num_tokens(prompts), 0)
 
+    def prompt_file_name(self, mode: str) -> str:
+        if 'baichuan' in self.name.lower():
+            if mode == 'completion':
+                return 'baichuan_completion'
+            else:
+                return 'baichuan_chat'
+        else:
+            return super().prompt_file_name(mode)
+
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         pass
 

+ 9 - 0
api/core/model_providers/models/llm/xinference_model.py

@@ -59,6 +59,15 @@ class XinferenceModel(BaseLLM):
         prompts = self._get_prompt_from_messages(messages)
         return max(self._client.get_num_tokens(prompts), 0)
 
+    def prompt_file_name(self, mode: str) -> str:
+        if 'baichuan' in self.name.lower():
+            if mode == 'completion':
+                return 'baichuan_completion'
+            else:
+                return 'baichuan_chat'
+        else:
+            return super().prompt_file_name(mode)
+
     def _set_model_kwargs(self, model_kwargs: ModelKwargs):
         pass
 

+ 13 - 0
api/core/prompt/generate_prompts/baichuan_chat.json

@@ -0,0 +1,13 @@
+{
+  "human_prefix": "用户",
+  "assistant_prefix": "助手",
+  "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n引用材料\n{{context}}\n```\n\n",
+  "histories_prompt": "用户和助手的历史对话内容如下:\n```\n{{histories}}\n```\n\n",
+  "system_prompt_orders": [
+    "context_prompt",
+    "pre_prompt",
+    "histories_prompt"
+  ],
+  "query_prompt": "用户:{{query}}\n助手:",
+  "stops": ["用户:"]
+}

+ 9 - 0
api/core/prompt/generate_prompts/baichuan_completion.json

@@ -0,0 +1,9 @@
+{
+  "context_prompt": "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n引用材料\n{{context}}\n```\n",
+  "system_prompt_orders": [
+    "context_prompt",
+    "pre_prompt"
+  ],
+  "query_prompt": "{{query}}",
+  "stops": null
+}

+ 13 - 0
api/core/prompt/generate_prompts/common_chat.json

@@ -0,0 +1,13 @@
+{
+  "human_prefix": "Human",
+  "assistant_prefix": "Assistant",
+  "context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{context}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
+  "histories_prompt": "Here is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{histories}}\n</histories>\n\n",
+  "system_prompt_orders": [
+    "context_prompt",
+    "pre_prompt",
+    "histories_prompt"
+  ],
+  "query_prompt": "Human: {{query}}\n\nAssistant: ",
+  "stops": ["\nHuman:", "</histories>"]
+}

+ 9 - 0
api/core/prompt/generate_prompts/common_completion.json

@@ -0,0 +1,9 @@
+{
+  "context_prompt": "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{context}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n\n",
+  "system_prompt_orders": [
+    "context_prompt",
+    "pre_prompt"
+  ],
+  "query_prompt": "{{query}}",
+  "stops": null
+}