Ver código fonte

feat: enhance comfyui workflow (#10085)

非法操作 1 ano atrás
pai
commit
bd6175157c

+ 17 - 14
api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py

@@ -1,5 +1,3 @@
-import base64
-import io
 import json
 import random
 import uuid
@@ -8,7 +6,7 @@ import httpx
 from websocket import WebSocket
 from yarl import URL
 
-from core.file.file_manager import _get_encoded_string
+from core.file.file_manager import download
 from core.file.models import File
 
 
@@ -29,8 +27,7 @@ class ComfyUiClient:
         return response.content
 
     def upload_image(self, image_file: File) -> dict:
-        image_content = base64.b64decode(_get_encoded_string(image_file))
-        file = io.BytesIO(image_content)
+        file = download(image_file)
         files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"}
         res = httpx.post(str(self.base_url / "upload/image"), files=files)
         return res.json()
@@ -47,12 +44,7 @@ class ComfyUiClient:
         ws.connect(ws_address)
         return ws, client_id
 
-    def set_prompt(
-        self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = ""
-    ) -> dict:
-        """
-        find the first KSampler, then can find the prompt node through it.
-        """
+    def set_prompt_by_ksampler(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "") -> dict:
         prompt = origin_prompt.copy()
         id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
         k_sampler = [key for key, value in id_to_class_type.items() if value == "KSampler"][0]
@@ -64,9 +56,20 @@ class ComfyUiClient:
             negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
             prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
 
-        if image_name != "":
-            image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0]
-            prompt.get(image_loader)["inputs"]["image"] = image_name
+        return prompt
+
+    def set_prompt_images_by_ids(self, origin_prompt: dict, image_names: list[str], image_ids: list[str]) -> dict:
+        prompt = origin_prompt.copy()
+        for index, image_node_id in enumerate(image_ids):
+            prompt[image_node_id]["inputs"]["image"] = image_names[index]
+        return prompt
+
+    def set_prompt_images_by_default(self, origin_prompt: dict, image_names: list[str]) -> dict:
+        prompt = origin_prompt.copy()
+        id_to_class_type = {id: details["class_type"] for id, details in prompt.items()}
+        load_image_nodes = [key for key, value in id_to_class_type.items() if value == "LoadImage"]
+        for load_image, image_name in zip(load_image_nodes, image_names):
+            prompt.get(load_image)["inputs"]["image"] = image_name
         return prompt
 
     def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):

+ 35 - 6
api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py

@@ -1,7 +1,9 @@
 import json
 from typing import Any
 
+from core.file import FileType
 from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.errors import ToolParameterValidationError
 from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient
 from core.tools.tool.builtin_tool import BuiltinTool
 
@@ -10,19 +12,46 @@ class ComfyUIWorkflowTool(BuiltinTool):
     def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
         comfyui = ComfyUiClient(self.runtime.credentials["base_url"])
 
-        positive_prompt = tool_parameters.get("positive_prompt")
-        negative_prompt = tool_parameters.get("negative_prompt")
+        positive_prompt = tool_parameters.get("positive_prompt", "")
+        negative_prompt = tool_parameters.get("negative_prompt", "")
+        images = tool_parameters.get("images") or []
         workflow = tool_parameters.get("workflow_json")
-        image_name = ""
-        if image := tool_parameters.get("image"):
+        image_names = []
+        for image in images:
+            if image.type != FileType.IMAGE:
+                continue
             image_name = comfyui.upload_image(image).get("name")
+            image_names.append(image_name)
+
+        set_prompt_with_ksampler = True
+        if "{{positive_prompt}}" in workflow:
+            set_prompt_with_ksampler = False
+            workflow = workflow.replace("{{positive_prompt}}", positive_prompt)
+            workflow = workflow.replace("{{negative_prompt}}", negative_prompt)
 
         try:
-            origin_prompt = json.loads(workflow)
+            prompt = json.loads(workflow)
         except:
             return self.create_text_message("the Workflow JSON is not correct")
 
-        prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name)
+        if set_prompt_with_ksampler:
+            try:
+                prompt = comfyui.set_prompt_by_ksampler(prompt, positive_prompt, negative_prompt)
+            except:
+                raise ToolParameterValidationError(
+                    "Failed set prompt with KSampler, try replace prompt to {{positive_prompt}} in your workflow json"
+                )
+
+        if image_names:
+            if image_ids := tool_parameters.get("image_ids"):
+                image_ids = image_ids.split(",")
+                try:
+                    prompt = comfyui.set_prompt_images_by_ids(prompt, image_names, image_ids)
+                except:
+                    raise ToolParameterValidationError("the Image Node ID List not match your upload image files.")
+            else:
+                prompt = comfyui.set_prompt_images_by_default(prompt, image_names)
+
         images = comfyui.generate_image_by_prompt(prompt)
         result = []
         for img in images:

+ 16 - 4
api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml

@@ -24,12 +24,12 @@ parameters:
       zh_Hans: 负面提示词
     llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
     form: llm
-  - name: image
-    type: file
+  - name: images
+    type: files
     label:
-      en_US: Input Image
+      en_US: Input Images
       zh_Hans: 输入的图片
-    llm_description: The input image, used to transfer to the comfyui workflow to generate another image.
+    llm_description: The input images, used to transfer to the comfyui workflow to generate another image.
     form: llm
   - name: workflow_json
     type: string
@@ -40,3 +40,15 @@ parameters:
       en_US: exported from ComfyUI workflow
       zh_Hans: 从ComfyUI的工作流中导出
     form: form
+  - name: image_ids
+    type: string
+    label:
+      en_US: Image Node ID List
+      zh_Hans: 图片节点ID列表
+    placeholder:
+      en_US: Use commas to separate multiple node ID
+      zh_Hans: 多个节点ID时使用半角逗号分隔
+    human_description:
+      en_US: When the workflow has multiple image nodes, enter the ID list of these nodes, and the images will be passed to ComfyUI in the order of the list.
+      zh_Hans: 当工作流有多个图片节点时,输入这些节点的ID列表,图片将按列表顺序传给ComfyUI
+    form: form