from collections.abc import Generator import concurrent.futures from functools import reduce from io import BytesIO from typing import Optional from openai import OpenAI from pydub import AudioSegment from dify_plugin import TTSModel from dify_plugin.errors.model import ( CredentialsValidateFailedError, InvokeBadRequestError, ) from ..common_openai import _CommonOpenAI class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): """ Model class for OpenAI Speech to text model. """ def _invoke( self, model: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None, ) -> bytes | Generator[bytes, None, None]: """ _invoke text2speech model :param model: model name :param tenant_id: user tenant id :param credentials: model credentials :param content_text: text content to be translated :param voice: model timbre :param user: unique user id :return: text translated to audio file """ voices = self.get_tts_model_voices(model=model, credentials=credentials) if not voices: raise InvokeBadRequestError("No voices found for the model") if not voice or voice not in [d["value"] for d in voices]: voice = self._get_model_default_voice(model, credentials) # if streaming: return self._tts_invoke_streaming( model=model, credentials=credentials, content_text=content_text, voice=voice ) def validate_credentials( self, model: str, credentials: dict, user: Optional[str] = None ) -> None: """ validate credentials text2speech model :param model: model name :param credentials: model credentials :param user: unique user id :return: text translated to audio file """ try: self._tts_invoke( model=model, credentials=credentials, content_text="Hello Dify!", voice=self._get_model_default_voice(model, credentials), ) except Exception as ex: raise CredentialsValidateFailedError(str(ex)) def _tts_invoke( self, model: str, credentials: dict, content_text: str, voice: str ) -> bytes: """ _tts_invoke text2speech model :param model: model name :param credentials: model credentials :param content_text: text content to be translated :param voice: model timbre :return: text translated to audio file """ audio_type = self._get_model_audio_type(model, credentials) word_limit = self._get_model_word_limit(model, credentials) or 500 max_workers = self._get_model_workers_limit(model, credentials) try: sentences = list( self._split_text_into_sentences( org_text=content_text, max_length=word_limit ) ) audio_bytes_list = [] # Create a thread pool and map the function to the list of sentences with concurrent.futures.ThreadPoolExecutor( max_workers=max_workers ) as executor: futures = [ executor.submit( self._process_sentence, sentence=sentence, model=model, voice=voice, credentials=credentials, ) for sentence in sentences ] for future in futures: try: if future.result(): audio_bytes_list.append(future.result()) except Exception as ex: raise InvokeBadRequestError(str(ex)) if len(audio_bytes_list) > 0: audio_segments = [ AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type) for audio_bytes in audio_bytes_list if audio_bytes ] combined_segment = reduce(lambda x, y: x + y, audio_segments) buffer: BytesIO = BytesIO() combined_segment.export(buffer, format=audio_type) buffer.seek(0) return buffer.read() else: raise InvokeBadRequestError("No audio bytes found") except Exception as ex: raise InvokeBadRequestError(str(ex)) def _tts_invoke_streaming( self, model: str, credentials: dict, content_text: str, voice: str ) -> Generator[bytes, None, None]: """ _tts_invoke_streaming text2speech model :param model: model name :param credentials: model credentials :param content_text: text content to be translated :param voice: model timbre :return: text translated to audio file """ try: # doc: https://platform.openai.com/docs/guides/text-to-speech credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) voices = self.get_tts_model_voices(model=model, credentials=credentials) if not voices: raise InvokeBadRequestError("No voices found for the model") if not voice or voice not in voices: voice = self._get_model_default_voice(model, credentials) word_limit = self._get_model_word_limit(model, credentials) or 500 if len(content_text) > word_limit: sentences = self._split_text_into_sentences( content_text, max_length=word_limit ) executor = concurrent.futures.ThreadPoolExecutor( max_workers=min(3, len(sentences)) ) futures = [ executor.submit( client.audio.speech.with_streaming_response.create, model=model, response_format="mp3", input=sentences[i], voice=voice, # type: ignore ) for i in range(len(sentences)) ] for index, future in enumerate(futures): yield from future.result().__enter__().iter_bytes(1024) else: response = client.audio.speech.with_streaming_response.create( model=model, voice=voice, # type: ignore response_format="mp3", input=content_text.strip(), ) yield from response.__enter__().iter_bytes(1024) except Exception as ex: raise InvokeBadRequestError(str(ex)) def _process_sentence(self, sentence: str, model: str, voice, credentials: dict): """ _tts_invoke openai text2speech model api :param model: model name :param credentials: model credentials :param voice: model timbre :param sentence: text content to be translated :return: text translated to audio file """ # transform credentials to kwargs for model instance credentials_kwargs = self._to_credential_kwargs(credentials) client = OpenAI(**credentials_kwargs) response = client.audio.speech.create( model=model, voice=voice, input=sentence.strip() ) if isinstance(response.read(), bytes): return response.read()