Browse Source

fix: avoid nil in ToolProviderDeclaration

Yeuoly 9 months ago
parent
commit
67e67fdd2d

+ 211 - 0
cmd/commandline/init/templates/python/tts.py

@@ -0,0 +1,211 @@
+from collections.abc import Generator
+import concurrent.futures
+from functools import reduce
+from io import BytesIO
+from typing import Optional
+
+from openai import OpenAI
+from pydub import AudioSegment
+
+from dify_plugin import TTSModel
+from dify_plugin.errors.model import (
+    CredentialsValidateFailedError,
+    InvokeBadRequestError,
+)
+from ..common_openai import _CommonOpenAI
+
+
+class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
+    """
+    Model class for OpenAI Speech to text model.
+    """
+
+    def _invoke(
+        self,
+        model: str,
+        credentials: dict,
+        content_text: str,
+        voice: str,
+        user: Optional[str] = None,
+    ) -> bytes | Generator[bytes, None, None]:
+        """
+        _invoke text2speech model
+
+        :param model: model name
+        :param tenant_id: user tenant id
+        :param credentials: model credentials
+        :param content_text: text content to be translated
+        :param voice: model timbre
+        :param user: unique user id
+        :return: text translated to audio file
+        """
+
+        voices = self.get_tts_model_voices(model=model, credentials=credentials)
+        if not voices:
+            raise InvokeBadRequestError("No voices found for the model")
+
+        if not voice or voice not in [d["value"] for d in voices]:
+            voice = self._get_model_default_voice(model, credentials)
+
+        # if streaming:
+        return self._tts_invoke_streaming(
+            model=model, credentials=credentials, content_text=content_text, voice=voice
+        )
+
+    def validate_credentials(
+        self, model: str, credentials: dict, user: Optional[str] = None
+    ) -> None:
+        """
+        validate credentials text2speech model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param user: unique user id
+        :return: text translated to audio file
+        """
+        try:
+            self._tts_invoke(
+                model=model,
+                credentials=credentials,
+                content_text="Hello Dify!",
+                voice=self._get_model_default_voice(model, credentials),
+            )
+        except Exception as ex:
+            raise CredentialsValidateFailedError(str(ex))
+
+    def _tts_invoke(
+        self, model: str, credentials: dict, content_text: str, voice: str
+    ) -> bytes:
+        """
+        _tts_invoke text2speech model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param content_text: text content to be translated
+        :param voice: model timbre
+        :return: text translated to audio file
+        """
+        audio_type = self._get_model_audio_type(model, credentials)
+        word_limit = self._get_model_word_limit(model, credentials) or 500
+        max_workers = self._get_model_workers_limit(model, credentials)
+        try:
+            sentences = list(
+                self._split_text_into_sentences(
+                    org_text=content_text, max_length=word_limit
+                )
+            )
+            audio_bytes_list = []
+
+            # Create a thread pool and map the function to the list of sentences
+            with concurrent.futures.ThreadPoolExecutor(
+                max_workers=max_workers
+            ) as executor:
+                futures = [
+                    executor.submit(
+                        self._process_sentence,
+                        sentence=sentence,
+                        model=model,
+                        voice=voice,
+                        credentials=credentials,
+                    )
+                    for sentence in sentences
+                ]
+                for future in futures:
+                    try:
+                        if future.result():
+                            audio_bytes_list.append(future.result())
+                    except Exception as ex:
+                        raise InvokeBadRequestError(str(ex))
+
+            if len(audio_bytes_list) > 0:
+                audio_segments = [
+                    AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type)
+                    for audio_bytes in audio_bytes_list
+                    if audio_bytes
+                ]
+                combined_segment = reduce(lambda x, y: x + y, audio_segments)
+                buffer: BytesIO = BytesIO()
+                combined_segment.export(buffer, format=audio_type)
+                buffer.seek(0)
+
+                return buffer.read()
+            else:
+                raise InvokeBadRequestError("No audio bytes found")
+        except Exception as ex:
+            raise InvokeBadRequestError(str(ex))
+
+    def _tts_invoke_streaming(
+        self, model: str, credentials: dict, content_text: str, voice: str
+    ) -> Generator[bytes, None, None]:
+        """
+        _tts_invoke_streaming text2speech model
+
+        :param model: model name
+        :param credentials: model credentials
+        :param content_text: text content to be translated
+        :param voice: model timbre
+        :return: text translated to audio file
+        """
+        try:
+            # doc: https://platform.openai.com/docs/guides/text-to-speech
+            credentials_kwargs = self._to_credential_kwargs(credentials)
+            client = OpenAI(**credentials_kwargs)
+
+            voices = self.get_tts_model_voices(model=model, credentials=credentials)
+            if not voices:
+                raise InvokeBadRequestError("No voices found for the model")
+
+            if not voice or voice not in voices:
+                voice = self._get_model_default_voice(model, credentials)
+
+            word_limit = self._get_model_word_limit(model, credentials) or 500
+            if len(content_text) > word_limit:
+                sentences = self._split_text_into_sentences(
+                    content_text, max_length=word_limit
+                )
+                executor = concurrent.futures.ThreadPoolExecutor(
+                    max_workers=min(3, len(sentences))
+                )
+                futures = [
+                    executor.submit(
+                        client.audio.speech.with_streaming_response.create,
+                        model=model,
+                        response_format="mp3",
+                        input=sentences[i],
+                        voice=voice,  # type: ignore
+                    )
+                    for i in range(len(sentences))
+                ]
+                for index, future in enumerate(futures):
+                    yield from future.result().__enter__().iter_bytes(1024)
+
+            else:
+                response = client.audio.speech.with_streaming_response.create(
+                    model=model,
+                    voice=voice,  # type: ignore
+                    response_format="mp3",
+                    input=content_text.strip(),
+                )
+
+                yield from response.__enter__().iter_bytes(1024)
+        except Exception as ex:
+            raise InvokeBadRequestError(str(ex))
+
+    def _process_sentence(self, sentence: str, model: str, voice, credentials: dict):
+        """
+        _tts_invoke openai text2speech model api
+
+        :param model: model name
+        :param credentials: model credentials
+        :param voice: model timbre
+        :param sentence: text content to be translated
+        :return: text translated to audio file
+        """
+        # transform credentials to kwargs for model instance
+        credentials_kwargs = self._to_credential_kwargs(credentials)
+        client = OpenAI(**credentials_kwargs)
+        response = client.audio.speech.create(
+            model=model, voice=voice, input=sentence.strip()
+        )
+        if isinstance(response.read(), bytes):
+            return response.read()

+ 28 - 0
internal/types/entities/plugin_entities/tool_declaration.go

@@ -130,6 +130,18 @@ type ToolProviderDeclaration struct {
 	ToolFiles         []string             `json:"-" yaml:"-"`
 }
 
+func (t *ToolProviderDeclaration) MarshalJSON() ([]byte, error) {
+	type alias ToolProviderDeclaration
+	p := alias(*t)
+	if p.CredentialsSchema == nil {
+		p.CredentialsSchema = []ProviderConfig{}
+	}
+	if p.Tools == nil {
+		p.Tools = []ToolDeclaration{}
+	}
+	return json.Marshal(p)
+}
+
 func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
 	type alias struct {
 		Identity               ToolProviderIdentity `yaml:"identity"`
@@ -196,6 +208,14 @@ func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
 		}
 	}
 
+	if t.CredentialsSchema == nil {
+		t.CredentialsSchema = []ProviderConfig{}
+	}
+
+	if t.Tools == nil {
+		t.Tools = []ToolDeclaration{}
+	}
+
 	return nil
 }
 
@@ -250,6 +270,14 @@ func (t *ToolProviderDeclaration) UnmarshalJSON(data []byte) error {
 		}
 	}
 
+	if t.CredentialsSchema == nil {
+		t.CredentialsSchema = []ProviderConfig{}
+	}
+
+	if t.Tools == nil {
+		t.Tools = []ToolDeclaration{}
+	}
+
 	return nil
 }