Ver código fonte

feat: add the audio tool (#10695)

非法操作 6 meses atrás
pai
commit
15f341b655

Diferenças do arquivo suprimidas por serem muito extensas
+ 3 - 0
api/core/tools/provider/builtin/audio/_assets/icon.svg


+ 6 - 0
api/core/tools/provider/builtin/audio/audio.py

@@ -0,0 +1,6 @@
+from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController
+
+
+class AudioToolProvider(BuiltinToolProviderController):
+    def _validate_credentials(self, credentials: dict) -> None:
+        pass

+ 11 - 0
api/core/tools/provider/builtin/audio/audio.yaml

@@ -0,0 +1,11 @@
+identity:
+  author: hjlarry
+  name: audio
+  label:
+    en_US: Audio
+  description:
+    en_US: A tool for tts and asr.
+    zh_Hans: 一个用于文本转语音和语音转文本的工具。
+  icon: icon.svg
+  tags:
+    - utilities

+ 70 - 0
api/core/tools/provider/builtin/audio/tools/asr.py

@@ -0,0 +1,70 @@
+import io
+from typing import Any
+
+from core.file.enums import FileType
+from core.file.file_manager import download
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelType
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
+from core.tools.tool.builtin_tool import BuiltinTool
+from services.model_provider_service import ModelProviderService
+
+
+class ASRTool(BuiltinTool):
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
+        file = tool_parameters.get("audio_file")
+        if file.type != FileType.AUDIO:
+            return [self.create_text_message("not a valid audio file")]
+        audio_binary = io.BytesIO(download(file))
+        audio_binary.name = "temp.mp3"
+        provider, model = tool_parameters.get("model").split("#")
+        model_manager = ModelManager()
+        model_instance = model_manager.get_model_instance(
+            tenant_id=self.runtime.tenant_id,
+            provider=provider,
+            model_type=ModelType.SPEECH2TEXT,
+            model=model,
+        )
+        text = model_instance.invoke_speech2text(
+            file=audio_binary,
+            user=user_id,
+        )
+        return [self.create_text_message(text)]
+
+    def get_available_models(self) -> list[tuple[str, str]]:
+        model_provider_service = ModelProviderService()
+        models = model_provider_service.get_models_by_model_type(
+            tenant_id=self.runtime.tenant_id, model_type="speech2text"
+        )
+        items = []
+        for provider_model in models:
+            provider = provider_model.provider
+            for model in provider_model.models:
+                items.append((provider, model.model))
+        return items
+
+    def get_runtime_parameters(self) -> list[ToolParameter]:
+        parameters = []
+
+        options = []
+        for provider, model in self.get_available_models():
+            option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
+            options.append(option)
+
+        parameters.append(
+            ToolParameter(
+                name="model",
+                label=I18nObject(en_US="Model", zh_Hans="Model"),
+                human_description=I18nObject(
+                    en_US="All available ASR models",
+                    zh_Hans="所有可用的 ASR 模型",
+                ),
+                type=ToolParameter.ToolParameterType.SELECT,
+                form=ToolParameter.ToolParameterForm.FORM,
+                required=True,
+                default=options[0].value,
+                options=options,
+            )
+        )
+        return parameters

+ 22 - 0
api/core/tools/provider/builtin/audio/tools/asr.yaml

@@ -0,0 +1,22 @@
+identity:
+  name: asr
+  author: hjlarry
+  label:
+    en_US: Speech To Text
+description:
+  human:
+    en_US: Convert audio file to text.
+    zh_Hans: 将音频文件转换为文本。
+  llm: Convert audio file to text.
+parameters:
+  - name: audio_file
+    type: file
+    required: true
+    label:
+      en_US: Audio File
+      zh_Hans: 音频文件
+    human_description:
+      en_US: The audio file to be converted.
+      zh_Hans: 要转换的音频文件。
+    llm_description: The audio file to be converted.
+    form: llm

+ 90 - 0
api/core/tools/provider/builtin/audio/tools/tts.py

@@ -0,0 +1,90 @@
+import io
+from typing import Any
+
+from core.model_manager import ModelManager
+from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolParameterOption
+from core.tools.tool.builtin_tool import BuiltinTool
+from services.model_provider_service import ModelProviderService
+
+
+class TTSTool(BuiltinTool):
+    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
+        provider, model = tool_parameters.get("model").split("#")
+        voice = tool_parameters.get(f"voice#{provider}#{model}")
+        model_manager = ModelManager()
+        model_instance = model_manager.get_model_instance(
+            tenant_id=self.runtime.tenant_id,
+            provider=provider,
+            model_type=ModelType.TTS,
+            model=model,
+        )
+        tts = model_instance.invoke_tts(
+            content_text=tool_parameters.get("text"),
+            user=user_id,
+            tenant_id=self.runtime.tenant_id,
+            voice=voice,
+        )
+        buffer = io.BytesIO()
+        for chunk in tts:
+            buffer.write(chunk)
+
+        wav_bytes = buffer.getvalue()
+        return [
+            self.create_text_message("Audio generated successfully"),
+            self.create_blob_message(
+                blob=wav_bytes,
+                meta={"mime_type": "audio/x-wav"},
+                save_as=self.VariableKey.AUDIO,
+            ),
+        ]
+
+    def get_available_models(self) -> list[tuple[str, str, list[Any]]]:
+        model_provider_service = ModelProviderService()
+        models = model_provider_service.get_models_by_model_type(tenant_id=self.runtime.tenant_id, model_type="tts")
+        items = []
+        for provider_model in models:
+            provider = provider_model.provider
+            for model in provider_model.models:
+                voices = model.model_properties.get(ModelPropertyKey.VOICES, [])
+                items.append((provider, model.model, voices))
+        return items
+
+    def get_runtime_parameters(self) -> list[ToolParameter]:
+        parameters = []
+
+        options = []
+        for provider, model, voices in self.get_available_models():
+            option = ToolParameterOption(value=f"{provider}#{model}", label=I18nObject(en_US=f"{model}({provider})"))
+            options.append(option)
+            parameters.append(
+                ToolParameter(
+                    name=f"voice#{provider}#{model}",
+                    label=I18nObject(en_US=f"Voice of {model}({provider})"),
+                    type=ToolParameter.ToolParameterType.SELECT,
+                    form=ToolParameter.ToolParameterForm.FORM,
+                    options=[
+                        ToolParameterOption(value=voice.get("mode"), label=I18nObject(en_US=voice.get("name")))
+                        for voice in voices
+                    ],
+                )
+            )
+
+        parameters.insert(
+            0,
+            ToolParameter(
+                name="model",
+                label=I18nObject(en_US="Model", zh_Hans="Model"),
+                human_description=I18nObject(
+                    en_US="All available TTS models",
+                    zh_Hans="所有可用的 TTS 模型",
+                ),
+                type=ToolParameter.ToolParameterType.SELECT,
+                form=ToolParameter.ToolParameterForm.FORM,
+                required=True,
+                default=options[0].value,
+                options=options,
+            ),
+        )
+        return parameters

+ 22 - 0
api/core/tools/provider/builtin/audio/tools/tts.yaml

@@ -0,0 +1,22 @@
+identity:
+  name: tts
+  author: hjlarry
+  label:
+    en_US: Text To Speech
+description:
+  human:
+    en_US: Convert text to audio file.
+    zh_Hans: 将文本转换为音频文件。
+  llm: Convert text to audio file.
+parameters:
+  - name: text
+    type: string
+    required: true
+    label:
+      en_US: Text
+      zh_Hans: 文本
+    human_description:
+      en_US: The text to be converted.
+      zh_Hans: 要转换的文本。
+    llm_description: The text to be converted.
+    form: llm