Browse Source

question classifier optimize (#4147)

Jyong 1 year ago
parent
commit
e353809680

+ 24 - 13
api/core/workflow/nodes/question_classifier/question_classifier_node.py

@@ -1,3 +1,4 @@
+import json
 import logging
 from typing import Optional, Union, cast
 
@@ -62,13 +63,20 @@ class QuestionClassifierNode(LLMNode):
             prompt_messages=prompt_messages,
             stop=stop
         )
-        categories = [_class.name for _class in node_data.classes]
+        category_name = node_data.classes[0].name
+        category_id = node_data.classes[0].id
         try:
             result_text_json = parse_and_check_json_markdown(result_text, [])
-            #result_text_json = json.loads(result_text.strip('```JSON\n'))
-            categories_result = result_text_json.get('categories', [])
-            if categories_result:
-                categories = categories_result
+            # result_text_json = json.loads(result_text.strip('```JSON\n'))
+            if 'category_name' in result_text_json and 'category_id' in result_text_json:
+                category_id_result = result_text_json['category_id']
+                classes = node_data.classes
+                classes_map = {class_.id: class_.name for class_ in classes}
+                category_ids = [_class.id for _class in classes]
+                if category_id_result in category_ids:
+                    category_name = classes_map[category_id_result]
+                    category_id = category_id_result
+
         except Exception:
             logging.error(f"Failed to parse result text: {result_text}")
         try:
@@ -81,17 +89,15 @@ class QuestionClassifierNode(LLMNode):
                 'usage': jsonable_encoder(usage),
             }
             outputs = {
-                'class_name': categories[0] if categories else ''
+                'class_name': category_name
             }
-            classes = node_data.classes
-            classes_map = {class_.name: class_.id for class_ in classes}
 
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.SUCCEEDED,
                 inputs=variables,
                 process_data=process_data,
                 outputs=outputs,
-                edge_source_handle=classes_map.get(categories[0], None),
+                edge_source_handle=category_id,
                 metadata={
                     NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
                     NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
@@ -210,8 +216,13 @@ class QuestionClassifierNode(LLMNode):
             -> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
         model_mode = ModelMode.value_of(node_data.model.mode)
         classes = node_data.classes
-        class_names = [class_.name for class_ in classes]
-        class_names_str = ','.join(f'"{name}"' for name in class_names)
+        categories = []
+        for class_ in classes:
+            category = {
+                'category_id': class_.id,
+                'category_name': class_.name
+            }
+            categories.append(category)
         instruction = node_data.instruction if node_data.instruction else ''
         input_text = query
         memory_str = ''
@@ -248,7 +259,7 @@ class QuestionClassifierNode(LLMNode):
             user_prompt_message_3 = ChatModelMessage(
                 role=PromptMessageRole.USER,
                 text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text,
-                                                              categories=class_names_str,
+                                                              categories=json.dumps(categories),
                                                               classification_instructions=instruction)
             )
             prompt_messages.append(user_prompt_message_3)
@@ -257,7 +268,7 @@ class QuestionClassifierNode(LLMNode):
             return CompletionModelPromptTemplate(
                 text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str,
                                                                   input_text=input_text,
-                                                                  categories=class_names_str,
+                                                                  categories=json.dumps(categories),
                                                                   classification_instructions=instruction)
             )
 

File diff suppressed because it is too large
+ 14 - 14
api/core/workflow/nodes/question_classifier/template_prompts.py