瀏覽代碼

support variable

Yeuoly 7 月之前
父節點
當前提交
70c001436e
共有 2 個文件被更改,包括 64 次插入18 次删除
  1. 42 2
      api/core/tools/entities/tool_entities.py
  2. 22 16
      api/core/workflow/nodes/tool/tool_node.py

+ 42 - 2
api/core/tools/entities/tool_entities.py

@@ -1,7 +1,8 @@
+import base64
 from enum import Enum
 from typing import Any, Optional, Union, cast
 
-from pydantic import BaseModel, Field, field_validator
+from pydantic import BaseModel, Field, field_serializer, field_validator
 
 from core.entities.parameter_entities import AppSelectorScope, CommonParameterType, ModelConfigScope
 from core.tools.entities.common_entities import I18nObject
@@ -100,6 +101,26 @@ class ToolInvokeMessage(BaseModel):
     class BlobMessage(BaseModel):
         blob: bytes
 
+    class VariableMessage(BaseModel):
+        variable_name: str = Field(..., description="The name of the variable")
+        variable_value: str = Field(..., description="The value of the variable")
+        stream: bool = Field(default=False, description="Whether the variable is streamed")
+
+        @field_validator("variable_value", mode="before")
+        def transform_variable_value(cls, value, values) -> Any:
+            """
+            Only basic types and lists are allowed.
+            """
+            if not isinstance(value, dict | list | str | int | float | bool):
+                raise ValueError("Only basic types and lists are allowed.")
+            
+            # if stream is true, the value must be a string
+            if values.get('stream'):
+                if not isinstance(value, str):
+                    raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
+
+            return value
+
     class MessageType(Enum):
         TEXT = "text"
         IMAGE = "image"
@@ -108,15 +129,34 @@ class ToolInvokeMessage(BaseModel):
         JSON = "json"
         IMAGE_LINK = "image_link"
         FILE_VAR = "file_var"
+        VARIABLE = "variable"
 
     type: MessageType = MessageType.TEXT
     """
         plain text, image url or link url
     """
-    message: JsonMessage | TextMessage | BlobMessage | None
+    message: JsonMessage | TextMessage | BlobMessage | VariableMessage | None
     meta: dict[str, Any] | None = None
     save_as: str = ''
 
+    @field_validator('message', mode='before')
+    @classmethod
+    def decode_blob_message(cls, v):
+        if isinstance(v, dict) and 'blob' in v:
+            try:
+                v['blob'] = base64.b64decode(v['blob'])
+            except Exception:
+                pass
+        return v
+
+    @field_serializer('message')
+    def serialize_message(self, v):
+        if isinstance(v, self.BlobMessage):
+            return {
+                'blob': base64.b64encode(v.blob).decode('utf-8')
+            }
+        return v
+
 class ToolInvokeMessageBinary(BaseModel):
     mimetype: str = Field(..., description="The mimetype of the binary")
     url: str = Field(..., description="The url of the binary")

+ 22 - 16
api/core/workflow/nodes/tool/tool_node.py

@@ -1,7 +1,6 @@
-from collections.abc import Generator, Iterable, Mapping, Sequence
+from collections.abc import Generator, Mapping, Sequence
 from os import path
 from typing import Any, cast
-from urllib import response
 
 from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
 from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@@ -98,19 +97,6 @@ class ToolNode(BaseNode):
         # convert tool messages
         yield from self._transform_message(message_stream, tool_info, parameters_for_log)
 
-        # return NodeRunResult(
-        #     status=WorkflowNodeExecutionStatus.SUCCEEDED,
-        #     outputs={
-        #         'text': plain_text,
-        #         'files': files,
-        #         'json': json
-        #     },
-        #     metadata={
-        #         NodeRunMetadataKey.TOOL_INFO: tool_info
-        #     },
-        #     inputs=parameters_for_log
-        # )
-
     def _generate_parameters(
         self,
         *,
@@ -183,6 +169,8 @@ class ToolNode(BaseNode):
         files: list[FileVar] = []
         text = ""
         json: list[dict] = []
+        
+        variables: dict[str, Any] = {}
 
         for message in message_stream:
             if message.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
@@ -241,6 +229,23 @@ class ToolNode(BaseNode):
                     chunk_content=stream_text,
                     from_variable_selector=[self.node_id, 'text']
                 )
+            elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
+                assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
+                variable_name = message.message.variable_name
+                variable_value = message.message.variable_value
+                if message.message.stream:
+                    if not isinstance(variable_value, str):
+                        raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
+                    if variable_name not in variables:
+                        variables[variable_name] = ""
+                    variables[variable_name] += variable_value
+
+                    yield RunStreamChunkEvent(
+                        chunk_content=variable_value,
+                        from_variable_selector=[self.node_id, variable_name]
+                    )
+                else:
+                    variables[variable_name] = variable_value
 
         yield RunCompletedEvent(
             run_result=NodeRunResult(
@@ -248,7 +253,8 @@ class ToolNode(BaseNode):
                 outputs={
                     'text': text,
                     'files': files,
-                    'json': json
+                    'json': json,
+                    **variables
                 },
                 metadata={
                     NodeRunMetadataKey.TOOL_INFO: tool_info