Browse Source

feat: optimize template parse (#460)

John Wang 1 year ago
parent
commit
c720f831af

+ 0 - 4
api/core/__init__.py

@@ -3,7 +3,6 @@ from typing import Optional
 
 
 import langchain
 import langchain
 from flask import Flask
 from flask import Flask
-from langchain.prompts.base import DEFAULT_FORMATTER_MAPPING
 from pydantic import BaseModel
 from pydantic import BaseModel
 
 
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
 from core.callback_handler.std_out_callback_handler import DifyStdOutCallbackHandler
@@ -22,9 +21,6 @@ hosted_llm_credentials = HostedLLMCredentials()
 
 
 
 
 def init_app(app: Flask):
 def init_app(app: Flask):
-    formatter = OneLineFormatter()
-    DEFAULT_FORMATTER_MAPPING['f-string'] = formatter.format
-
     if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
     if os.environ.get("DEBUG") and os.environ.get("DEBUG").lower() == 'true':
         langchain.verbose = True
         langchain.verbose = True
 
 

+ 24 - 23
api/core/completion.py

@@ -23,7 +23,7 @@ from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
 from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
 from core.memory.read_only_conversation_token_db_string_buffer_shared_memory import \
     ReadOnlyConversationTokenDBStringBufferSharedMemory
     ReadOnlyConversationTokenDBStringBufferSharedMemory
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import OutLinePromptTemplate
+from core.prompt.prompt_template import JinjaPromptTemplate
 from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
 from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
 from models.model import App, AppModelConfig, Account, Conversation, Message
 from models.model import App, AppModelConfig, Account, Conversation, Message
 
 
@@ -35,6 +35,8 @@ class Completion:
         """
         """
         errors: ProviderTokenNotInitError
         errors: ProviderTokenNotInitError
         """
         """
+        query = PromptBuilder.process_template(query)
+
         memory = None
         memory = None
         if conversation:
         if conversation:
             # get memory of conversation (read-only)
             # get memory of conversation (read-only)
@@ -141,18 +143,17 @@ class Completion:
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
                             memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory]) -> \
             Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
             Tuple[Union[str | List[BaseMessage]], Optional[List[str]]]:
         # disable template string in query
         # disable template string in query
-        query_params = OutLinePromptTemplate.from_template(template=query).input_variables
-        if query_params:
-            for query_param in query_params:
-                if query_param not in inputs:
-                    inputs[query_param] = '{' + query_param + '}'
+        # query_params = JinjaPromptTemplate.from_template(template=query).input_variables
+        # if query_params:
+        #     for query_param in query_params:
+        #         if query_param not in inputs:
+        #             inputs[query_param] = '{{' + query_param + '}}'
 
 
-        pre_prompt = PromptBuilder.process_template(pre_prompt) if pre_prompt else pre_prompt
         if mode == 'completion':
         if mode == 'completion':
-            prompt_template = OutLinePromptTemplate.from_template(
+            prompt_template = JinjaPromptTemplate.from_template(
                 template=("""Use the following CONTEXT as your learned knowledge:
                 template=("""Use the following CONTEXT as your learned knowledge:
 [CONTEXT]
 [CONTEXT]
-{context}
+{{context}}
 [END CONTEXT]
 [END CONTEXT]
 
 
 When answer to user:
 When answer to user:
@@ -162,16 +163,16 @@ Avoid mentioning that you obtained the information from the context.
 And answer according to the language of the user's question.
 And answer according to the language of the user's question.
 """ if chain_output else "")
 """ if chain_output else "")
                          + (pre_prompt + "\n" if pre_prompt else "")
                          + (pre_prompt + "\n" if pre_prompt else "")
-                         + "{query}\n"
+                         + "{{query}}\n"
             )
             )
 
 
             if chain_output:
             if chain_output:
                 inputs['context'] = chain_output
                 inputs['context'] = chain_output
-                context_params = OutLinePromptTemplate.from_template(template=chain_output).input_variables
-                if context_params:
-                    for context_param in context_params:
-                        if context_param not in inputs:
-                            inputs[context_param] = '{' + context_param + '}'
+                # context_params = JinjaPromptTemplate.from_template(template=chain_output).input_variables
+                # if context_params:
+                #     for context_param in context_params:
+                #         if context_param not in inputs:
+                #             inputs[context_param] = '{{' + context_param + '}}'
 
 
             prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
             prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
             prompt_content = prompt_template.format(
             prompt_content = prompt_template.format(
@@ -195,7 +196,7 @@ And answer according to the language of the user's question.
 
 
             if pre_prompt:
             if pre_prompt:
                 pre_prompt_inputs = {k: inputs[k] for k in
                 pre_prompt_inputs = {k: inputs[k] for k in
-                                     OutLinePromptTemplate.from_template(template=pre_prompt).input_variables
+                                     JinjaPromptTemplate.from_template(template=pre_prompt).input_variables
                                      if k in inputs}
                                      if k in inputs}
 
 
                 if pre_prompt_inputs:
                 if pre_prompt_inputs:
@@ -205,7 +206,7 @@ And answer according to the language of the user's question.
                 human_inputs['context'] = chain_output
                 human_inputs['context'] = chain_output
                 human_message_prompt += """Use the following CONTEXT as your learned knowledge.
                 human_message_prompt += """Use the following CONTEXT as your learned knowledge.
 [CONTEXT]
 [CONTEXT]
-{context}
+{{context}}
 [END CONTEXT]
 [END CONTEXT]
 
 
 When answer to user:
 When answer to user:
@@ -218,7 +219,7 @@ And answer according to the language of the user's question.
             if pre_prompt:
             if pre_prompt:
                 human_message_prompt += pre_prompt
                 human_message_prompt += pre_prompt
 
 
-            query_prompt = "\nHuman: {query}\nAI: "
+            query_prompt = "\nHuman: {{query}}\nAI: "
 
 
             if memory:
             if memory:
                 # append chat histories
                 # append chat histories
@@ -234,11 +235,11 @@ And answer according to the language of the user's question.
                 histories = cls.get_history_messages_from_memory(memory, rest_tokens)
                 histories = cls.get_history_messages_from_memory(memory, rest_tokens)
 
 
                 # disable template string in query
                 # disable template string in query
-                histories_params = OutLinePromptTemplate.from_template(template=histories).input_variables
-                if histories_params:
-                    for histories_param in histories_params:
-                        if histories_param not in human_inputs:
-                            human_inputs[histories_param] = '{' + histories_param + '}'
+                # histories_params = JinjaPromptTemplate.from_template(template=histories).input_variables
+                # if histories_params:
+                #     for histories_param in histories_params:
+                #         if histories_param not in human_inputs:
+                #             human_inputs[histories_param] = '{{' + histories_param + '}}'
 
 
                 human_message_prompt += "\n\n" + histories
                 human_message_prompt += "\n\n" + histories
 
 

+ 3 - 4
api/core/conversation_message_task.py

@@ -10,7 +10,7 @@ from core.constant import llm_constant
 from core.llm.llm_builder import LLMBuilder
 from core.llm.llm_builder import LLMBuilder
 from core.llm.provider.llm_provider_service import LLMProviderService
 from core.llm.provider.llm_provider_service import LLMProviderService
 from core.prompt.prompt_builder import PromptBuilder
 from core.prompt.prompt_builder import PromptBuilder
-from core.prompt.prompt_template import OutLinePromptTemplate
+from core.prompt.prompt_template import JinjaPromptTemplate
 from events.message_event import message_was_created
 from events.message_event import message_was_created
 from extensions.ext_database import db
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from extensions.ext_redis import redis_client
@@ -78,7 +78,7 @@ class ConversationMessageTask:
         if self.mode == 'chat':
         if self.mode == 'chat':
             introduction = self.app_model_config.opening_statement
             introduction = self.app_model_config.opening_statement
             if introduction:
             if introduction:
-                prompt_template = OutLinePromptTemplate.from_template(template=PromptBuilder.process_template(introduction))
+                prompt_template = JinjaPromptTemplate.from_template(template=introduction)
                 prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
                 prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
                 try:
                 try:
                     introduction = prompt_template.format(**prompt_inputs)
                     introduction = prompt_template.format(**prompt_inputs)
@@ -86,8 +86,7 @@ class ConversationMessageTask:
                     pass
                     pass
 
 
             if self.app_model_config.pre_prompt:
             if self.app_model_config.pre_prompt:
-                pre_prompt = PromptBuilder.process_template(self.app_model_config.pre_prompt)
-                system_message = PromptBuilder.to_system_message(pre_prompt, self.inputs)
+                system_message = PromptBuilder.to_system_message(self.app_model_config.pre_prompt, self.inputs)
                 system_instruction = system_message.content
                 system_instruction = system_message.content
                 llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
                 llm = LLMBuilder.to_llm(self.tenant_id, self.model_name)
                 system_instruction_tokens = llm.get_messages_tokens([system_message])
                 system_instruction_tokens = llm.get_messages_tokens([system_message])

+ 4 - 3
api/core/generator/llm_generator.py

@@ -1,5 +1,6 @@
 import logging
 import logging
 
 
+from langchain import PromptTemplate
 from langchain.chat_models.base import BaseChatModel
 from langchain.chat_models.base import BaseChatModel
 from langchain.schema import HumanMessage, OutputParserException
 from langchain.schema import HumanMessage, OutputParserException
 
 
@@ -10,7 +11,7 @@ from core.llm.token_calculator import TokenCalculator
 from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
 from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
 
 
 from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
 from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
-from core.prompt.prompt_template import OutLinePromptTemplate
+from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
 from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
 from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT
 
 
 
 
@@ -91,8 +92,8 @@ class LLMGenerator:
         output_parser = SuggestedQuestionsAfterAnswerOutputParser()
         output_parser = SuggestedQuestionsAfterAnswerOutputParser()
         format_instructions = output_parser.get_format_instructions()
         format_instructions = output_parser.get_format_instructions()
 
 
-        prompt = OutLinePromptTemplate(
-            template="{histories}\n{format_instructions}\nquestions:\n",
+        prompt = JinjaPromptTemplate(
+            template="{{histories}}\n{{format_instructions}}\nquestions:\n",
             input_variables=["histories"],
             input_variables=["histories"],
             partial_variables={"format_instructions": format_instructions}
             partial_variables={"format_instructions": format_instructions}
         )
         )

+ 7 - 6
api/core/prompt/prompt_builder.py

@@ -3,13 +3,13 @@ import re
 from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
 from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
 from langchain.schema import BaseMessage
 from langchain.schema import BaseMessage
 
 
-from core.prompt.prompt_template import OutLinePromptTemplate
+from core.prompt.prompt_template import JinjaPromptTemplate
 
 
 
 
 class PromptBuilder:
 class PromptBuilder:
     @classmethod
     @classmethod
     def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
     def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
-        prompt_template = OutLinePromptTemplate.from_template(prompt_content)
+        prompt_template = JinjaPromptTemplate.from_template(prompt_content)
         system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
         system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
         prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
         prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
         system_message = system_prompt_template.format(**prompt_inputs)
         system_message = system_prompt_template.format(**prompt_inputs)
@@ -17,7 +17,7 @@ class PromptBuilder:
 
 
     @classmethod
     @classmethod
     def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
     def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
-        prompt_template = OutLinePromptTemplate.from_template(prompt_content)
+        prompt_template = JinjaPromptTemplate.from_template(prompt_content)
         ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
         ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
         prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
         prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
         ai_message = ai_prompt_template.format(**prompt_inputs)
         ai_message = ai_prompt_template.format(**prompt_inputs)
@@ -25,13 +25,14 @@ class PromptBuilder:
 
 
     @classmethod
     @classmethod
     def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
     def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
-        prompt_template = OutLinePromptTemplate.from_template(prompt_content)
+        prompt_template = JinjaPromptTemplate.from_template(prompt_content)
         human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
         human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
         human_message = human_prompt_template.format(**inputs)
         human_message = human_prompt_template.format(**inputs)
         return human_message
         return human_message
 
 
     @classmethod
     @classmethod
     def process_template(cls, template: str):
     def process_template(cls, template: str):
-        processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
-        processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
+        processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
+        # processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
+        # processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
         return processed_template
         return processed_template

+ 41 - 0
api/core/prompt/prompt_template.py

@@ -1,10 +1,33 @@
 import re
 import re
 from typing import Any
 from typing import Any
 
 
+from jinja2 import Environment, meta
 from langchain import PromptTemplate
 from langchain import PromptTemplate
 from langchain.formatting import StrictFormatter
 from langchain.formatting import StrictFormatter
 
 
 
 
+class JinjaPromptTemplate(PromptTemplate):
+    template_format: str = "jinja2"
+    """The format of the prompt template. Options are: 'f-string', 'jinja2'."""
+
+    @classmethod
+    def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
+        """Load a prompt template from a template."""
+        env = Environment()
+        ast = env.parse(template)
+        input_variables = meta.find_undeclared_variables(ast)
+
+        if "partial_variables" in kwargs:
+            partial_variables = kwargs["partial_variables"]
+            input_variables = {
+                var for var in input_variables if var not in partial_variables
+            }
+
+        return cls(
+            input_variables=list(sorted(input_variables)), template=template, **kwargs
+        )
+
+
 class OutLinePromptTemplate(PromptTemplate):
 class OutLinePromptTemplate(PromptTemplate):
     @classmethod
     @classmethod
     def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
     def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
@@ -16,6 +39,24 @@ class OutLinePromptTemplate(PromptTemplate):
             input_variables=list(sorted(input_variables)), template=template, **kwargs
             input_variables=list(sorted(input_variables)), template=template, **kwargs
         )
         )
 
 
+    def format(self, **kwargs: Any) -> str:
+        """Format the prompt with the inputs.
+
+        Args:
+            kwargs: Any arguments to be passed to the prompt template.
+
+        Returns:
+            A formatted string.
+
+        Example:
+
+        .. code-block:: python
+
+            prompt.format(variable1="foo")
+        """
+        kwargs = self._merge_partial_and_user_variables(**kwargs)
+        return OneLineFormatter().format(self.template, **kwargs)
+
 
 
 class OneLineFormatter(StrictFormatter):
 class OneLineFormatter(StrictFormatter):
     def parse(self, format_string):
     def parse(self, format_string):

+ 4 - 4
api/core/prompt/prompts.py

@@ -1,5 +1,5 @@
 CONVERSATION_TITLE_PROMPT = (
 CONVERSATION_TITLE_PROMPT = (
-    "Human:{query}\n-----\n"
+    "Human:{{query}}\n-----\n"
     "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
     "Help me summarize the intent of what the human said and provide a title, the title should not exceed 20 words.\n"
     "If the human said is conducted in Chinese, you should return a Chinese title.\n" 
     "If the human said is conducted in Chinese, you should return a Chinese title.\n" 
     "If the human said is conducted in English, you should return an English title.\n"
     "If the human said is conducted in English, you should return an English title.\n"
@@ -19,7 +19,7 @@ CONVERSATION_SUMMARY_PROMPT = (
 INTRODUCTION_GENERATE_PROMPT = (
 INTRODUCTION_GENERATE_PROMPT = (
     "I am designing a product for users to interact with an AI through dialogue. "
     "I am designing a product for users to interact with an AI through dialogue. "
     "The Prompt given to the AI before the conversation is:\n\n"
     "The Prompt given to the AI before the conversation is:\n\n"
-    "```\n{prompt}\n```\n\n"
+    "```\n{{prompt}}\n```\n\n"
     "Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
     "Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
     "Do not reveal the developer's motivation or deep logic behind the Prompt, "
     "Do not reveal the developer's motivation or deep logic behind the Prompt, "
     "but focus on building a relationship with the user:\n"
     "but focus on building a relationship with the user:\n"
@@ -27,13 +27,13 @@ INTRODUCTION_GENERATE_PROMPT = (
 
 
 MORE_LIKE_THIS_GENERATE_PROMPT = (
 MORE_LIKE_THIS_GENERATE_PROMPT = (
     "-----\n"
     "-----\n"
-    "{original_completion}\n"
+    "{{original_completion}}\n"
     "-----\n\n"
     "-----\n\n"
     "Please use the above content as a sample for generating the result, "
     "Please use the above content as a sample for generating the result, "
     "and include key information points related to the original sample in the result. "
     "and include key information points related to the original sample in the result. "
     "Try to rephrase this information in different ways and predict according to the rules below.\n\n"
     "Try to rephrase this information in different ways and predict according to the rules below.\n\n"
     "-----\n"
     "-----\n"
-    "{prompt}\n"
+    "{{prompt}}\n"
 )
 )
 
 
 SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
 SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (