瀏覽代碼

feat: stable diffusion 3 (#3599)

Yeuoly 1 年之前
父節點
當前提交
d9f1a8ce9f

+ 1 - 0
api/core/tools/provider/_position.yaml

@@ -4,6 +4,7 @@
 - searxng
 - dalle
 - azuredalle
+- stability
 - wikipedia
 - model.openai
 - model.google

文件差異過大導致無法顯示
+ 10 - 0
api/core/tools/provider/builtin/stability/_assets/icon.svg


+ 15 - 0
api/core/tools/provider/builtin/stability/stability.py

@@ -0,0 +1,15 @@
+from typing import Any
+
+from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+class StabilityToolProvider(BuiltinToolProviderController, BaseStabilityAuthorization):
+    """
+    This class is responsible for providing the stability tool.
+    """
+    def _validate_credentials(self, credentials: dict[str, Any]) -> None:
+        """
+        This method is responsible for validating the credentials.
+        """
+        self.sd_validate_credentials(credentials)

+ 29 - 0
api/core/tools/provider/builtin/stability/stability.yaml

@@ -0,0 +1,29 @@
+identity:
+  author: Dify
+  name: stability
+  label:
+    en_US: Stability
+    zh_Hans: Stability
+    pt_BR: Stability
+  description:
+    en_US: Activating humanity's potential through generative AI
+    zh_Hans: 通过生成式 AI 激活人类的潜力
+    pt_BR: Activating humanity's potential through generative AI
+  icon: icon.svg
+credentials_for_provider:
+  api_key:
+    type: secret-input
+    required: true
+    label:
+      en_US: API key
+      zh_Hans: API key
+      pt_BR: API key
+    placeholder:
+      en_US: Please input your API key
+      zh_Hans: 请输入你的 API key
+      pt_BR: Please input your API key
+    help:
+      en_US: Get your API key from Stability
+      zh_Hans: 从 Stability 获取你的 API key
+      pt_BR: Get your API key from Stability
+    url: https://platform.stability.ai/account/keys

+ 34 - 0
api/core/tools/provider/builtin/stability/tools/base.py

@@ -0,0 +1,34 @@
+import requests
+from yarl import URL
+
+from core.tools.errors import ToolProviderCredentialValidationError
+
+
+class BaseStabilityAuthorization:
+    def sd_validate_credentials(self, credentials: dict):
+        """
+        This method is responsible for validating the credentials.
+        """
+        api_key = credentials.get('api_key', '')
+        if not api_key:
+            raise ToolProviderCredentialValidationError('API key is required.')
+        
+        response = requests.get(
+            URL('https://api.stability.ai') / 'v1' / 'user' / 'account', 
+            headers=self.generate_authorization_headers(credentials),
+            timeout=(5, 30)
+        )
+
+        if not response.ok:
+            raise ToolProviderCredentialValidationError('Invalid API key.')
+
+        return True
+    
+    def generate_authorization_headers(self, credentials: dict) -> dict[str, str]:
+        """
+        This method is responsible for generating the authorization headers.
+        """
+        return {
+            'Authorization': f'Bearer {credentials.get("api_key", "")}'
+        }
+    

+ 60 - 0
api/core/tools/provider/builtin/stability/tools/text2image.py

@@ -0,0 +1,60 @@
+from typing import Any
+
+from httpx import post
+
+from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.provider.builtin.stability.tools.base import BaseStabilityAuthorization
+from core.tools.tool.builtin_tool import BuiltinTool
+
+
+class StableDiffusionTool(BuiltinTool, BaseStabilityAuthorization):
+    """
+    This class is responsible for providing the stable diffusion tool.
+    """
+    model_endpoint_map = {
+        'sd3': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
+        'sd3-turbo': 'https://api.stability.ai/v2beta/stable-image/generate/sd3',
+        'core': 'https://api.stability.ai/v2beta/stable-image/generate/core',
+    }
+
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
+        """
+        Invoke the tool.
+        """
+        payload = {
+            'prompt': tool_parameters.get('prompt', ''),
+            'aspect_radio': tool_parameters.get('aspect_radio', '16:9'),
+            'mode': 'text-to-image',
+            'seed': tool_parameters.get('seed', 0),
+            'output_format': 'png',
+        }
+
+        model = tool_parameters.get('model', 'core')
+
+        if model in ['sd3', 'sd3-turbo']:
+            payload['model'] = tool_parameters.get('model')
+
+        if not model == 'sd3-turbo':
+            payload['negative_prompt'] = tool_parameters.get('negative_prompt', '')
+
+        response = post(
+            self.model_endpoint_map[tool_parameters.get('model', 'core')],
+            headers={
+                'accept': 'image/*',
+                **self.generate_authorization_headers(self.runtime.credentials),
+            },
+            files={
+                key: (None, str(value)) for key, value in payload.items()
+            },
+            timeout=(5, 30)
+        )
+
+        if not response.status_code == 200:
+            raise Exception(response.text)
+        
+        return self.create_blob_message(
+            blob=response.content, meta={
+                'mime_type': 'image/png'
+            },
+            save_as=self.VARIABLE_KEY.IMAGE.value
+        )

+ 142 - 0
api/core/tools/provider/builtin/stability/tools/text2image.yaml

@@ -0,0 +1,142 @@
+identity:
+  name: stability_text2image
+  author: Dify
+  label:
+    en_US: StableDiffusion
+    zh_Hans: 稳定扩散
+    pt_BR: StableDiffusion
+description:
+  human:
+    en_US: A tool for generate images based on the text input
+    zh_Hans: 一个基于文本输入生成图像的工具
+    pt_BR: A tool for generate images based on the text input
+  llm: A tool for generate images based on the text input
+parameters:
+  - name: prompt
+    type: string
+    required: true
+    label:
+      en_US: Prompt
+      zh_Hans: 提示词
+      pt_BR: Prompt
+    human_description:
+      en_US: used for generating images
+      zh_Hans: 用于生成图像
+      pt_BR: used for generating images
+    llm_description: key words for generating images
+    form: llm
+  - name: model
+    type: select
+    default: sd3-turbo
+    required: true
+    label:
+      en_US: Model
+      zh_Hans: 模型
+      pt_BR: Model
+    options:
+      - value: core
+        label:
+          en_US: Core
+          zh_Hans: Core
+          pt_BR: Core
+      - value: sd3
+        label:
+          en_US: Stable Diffusion 3
+          zh_Hans: Stable Diffusion 3
+          pt_BR: Stable Diffusion 3
+      - value: sd3-turbo
+        label:
+          en_US: Stable Diffusion 3 Turbo
+          zh_Hans: Stable Diffusion 3 Turbo
+          pt_BR: Stable Diffusion 3 Turbo
+    human_description:
+      en_US: Model for generating images
+      zh_Hans: 用于生成图像的模型
+      pt_BR: Model for generating images
+    llm_description: Model for generating images
+    form: form
+  - name: negative_prompt
+    type: string
+    default: bad art, ugly, deformed, watermark, duplicated, discontinuous lines
+    required: false
+    label:
+      en_US: Negative Prompt
+      zh_Hans: 负面提示
+      pt_BR: Negative Prompt
+    human_description:
+      en_US: Negative Prompt
+      zh_Hans: 负面提示
+      pt_BR: Negative Prompt
+    llm_description: Negative Prompt
+    form: form
+  - name: seeds
+    type: number
+    default: 0
+    required: false
+    label:
+      en_US: Seeds
+      zh_Hans: 种子
+      pt_BR: Seeds
+    human_description:
+      en_US: Seeds
+      zh_Hans: 种子
+      pt_BR: Seeds
+    llm_description: Seeds
+    min: 0
+    max: 4294967294
+    form: form
+  - name: aspect_radio
+    type: select
+    default: '16:9'
+    options:
+      - value: '16:9'
+        label:
+          en_US: '16:9'
+          zh_Hans: '16:9'
+          pt_BR: '16:9'
+      - value: '1:1'
+        label:
+          en_US: '1:1'
+          zh_Hans: '1:1'
+          pt_BR: '1:1'
+      - value: '21:9'
+        label:
+          en_US: '21:9'
+          zh_Hans: '21:9'
+          pt_BR: '21:9'
+      - value: '2:3'
+        label:
+          en_US: '2:3'
+          zh_Hans: '2:3'
+          pt_BR: '2:3'
+      - value: '4:5'
+        label:
+          en_US: '4:5'
+          zh_Hans: '4:5'
+          pt_BR: '4:5'
+      - value: '5:4'
+        label:
+          en_US: '5:4'
+          zh_Hans: '5:4'
+          pt_BR: '5:4'
+      - value: '9:16'
+        label:
+          en_US: '9:16'
+          zh_Hans: '9:16'
+          pt_BR: '9:16'
+      - value: '9:21'
+        label:
+          en_US: '9:21'
+          zh_Hans: '9:21'
+          pt_BR: '9:21'
+    required: false
+    label:
+      en_US: Aspect Radio
+      zh_Hans: 长宽比
+      pt_BR: Aspect Radio
+    human_description:
+      en_US: Aspect Radio
+      zh_Hans: 长宽比
+      pt_BR: Aspect Radio
+    llm_description: Aspect Radio
+    form: form