소스 검색

feat: support comfyui workflow tool image generate image (#9871)

非法操作 5 달 전
부모
커밋
ace7ffab5f

+ 23 - 14
api/core/tools/provider/builtin/comfyui/tools/comfyui_client.py

@@ -1,3 +1,5 @@
+import base64
+import io
 import json
 import json
 import random
 import random
 import uuid
 import uuid
@@ -6,45 +8,48 @@ import httpx
 from websocket import WebSocket
 from websocket import WebSocket
 from yarl import URL
 from yarl import URL
 
 
+from core.file.file_manager import _get_encoded_string
+from core.file.models import File
+
 
 
 class ComfyUiClient:
 class ComfyUiClient:
     def __init__(self, base_url: str):
     def __init__(self, base_url: str):
         self.base_url = URL(base_url)
         self.base_url = URL(base_url)
 
 
-    def get_history(self, prompt_id: str):
+    def get_history(self, prompt_id: str) -> dict:
         res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id})
         res = httpx.get(str(self.base_url / "history"), params={"prompt_id": prompt_id})
         history = res.json()[prompt_id]
         history = res.json()[prompt_id]
         return history
         return history
 
 
-    def get_image(self, filename: str, subfolder: str, folder_type: str):
+    def get_image(self, filename: str, subfolder: str, folder_type: str) -> bytes:
         response = httpx.get(
         response = httpx.get(
             str(self.base_url / "view"),
             str(self.base_url / "view"),
             params={"filename": filename, "subfolder": subfolder, "type": folder_type},
             params={"filename": filename, "subfolder": subfolder, "type": folder_type},
         )
         )
         return response.content
         return response.content
 
 
-    def upload_image(self, input_path: str, name: str, image_type: str = "input", overwrite: bool = False):
-        # plan to support img2img in dify 0.10.0
-        with open(input_path, "rb") as file:
-            files = {"image": (name, file, "image/png")}
-            data = {"type": image_type, "overwrite": str(overwrite).lower()}
-
-        res = httpx.post(str(self.base_url / "upload/image"), data=data, files=files)
-        return res
+    def upload_image(self, image_file: File) -> dict:
+        image_content = base64.b64decode(_get_encoded_string(image_file))
+        file = io.BytesIO(image_content)
+        files = {"image": (image_file.filename, file, image_file.mime_type), "overwrite": "true"}
+        res = httpx.post(str(self.base_url / "upload/image"), files=files)
+        return res.json()
 
 
-    def queue_prompt(self, client_id: str, prompt: dict):
+    def queue_prompt(self, client_id: str, prompt: dict) -> str:
         res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt})
         res = httpx.post(str(self.base_url / "prompt"), json={"client_id": client_id, "prompt": prompt})
         prompt_id = res.json()["prompt_id"]
         prompt_id = res.json()["prompt_id"]
         return prompt_id
         return prompt_id
 
 
-    def open_websocket_connection(self):
+    def open_websocket_connection(self) -> tuple[WebSocket, str]:
         client_id = str(uuid.uuid4())
         client_id = str(uuid.uuid4())
         ws = WebSocket()
         ws = WebSocket()
         ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
         ws_address = f"ws://{self.base_url.authority}/ws?clientId={client_id}"
         ws.connect(ws_address)
         ws.connect(ws_address)
         return ws, client_id
         return ws, client_id
 
 
-    def set_prompt(self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = ""):
+    def set_prompt(
+        self, origin_prompt: dict, positive_prompt: str, negative_prompt: str = "", image_name: str = ""
+    ) -> dict:
         """
         """
         find the first KSampler, then can find the prompt node through it.
         find the first KSampler, then can find the prompt node through it.
         """
         """
@@ -58,6 +63,10 @@ class ComfyUiClient:
         if negative_prompt != "":
         if negative_prompt != "":
             negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
             negative_input_id = prompt.get(k_sampler)["inputs"]["negative"][0]
             prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
             prompt.get(negative_input_id)["inputs"]["text"] = negative_prompt
+
+        if image_name != "":
+            image_loader = [key for key, value in id_to_class_type.items() if value == "LoadImage"][0]
+            prompt.get(image_loader)["inputs"]["image"] = image_name
         return prompt
         return prompt
 
 
     def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
     def track_progress(self, prompt: dict, ws: WebSocket, prompt_id: str):
@@ -89,7 +98,7 @@ class ComfyUiClient:
             else:
             else:
                 continue
                 continue
 
 
-    def generate_image_by_prompt(self, prompt: dict):
+    def generate_image_by_prompt(self, prompt: dict) -> list[bytes]:
         try:
         try:
             ws, client_id = self.open_websocket_connection()
             ws, client_id = self.open_websocket_connection()
             prompt_id = self.queue_prompt(client_id, prompt)
             prompt_id = self.queue_prompt(client_id, prompt)

+ 5 - 3
api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.py

@@ -2,10 +2,9 @@ import json
 from typing import Any
 from typing import Any
 
 
 from core.tools.entities.tool_entities import ToolInvokeMessage
 from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.provider.builtin.comfyui.tools.comfyui_client import ComfyUiClient
 from core.tools.tool.builtin_tool import BuiltinTool
 from core.tools.tool.builtin_tool import BuiltinTool
 
 
-from .comfyui_client import ComfyUiClient
-
 
 
 class ComfyUIWorkflowTool(BuiltinTool):
 class ComfyUIWorkflowTool(BuiltinTool):
     def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
     def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
@@ -14,13 +13,16 @@ class ComfyUIWorkflowTool(BuiltinTool):
         positive_prompt = tool_parameters.get("positive_prompt")
         positive_prompt = tool_parameters.get("positive_prompt")
         negative_prompt = tool_parameters.get("negative_prompt")
         negative_prompt = tool_parameters.get("negative_prompt")
         workflow = tool_parameters.get("workflow_json")
         workflow = tool_parameters.get("workflow_json")
+        image_name = ""
+        if image := tool_parameters.get("image"):
+            image_name = comfyui.upload_image(image).get("name")
 
 
         try:
         try:
             origin_prompt = json.loads(workflow)
             origin_prompt = json.loads(workflow)
         except:
         except:
             return self.create_text_message("the Workflow JSON is not correct")
             return self.create_text_message("the Workflow JSON is not correct")
 
 
-        prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt)
+        prompt = comfyui.set_prompt(origin_prompt, positive_prompt, negative_prompt, image_name)
         images = comfyui.generate_image_by_prompt(prompt)
         images = comfyui.generate_image_by_prompt(prompt)
         result = []
         result = []
         for img in images:
         for img in images:

+ 7 - 0
api/core/tools/provider/builtin/comfyui/tools/comfyui_workflow.yaml

@@ -24,6 +24,13 @@ parameters:
       zh_Hans: 负面提示词
       zh_Hans: 负面提示词
     llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
     llm_description: Negative prompt, you should describe the image you don't want to generate as a list of words as possible as detailed, the prompt must be written in English.
     form: llm
     form: llm
+  - name: image
+    type: file
+    label:
+      en_US: Input Image
+      zh_Hans: 输入的图片
+    llm_description: The input image, used to transfer to the comfyui workflow to generate another image.
+    form: llm
   - name: workflow_json
   - name: workflow_json
     type: string
     type: string
     required: true
     required: true