Browse Source

add finish_reason to the LLM node output (#7498)

orangeclk 8 months ago
parent
commit
f53454f81d

+ 3 - 1
api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py

@@ -428,7 +428,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                 if new_tool_call.function.arguments:
                     tool_call.function.arguments += new_tool_call.function.arguments
 
-        finish_reason = 'Unknown'
+        finish_reason = None  # The default value of finish_reason is None
 
         for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
             chunk = chunk.strip()
@@ -437,6 +437,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
                 if chunk.startswith(':'):
                     continue
                 decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
+                if decoded_chunk == '[DONE]':  # Some provider returns "data: [DONE]"
+                    continue
 
                 try:
                     chunk_json = json.loads(decoded_chunk)

+ 10 - 5
api/core/workflow/nodes/llm/llm_node.py

@@ -113,7 +113,7 @@ class LLMNode(BaseNode):
             }
 
             # handle invoke result
-            result_text, usage = self._invoke_llm(
+            result_text, usage, finish_reason = self._invoke_llm(
                 node_data_model=node_data.model,
                 model_instance=model_instance,
                 prompt_messages=prompt_messages,
@@ -129,7 +129,8 @@ class LLMNode(BaseNode):
 
         outputs = {
             'text': result_text,
-            'usage': jsonable_encoder(usage)
+            'usage': jsonable_encoder(usage),
+            'finish_reason': finish_reason
         }
 
         return NodeRunResult(
@@ -167,14 +168,14 @@ class LLMNode(BaseNode):
         )
 
         # handle invoke result
-        text, usage = self._handle_invoke_result(
+        text, usage, finish_reason = self._handle_invoke_result(
             invoke_result=invoke_result
         )
 
         # deduct quota
         self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
 
-        return text, usage
+        return text, usage, finish_reason
 
     def _handle_invoke_result(self, invoke_result: Generator) -> tuple[str, LLMUsage]:
         """
@@ -186,6 +187,7 @@ class LLMNode(BaseNode):
         prompt_messages = []
         full_text = ''
         usage = None
+        finish_reason = None
         for result in invoke_result:
             text = result.delta.message.content
             full_text += text
@@ -201,10 +203,13 @@ class LLMNode(BaseNode):
             if not usage and result.delta.usage:
                 usage = result.delta.usage
 
+            if not finish_reason and result.delta.finish_reason:
+                finish_reason = result.delta.finish_reason
+
         if not usage:
             usage = LLMUsage.empty_usage()
 
-        return full_text, usage
+        return full_text, usage, finish_reason
 
     def _transform_chat_messages(self,
         messages: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate

+ 2 - 1
api/core/workflow/nodes/question_classifier/question_classifier_node.py

@@ -63,7 +63,7 @@ class QuestionClassifierNode(LLMNode):
         )
 
         # handle invoke result
-        result_text, usage = self._invoke_llm(
+        result_text, usage, finish_reason = self._invoke_llm(
             node_data_model=node_data.model,
             model_instance=model_instance,
             prompt_messages=prompt_messages,
@@ -93,6 +93,7 @@ class QuestionClassifierNode(LLMNode):
                     prompt_messages=prompt_messages
                 ),
                 'usage': jsonable_encoder(usage),
+                'finish_reason': finish_reason
             }
             outputs = {
                 'class_name': category_name