Browse Source

feat: update prompt generate (#6516)

Joe 9 months ago
parent
commit
8123a00e97

+ 7 - 5
api/controllers/console/app/generator.py

@@ -22,17 +22,19 @@ class RuleGenerateApi(Resource):
     @account_initialization_required
     def post(self):
         parser = reqparse.RequestParser()
-        parser.add_argument('audiences', type=str, required=True, nullable=False, location='json')
-        parser.add_argument('hoping_to_solve', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('instruction', type=str, required=True, nullable=False, location='json')
+        parser.add_argument('model_config', type=dict, required=True, nullable=False, location='json')
+        parser.add_argument('no_variable', type=bool, required=True, default=False, location='json')
         args = parser.parse_args()
 
         account = current_user
 
         try:
             rules = LLMGenerator.generate_rule_config(
-                account.current_tenant_id,
-                args['audiences'],
-                args['hoping_to_solve']
+                tenant_id=account.current_tenant_id,
+                instruction=args['instruction'],
+                model_config=args['model_config'],
+                no_variable=args['no_variable']
             )
         except ProviderTokenNotInitError as ex:
             raise ProviderNotInitializeError(ex.description)

+ 136 - 30
api/core/llm_generator/llm_generator.py

@@ -3,10 +3,13 @@ import logging
 import re
 from typing import Optional
 
-from core.llm_generator.output_parser.errors import OutputParserException
 from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
 from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
-from core.llm_generator.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
+from core.llm_generator.prompts import (
+    CONVERSATION_TITLE_PROMPT,
+    GENERATOR_QA_PROMPT,
+    WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
+)
 from core.model_manager import ModelManager
 from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
 from core.model_runtime.entities.model_entities import ModelType
@@ -115,55 +118,158 @@ class LLMGenerator:
         return questions
 
     @classmethod
-    def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
+    def generate_rule_config(cls, tenant_id: str, instruction: str, model_config: dict, no_variable: bool) -> dict:
         output_parser = RuleConfigGeneratorOutputParser()
 
+        error = ""
+        error_step = ""
+        rule_config = {
+            "prompt": "",
+            "variables": [],
+            "opening_statement": "",
+            "error": ""
+        }
+        model_parameters = {
+            "max_tokens": 512,
+            "temperature": 0.01
+        }
+
+        if no_variable:
+            prompt_template = PromptTemplateParser(
+                WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE
+            )
+
+            prompt_generate = prompt_template.format(
+                inputs={
+                    "TASK_DESCRIPTION": instruction,
+                },
+                remove_template_variables=False
+            )
+
+            prompt_messages = [UserPromptMessage(content=prompt_generate)]
+
+            model_manager = ModelManager()
+
+            model_instance = model_manager.get_default_model_instance(
+                tenant_id=tenant_id,
+                model_type=ModelType.LLM,
+            )
+
+            try:
+                response = model_instance.invoke_llm(
+                    prompt_messages=prompt_messages,
+                    model_parameters=model_parameters,
+                    stream=False
+                )
+
+                rule_config["prompt"] = response.message.content
+                
+            except InvokeError as e:
+                error = str(e)
+                error_step = "generate rule config"
+            except Exception as e:
+                logging.exception(e)
+                rule_config["error"] = str(e)
+
+            rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
+
+            return rule_config
+
+        # get rule config prompt, parameter and statement
+        prompt_generate, parameter_generate, statement_generate = output_parser.get_format_instructions()
+
         prompt_template = PromptTemplateParser(
-            template=output_parser.get_format_instructions()
+            prompt_generate
+        )
+
+        parameter_template = PromptTemplateParser(
+            parameter_generate
+        )
+
+        statement_template = PromptTemplateParser(
+            statement_generate
         )
 
-        prompt = prompt_template.format(
+        # format the prompt_generate_prompt
+        prompt_generate_prompt = prompt_template.format(
             inputs={
-                "audiences": audiences,
-                "hoping_to_solve": hoping_to_solve,
-                "variable": "{{variable}}",
-                "lanA": "{{lanA}}",
-                "lanB": "{{lanB}}",
-                "topic": "{{topic}}"
+                "TASK_DESCRIPTION": instruction,
             },
             remove_template_variables=False
         )
+        prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)]
 
+        # get model instance
         model_manager = ModelManager()
-        model_instance = model_manager.get_default_model_instance(
+        model_instance = model_manager.get_model_instance(
             tenant_id=tenant_id,
             model_type=ModelType.LLM,
+            provider=model_config.get("provider") if model_config else None,
+            model=model_config.get("name") if model_config else None,
         )
 
-        prompt_messages = [UserPromptMessage(content=prompt)]
-
         try:
-            response = model_instance.invoke_llm(
-                prompt_messages=prompt_messages,
-                model_parameters={
-                    "max_tokens": 512,
-                    "temperature": 0
+            try:
+                # the first step to generate the task prompt
+                prompt_content = model_instance.invoke_llm(
+                    prompt_messages=prompt_messages,
+                    model_parameters=model_parameters,
+                    stream=False
+                )
+            except InvokeError as e:
+                error = str(e)
+                error_step = "generate prefix prompt"
+                rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
+
+                return rule_config
+
+            rule_config["prompt"] = prompt_content.message.content
+
+            parameter_generate_prompt = parameter_template.format(
+                inputs={
+                    "INPUT_TEXT": prompt_content.message.content,
                 },
-                stream=False
+                remove_template_variables=False
             )
+            parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
+
+            # the second step to generate the task_parameter and task_statement
+            statement_generate_prompt = statement_template.format(
+                inputs={
+                    "TASK_DESCRIPTION": instruction,
+                    "INPUT_TEXT": prompt_content.message.content,
+                },
+                remove_template_variables=False
+            )
+            statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
+
+            try:
+                parameter_content = model_instance.invoke_llm(
+                    prompt_messages=parameter_messages,
+                    model_parameters=model_parameters,
+                    stream=False
+                )
+                rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.content)
+            except InvokeError as e:
+                error = str(e)
+                error_step = "generate variables"
+
+            try:
+                statement_content = model_instance.invoke_llm(
+                    prompt_messages=statement_messages,
+                    model_parameters=model_parameters,
+                    stream=False
+                )
+                rule_config["opening_statement"] = statement_content.message.content
+            except InvokeError as e:
+                error = str(e)
+                error_step = "generate conversation opener"
 
-            rule_config = output_parser.parse(response.message.content)
-        except InvokeError as e:
-            raise e
-        except OutputParserException:
-            raise ValueError('Please give a valid input for intended audience or hoping to solve problems.')
         except Exception as e:
             logging.exception(e)
-            rule_config = {
-                "prompt": "",
-                "variables": [],
-                "opening_statement": ""
-            }
+            rule_config["error"] = str(e)
+
+        rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
 
         return rule_config
 

+ 7 - 3
api/core/llm_generator/output_parser/rule_config_generator.py

@@ -1,14 +1,18 @@
 from typing import Any
 
 from core.llm_generator.output_parser.errors import OutputParserException
-from core.llm_generator.prompts import RULE_CONFIG_GENERATE_TEMPLATE
+from core.llm_generator.prompts import (
+    RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE,
+    RULE_CONFIG_PROMPT_GENERATE_TEMPLATE,
+    RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE,
+)
 from libs.json_in_md_parser import parse_and_check_json_markdown
 
 
 class RuleConfigGeneratorOutputParser:
 
-    def get_format_instructions(self) -> str:
-        return RULE_CONFIG_GENERATE_TEMPLATE
+    def get_format_instructions(self) -> tuple[str, str, str]:
+        return RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE
 
     def parse(self, text: str) -> Any:
         try:

File diff suppressed because it is too large
+ 67 - 59
api/core/llm_generator/prompts.py


+ 4 - 0
api/core/model_runtime/model_providers/tongyi/llm/llm.py

@@ -262,6 +262,10 @@ You should also complete the text started with ``` but not tell ``` directly.
         :param prompt_messages: prompt messages
         :return: llm response
         """
+        if response.status_code != 200 and response.status_code != HTTPStatus.OK:
+            raise ServiceUnavailableError(
+                response.message
+            )
         # transform assistant message to prompt message
         assistant_prompt_message = AssistantPromptMessage(
             content=response.output.choices[0].message.content,