|
@@ -1,3 +1,4 @@
|
|
|
+import json
|
|
|
import logging
|
|
|
from typing import Optional, Union, cast
|
|
|
|
|
@@ -62,13 +63,20 @@ class QuestionClassifierNode(LLMNode):
|
|
|
prompt_messages=prompt_messages,
|
|
|
stop=stop
|
|
|
)
|
|
|
- categories = [_class.name for _class in node_data.classes]
|
|
|
+ category_name = node_data.classes[0].name
|
|
|
+ category_id = node_data.classes[0].id
|
|
|
try:
|
|
|
result_text_json = parse_and_check_json_markdown(result_text, [])
|
|
|
- #result_text_json = json.loads(result_text.strip('```JSON\n'))
|
|
|
- categories_result = result_text_json.get('categories', [])
|
|
|
- if categories_result:
|
|
|
- categories = categories_result
|
|
|
+ # result_text_json = json.loads(result_text.strip('```JSON\n'))
|
|
|
+ if 'category_name' in result_text_json and 'category_id' in result_text_json:
|
|
|
+ category_id_result = result_text_json['category_id']
|
|
|
+ classes = node_data.classes
|
|
|
+ classes_map = {class_.id: class_.name for class_ in classes}
|
|
|
+ category_ids = [_class.id for _class in classes]
|
|
|
+ if category_id_result in category_ids:
|
|
|
+ category_name = classes_map[category_id_result]
|
|
|
+ category_id = category_id_result
|
|
|
+
|
|
|
except Exception:
|
|
|
logging.error(f"Failed to parse result text: {result_text}")
|
|
|
try:
|
|
@@ -81,17 +89,15 @@ class QuestionClassifierNode(LLMNode):
|
|
|
'usage': jsonable_encoder(usage),
|
|
|
}
|
|
|
outputs = {
|
|
|
- 'class_name': categories[0] if categories else ''
|
|
|
+ 'class_name': category_name
|
|
|
}
|
|
|
- classes = node_data.classes
|
|
|
- classes_map = {class_.name: class_.id for class_ in classes}
|
|
|
|
|
|
return NodeRunResult(
|
|
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
|
inputs=variables,
|
|
|
process_data=process_data,
|
|
|
outputs=outputs,
|
|
|
- edge_source_handle=classes_map.get(categories[0], None),
|
|
|
+ edge_source_handle=category_id,
|
|
|
metadata={
|
|
|
NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
|
|
NodeRunMetadataKey.TOTAL_PRICE: usage.total_price,
|
|
@@ -210,8 +216,13 @@ class QuestionClassifierNode(LLMNode):
|
|
|
-> Union[list[ChatModelMessage], CompletionModelPromptTemplate]:
|
|
|
model_mode = ModelMode.value_of(node_data.model.mode)
|
|
|
classes = node_data.classes
|
|
|
- class_names = [class_.name for class_ in classes]
|
|
|
- class_names_str = ','.join(f'"{name}"' for name in class_names)
|
|
|
+ categories = []
|
|
|
+ for class_ in classes:
|
|
|
+ category = {
|
|
|
+ 'category_id': class_.id,
|
|
|
+ 'category_name': class_.name
|
|
|
+ }
|
|
|
+ categories.append(category)
|
|
|
instruction = node_data.instruction if node_data.instruction else ''
|
|
|
input_text = query
|
|
|
memory_str = ''
|
|
@@ -248,7 +259,7 @@ class QuestionClassifierNode(LLMNode):
|
|
|
user_prompt_message_3 = ChatModelMessage(
|
|
|
role=PromptMessageRole.USER,
|
|
|
text=QUESTION_CLASSIFIER_USER_PROMPT_3.format(input_text=input_text,
|
|
|
- categories=class_names_str,
|
|
|
+ categories=json.dumps(categories),
|
|
|
classification_instructions=instruction)
|
|
|
)
|
|
|
prompt_messages.append(user_prompt_message_3)
|
|
@@ -257,7 +268,7 @@ class QuestionClassifierNode(LLMNode):
|
|
|
return CompletionModelPromptTemplate(
|
|
|
text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format(histories=memory_str,
|
|
|
input_text=input_text,
|
|
|
- categories=class_names_str,
|
|
|
+ categories=json.dumps(categories),
|
|
|
classification_instructions=instruction)
|
|
|
)
|
|
|
|