Browse Source

Feat/optimize chat prompt (#158)

John Wang 1 year ago
parent
commit
90150a6ca9
1 changed files with 38 additions and 33 deletions
  1. 38 33
      api/core/completion.py

+ 38 - 33
api/core/completion.py

@@ -39,7 +39,8 @@ class Completion:
             memory = cls.get_memory_from_conversation(
                 tenant_id=app.tenant_id,
                 app_model_config=app_model_config,
-                conversation=conversation
+                conversation=conversation,
+                return_messages=False
             )
 
             inputs = conversation.inputs
@@ -119,7 +120,8 @@ class Completion:
         return response
 
     @classmethod
-    def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict, chain_output: Optional[str],
+    def get_main_llm_prompt(cls, mode: str, llm: BaseLanguageModel, pre_prompt: str, query: str, inputs: dict,
+                            chain_output: Optional[str],
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
             Union[str | List[BaseMessage]]:
         pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
@@ -161,11 +163,19 @@ And answer according to the language of the user's question.
                 "query": query
             }
 
-            human_message_prompt = "{query}"
+            human_message_prompt = ""
+
+            if pre_prompt:
+                pre_prompt_inputs = {k: inputs[k] for k in
+                                     OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
+                                     if k in inputs}
+
+                if pre_prompt_inputs:
+                    human_inputs.update(pre_prompt_inputs)
 
             if chain_output:
                 human_inputs['context'] = chain_output
-                human_message_instruction = """Use the following CONTEXT as your learned knowledge.
+                human_message_prompt += """Use the following CONTEXT as your learned knowledge.
 [CONTEXT]
 {context}
 [END CONTEXT]
@@ -176,39 +186,33 @@ When answer to user:
 Avoid mentioning that you obtained the information from the context.
 And answer according to the language of the user's question.
 """
-                if pre_prompt:
-                    extra_inputs = {k: inputs[k] for k in
-                                    OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
-                                    if k in inputs}
-                    if extra_inputs:
-                        human_inputs.update(extra_inputs)
-                    human_message_instruction += pre_prompt + "\n"
-
-                human_message_prompt = human_message_instruction + "Q:{query}\nA:"
-            else:
-                if pre_prompt:
-                    extra_inputs = {k: inputs[k] for k in
-                                    OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
-                                    if k in inputs}
-                    if extra_inputs:
-                        human_inputs.update(extra_inputs)
-                    human_message_prompt = pre_prompt + "\n" + human_message_prompt
 
-            # construct main prompt
-            human_message = PromptBuilder.to_human_message(
-                prompt_content=human_message_prompt,
-                inputs=human_inputs
-            )
+            if pre_prompt:
+                human_message_prompt += pre_prompt
+
+            query_prompt = "\nHuman: {query}\nAI: "
 
             if memory:
                 # append chat histories
-                tmp_messages = messages.copy() + [human_message]
-                curr_message_tokens = memory.llm.get_messages_tokens(tmp_messages)
-                rest_tokens = llm_constant.max_context_token_length[
-                                  memory.llm.model_name] - memory.llm.max_tokens - curr_message_tokens
+                tmp_human_message = PromptBuilder.to_human_message(
+                    prompt_content=human_message_prompt + query_prompt,
+                    inputs=human_inputs
+                )
+
+                curr_message_tokens = memory.llm.get_messages_tokens([tmp_human_message])
+                rest_tokens = llm_constant.max_context_token_length[memory.llm.model_name] \
+                              - memory.llm.max_tokens - curr_message_tokens
                 rest_tokens = max(rest_tokens, 0)
                 history_messages = cls.get_history_messages_from_memory(memory, rest_tokens)
-                messages += history_messages
+                human_message_prompt += "\n\n" + history_messages
+
+            human_message_prompt += query_prompt
+
+            # construct main prompt
+            human_message = PromptBuilder.to_human_message(
+                prompt_content=human_message_prompt,
+                inputs=human_inputs
+            )
 
             messages.append(human_message)
 
@@ -216,7 +220,8 @@ And answer according to the language of the user's question.
 
     @classmethod
     def get_llm_callback_manager(cls, llm: Union[StreamableOpenAI, StreamableChatOpenAI],
-                                 streaming: bool, conversation_message_task: ConversationMessageTask) -> CallbackManager:
+                                 streaming: bool,
+                                 conversation_message_task: ConversationMessageTask) -> CallbackManager:
         llm_callback_handler = LLMCallbackHandler(llm, conversation_message_task)
         if streaming:
             callback_handlers = [llm_callback_handler, DifyStreamingStdOutCallbackHandler()]
@@ -228,7 +233,7 @@ And answer according to the language of the user's question.
     @classmethod
     def get_history_messages_from_memory(cls, memory: ReadOnlyConversationTokenDBBufferSharedMemory,
                                          max_token_limit: int) -> \
-            List[BaseMessage]:
+            str:
         """Get memory messages."""
         memory.max_token_limit = max_token_limit
         memory_key = memory.memory_variables[0]