瀏覽代碼

feat: support agent log event

Yeuoly 6 月之前
父節點
當前提交
9a6f120e5c

+ 5 - 1
api/core/app/apps/advanced_chat/generate_task_pipeline.py

@@ -13,6 +13,7 @@ from core.app.entities.app_invoke_entities import (
 )
 from core.app.entities.queue_entities import (
     QueueAdvancedChatMessageEndEvent,
+    QueueAgentLogEvent,
     QueueAnnotationReplyEvent,
     QueueErrorEvent,
     QueueIterationCompletedEvent,
@@ -124,6 +125,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
 
         self._task_state = WorkflowTaskState()
         self._wip_workflow_node_executions = {}
+        self._wip_workflow_agent_logs = {}
 
         self._conversation_name_generate_thread = None
         self._recorded_files: list[Mapping[str, Any]] = []
@@ -244,7 +246,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 else:
                     start_listener_time = time.time()
                     yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
-            except Exception as e:
+            except Exception:
                 logger.exception(f"Failed to listen audio message, task_id: {task_id}")
                 break
         if tts_publisher:
@@ -493,6 +495,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
                 self._save_message(graph_runtime_state=graph_runtime_state)
 
                 yield self._message_end_to_stream_response()
+            elif isinstance(event, QueueAgentLogEvent):
+                yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
             else:
                 continue
 

+ 5 - 1
api/core/app/apps/workflow/generate_task_pipeline.py

@@ -11,6 +11,7 @@ from core.app.entities.app_invoke_entities import (
     WorkflowAppGenerateEntity,
 )
 from core.app.entities.queue_entities import (
+    QueueAgentLogEvent,
     QueueErrorEvent,
     QueueIterationCompletedEvent,
     QueueIterationNextEvent,
@@ -106,6 +107,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
 
         self._task_state = WorkflowTaskState()
         self._wip_workflow_node_executions = {}
+        self._wip_workflow_agent_logs = {}
         self.total_tokens: int = 0
 
     def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
@@ -216,7 +218,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                     break
                 else:
                     yield MessageAudioStreamResponse(audio=audio_trunk.audio, task_id=task_id)
-            except Exception as e:
+            except Exception:
                 logger.exception(f"Fails to get audio trunk, task_id: {task_id}")
                 break
         if tts_publisher:
@@ -387,6 +389,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
                 yield self._text_chunk_to_stream_response(
                     delta_text, from_variable_selector=event.from_variable_selector
                 )
+            elif isinstance(event, QueueAgentLogEvent):
+                yield self._handle_agent_log(task_id=self._application_generate_entity.task_id, event=event)
             else:
                 continue
 

+ 13 - 0
api/core/app/apps/workflow_app_runner.py

@@ -5,6 +5,7 @@ from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
 from core.app.apps.base_app_runner import AppRunner
 from core.app.entities.queue_entities import (
     AppQueueEvent,
+    QueueAgentLogEvent,
     QueueIterationCompletedEvent,
     QueueIterationNextEvent,
     QueueIterationStartEvent,
@@ -23,6 +24,7 @@ from core.app.entities.queue_entities import (
 )
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.graph_engine.entities.event import (
+    AgentLogEvent,
     GraphEngineEvent,
     GraphRunFailedEvent,
     GraphRunStartedEvent,
@@ -295,6 +297,17 @@ class WorkflowBasedAppRunner(AppRunner):
                     retriever_resources=event.retriever_resources, in_iteration_id=event.in_iteration_id
                 )
             )
+        elif isinstance(event, AgentLogEvent):
+            self._publish_event(
+                QueueAgentLogEvent(
+                    id=event.id,
+                    node_execution_id=event.node_execution_id,
+                    parent_id=event.parent_id,
+                    error=event.error,
+                    status=event.status,
+                    data=event.data,
+                )
+            )
         elif isinstance(event, ParallelBranchRunStartedEvent):
             self._publish_event(
                 QueueParallelBranchRunStartedEvent(

+ 16 - 1
api/core/app/entities/queue_entities.py

@@ -1,6 +1,6 @@
 from datetime import datetime
 from enum import Enum, StrEnum
-from typing import Any, Optional
+from typing import Any, Mapping, Optional
 
 from pydantic import BaseModel
 
@@ -38,6 +38,7 @@ class QueueEvent(StrEnum):
     PARALLEL_BRANCH_RUN_STARTED = "parallel_branch_run_started"
     PARALLEL_BRANCH_RUN_SUCCEEDED = "parallel_branch_run_succeeded"
     PARALLEL_BRANCH_RUN_FAILED = "parallel_branch_run_failed"
+    AGENT_LOG = "agent_log"
     ERROR = "error"
     PING = "ping"
     STOP = "stop"
@@ -300,6 +301,20 @@ class QueueNodeSucceededEvent(AppQueueEvent):
     iteration_duration_map: Optional[dict[str, float]] = None
 
 
+class QueueAgentLogEvent(AppQueueEvent):
+    """
+    QueueAgentLogEvent entity
+    """
+
+    event: QueueEvent = QueueEvent.AGENT_LOG
+    id: str
+    node_execution_id: str
+    parent_id: str | None
+    error: str | None
+    status: str
+    data: Mapping[str, Any]
+
+
 class QueueNodeInIterationFailedEvent(AppQueueEvent):
     """
     QueueNodeInIterationFailedEvent entity

+ 22 - 0
api/core/app/entities/task_entities.py

@@ -59,6 +59,7 @@ class StreamEvent(Enum):
     ITERATION_COMPLETED = "iteration_completed"
     TEXT_CHUNK = "text_chunk"
     TEXT_REPLACE = "text_replace"
+    AGENT_LOG = "agent_log"
 
 
 class StreamResponse(BaseModel):
@@ -625,3 +626,24 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
 
     workflow_run_id: str
     data: Data
+
+
+class AgentLogStreamResponse(StreamResponse):
+    """
+    AgentLogStreamResponse entity
+    """
+
+    class Data(BaseModel):
+        """
+        Data entity
+        """
+
+        node_execution_id: str
+        id: str
+        parent_id: str | None
+        error: str | None
+        status: str
+        data: Mapping[str, Any]
+
+    event: StreamEvent = StreamEvent.AGENT_LOG
+    data: Data

+ 72 - 6
api/core/app/task_pipeline/workflow_cycle_manage.py

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
 
 from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
 from core.app.entities.queue_entities import (
+    QueueAgentLogEvent,
     QueueIterationCompletedEvent,
     QueueIterationNextEvent,
     QueueIterationStartEvent,
@@ -21,6 +22,7 @@ from core.app.entities.queue_entities import (
     QueueParallelBranchRunSucceededEvent,
 )
 from core.app.entities.task_entities import (
+    AgentLogStreamResponse,
     IterationNodeCompletedStreamResponse,
     IterationNodeNextStreamResponse,
     IterationNodeStartStreamResponse,
@@ -63,6 +65,7 @@ class WorkflowCycleManage:
     _task_state: WorkflowTaskState
     _workflow_system_variables: dict[SystemVariableKey, Any]
     _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
+    _wip_workflow_agent_logs: dict[str, list[AgentLogStreamResponse.Data]]
 
     def _handle_workflow_run_start(self) -> WorkflowRun:
         max_sequence = (
@@ -283,9 +286,16 @@ class WorkflowCycleManage:
         inputs = WorkflowEntry.handle_special_values(event.inputs)
         process_data = WorkflowEntry.handle_special_values(event.process_data)
         outputs = WorkflowEntry.handle_special_values(event.outputs)
-        execution_metadata = (
-            json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
-        )
+        execution_metadata_dict = event.execution_metadata
+        if self._wip_workflow_agent_logs.get(event.node_execution_id):
+            if not execution_metadata_dict:
+                execution_metadata_dict = {}
+
+            execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get(
+                event.node_execution_id, []
+            )
+
+        execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
         finished_at = datetime.now(UTC).replace(tzinfo=None)
         elapsed_time = (finished_at - event.start_at).total_seconds()
 
@@ -332,9 +342,16 @@ class WorkflowCycleManage:
         outputs = WorkflowEntry.handle_special_values(event.outputs)
         finished_at = datetime.now(UTC).replace(tzinfo=None)
         elapsed_time = (finished_at - event.start_at).total_seconds()
-        execution_metadata = (
-            json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
-        )
+        execution_metadata_dict = event.execution_metadata
+        if self._wip_workflow_agent_logs.get(event.node_execution_id):
+            if not execution_metadata_dict:
+                execution_metadata_dict = {}
+
+            execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get(
+                event.node_execution_id, []
+            )
+
+        execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
         db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
             {
                 WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value,
@@ -746,3 +763,52 @@ class WorkflowCycleManage:
             raise Exception(f"Workflow node execution not found: {node_execution_id}")
 
         return workflow_node_execution
+
+    def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
+        """
+        Handle agent log
+        :param task_id: task id
+        :param event: agent log event
+        :return:
+        """
+        node_execution = self._wip_workflow_node_executions.get(event.node_execution_id)
+        if not node_execution:
+            raise Exception(f"Workflow node execution not found: {event.node_execution_id}")
+
+        node_execution_id = node_execution.id
+        original_agent_logs = self._wip_workflow_agent_logs.get(node_execution_id, [])
+
+        # try to find the log with the same id
+        for log in original_agent_logs:
+            if log.id == event.id:
+                # update the log
+                log.status = event.status
+                log.error = event.error
+                log.data = event.data
+                break
+        else:
+            # append the log
+            original_agent_logs.append(
+                AgentLogStreamResponse.Data(
+                    id=event.id,
+                    parent_id=event.parent_id,
+                    node_execution_id=node_execution_id,
+                    error=event.error,
+                    status=event.status,
+                    data=event.data,
+                )
+            )
+
+        self._wip_workflow_agent_logs[node_execution_id] = original_agent_logs
+
+        return AgentLogStreamResponse(
+            task_id=task_id,
+            data=AgentLogStreamResponse.Data(
+                node_execution_id=node_execution_id,
+                id=event.id,
+                parent_id=event.parent_id,
+                error=event.error,
+                status=event.status,
+                data=event.data,
+            ),
+        )

+ 15 - 2
api/core/tools/entities/tool_entities.py

@@ -1,7 +1,7 @@
 import base64
 import enum
 from enum import Enum
-from typing import Any, Optional, Union
+from typing import Any, Mapping, Optional, Union
 
 from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator
 
@@ -150,6 +150,18 @@ class ToolInvokeMessage(BaseModel):
                 raise ValueError(f"The variable name '{value}' is reserved.")
             return value
 
+    class LogMessage(BaseModel):
+        class LogStatus(Enum):
+            START = "start"
+            ERROR = "error"
+            SUCCESS = "success"
+
+        id: str
+        parent_id: Optional[str] = Field(default=None, description="Leave empty for root log")
+        error: Optional[str] = Field(default=None, description="The error message")
+        status: LogStatus = Field(..., description="The status of the log")
+        data: Mapping[str, Any] = Field(..., description="Detailed log data")
+
     class MessageType(Enum):
         TEXT = "text"
         IMAGE = "image"
@@ -160,12 +172,13 @@ class ToolInvokeMessage(BaseModel):
         BINARY_LINK = "binary_link"
         VARIABLE = "variable"
         FILE = "file"
+        LOG = "log"
 
     type: MessageType = MessageType.TEXT
     """
         plain text, image url or link url
     """
-    message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | None
+    message: JsonMessage | TextMessage | BlobMessage | VariableMessage | FileMessage | LogMessage | None
     meta: dict[str, Any] | None = None
 
     @field_validator("message", mode="before")

+ 1 - 0
api/core/workflow/entities/node_entities.py

@@ -17,6 +17,7 @@ class NodeRunMetadataKey(StrEnum):
     TOTAL_PRICE = "total_price"
     CURRENCY = "currency"
     TOOL_INFO = "tool_info"
+    AGENT_LOG = "agent_log"
     ITERATION_ID = "iteration_id"
     ITERATION_INDEX = "iteration_index"
     PARALLEL_ID = "parallel_id"

+ 19 - 1
api/core/workflow/graph_engine/entities/event.py

@@ -170,4 +170,22 @@ class IterationRunFailedEvent(BaseIterationEvent):
     error: str = Field(..., description="failed reason")
 
 
-InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent
+###########################################
+# Agent Events
+###########################################
+
+
+class BaseAgentEvent(GraphEngineEvent):
+    pass
+
+
+class AgentLogEvent(BaseAgentEvent):
+    id: str = Field(..., description="id")
+    node_execution_id: str = Field(..., description="node execution id")
+    parent_id: str | None = Field(..., description="parent id")
+    error: str | None = Field(..., description="error")
+    status: str = Field(..., description="status")
+    data: Mapping[str, Any] = Field(..., description="data")
+
+
+InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent | BaseAgentEvent

+ 1 - 1
api/core/workflow/graph_engine/graph_engine.py

@@ -152,7 +152,7 @@ class GraphEngine:
                     elif isinstance(item, NodeRunSucceededEvent):
                         if item.node_type == NodeType.END:
                             self.graph_runtime_state.outputs = (
-                                item.route_node_state.node_run_result.outputs
+                                dict(item.route_node_state.node_run_result.outputs)
                                 if item.route_node_state.node_run_result
                                 and item.route_node_state.node_run_result.outputs
                                 else {}

+ 12 - 0
api/core/workflow/nodes/tool/tool_node.py

@@ -16,6 +16,7 @@ from core.variables.variables import ArrayAnyVariable
 from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
 from core.workflow.entities.variable_pool import VariablePool
 from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.event import AgentLogEvent
 from core.workflow.nodes.base import BaseNode
 from core.workflow.nodes.enums import NodeType
 from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
@@ -55,6 +56,17 @@ class ToolNode(BaseNode[ToolNodeData]):
             "plugin_unique_identifier": node_data.plugin_unique_identifier,
         }
 
+        yield AgentLogEvent(
+            id=self.node_id,
+            node_execution_id=self.id,
+            parent_id=None,
+            error=None,
+            status="running",
+            data={
+                "tool_info": tool_info,
+            },
+        )
+
         # get tool runtime
         try:
             tool_runtime = ToolManager.get_workflow_tool_runtime(