tts.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from collections.abc import Generator
  2. import concurrent.futures
  3. from functools import reduce
  4. from io import BytesIO
  5. from typing import Optional
  6. from openai import OpenAI
  7. from pydub import AudioSegment
  8. from dify_plugin import TTSModel
  9. from dify_plugin.errors.model import (
  10. CredentialsValidateFailedError,
  11. InvokeBadRequestError,
  12. )
  13. from ..common_openai import _CommonOpenAI
  14. class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel):
  15. """
  16. Model class for OpenAI Speech to text model.
  17. """
  18. def _invoke(
  19. self,
  20. model: str,
  21. credentials: dict,
  22. content_text: str,
  23. voice: str,
  24. user: Optional[str] = None,
  25. ) -> bytes | Generator[bytes, None, None]:
  26. """
  27. _invoke text2speech model
  28. :param model: model name
  29. :param tenant_id: user tenant id
  30. :param credentials: model credentials
  31. :param content_text: text content to be translated
  32. :param voice: model timbre
  33. :param user: unique user id
  34. :return: text translated to audio file
  35. """
  36. voices = self.get_tts_model_voices(model=model, credentials=credentials)
  37. if not voices:
  38. raise InvokeBadRequestError("No voices found for the model")
  39. if not voice or voice not in [d["value"] for d in voices]:
  40. voice = self._get_model_default_voice(model, credentials)
  41. # if streaming:
  42. return self._tts_invoke_streaming(
  43. model=model, credentials=credentials, content_text=content_text, voice=voice
  44. )
  45. def validate_credentials(
  46. self, model: str, credentials: dict, user: Optional[str] = None
  47. ) -> None:
  48. """
  49. validate credentials text2speech model
  50. :param model: model name
  51. :param credentials: model credentials
  52. :param user: unique user id
  53. :return: text translated to audio file
  54. """
  55. try:
  56. self._tts_invoke(
  57. model=model,
  58. credentials=credentials,
  59. content_text="Hello Dify!",
  60. voice=self._get_model_default_voice(model, credentials),
  61. )
  62. except Exception as ex:
  63. raise CredentialsValidateFailedError(str(ex))
  64. def _tts_invoke(
  65. self, model: str, credentials: dict, content_text: str, voice: str
  66. ) -> bytes:
  67. """
  68. _tts_invoke text2speech model
  69. :param model: model name
  70. :param credentials: model credentials
  71. :param content_text: text content to be translated
  72. :param voice: model timbre
  73. :return: text translated to audio file
  74. """
  75. audio_type = self._get_model_audio_type(model, credentials)
  76. word_limit = self._get_model_word_limit(model, credentials) or 500
  77. max_workers = self._get_model_workers_limit(model, credentials)
  78. try:
  79. sentences = list(
  80. self._split_text_into_sentences(
  81. org_text=content_text, max_length=word_limit
  82. )
  83. )
  84. audio_bytes_list = []
  85. # Create a thread pool and map the function to the list of sentences
  86. with concurrent.futures.ThreadPoolExecutor(
  87. max_workers=max_workers
  88. ) as executor:
  89. futures = [
  90. executor.submit(
  91. self._process_sentence,
  92. sentence=sentence,
  93. model=model,
  94. voice=voice,
  95. credentials=credentials,
  96. )
  97. for sentence in sentences
  98. ]
  99. for future in futures:
  100. try:
  101. if future.result():
  102. audio_bytes_list.append(future.result())
  103. except Exception as ex:
  104. raise InvokeBadRequestError(str(ex))
  105. if len(audio_bytes_list) > 0:
  106. audio_segments = [
  107. AudioSegment.from_file(BytesIO(audio_bytes), format=audio_type)
  108. for audio_bytes in audio_bytes_list
  109. if audio_bytes
  110. ]
  111. combined_segment = reduce(lambda x, y: x + y, audio_segments)
  112. buffer: BytesIO = BytesIO()
  113. combined_segment.export(buffer, format=audio_type)
  114. buffer.seek(0)
  115. return buffer.read()
  116. else:
  117. raise InvokeBadRequestError("No audio bytes found")
  118. except Exception as ex:
  119. raise InvokeBadRequestError(str(ex))
  120. def _tts_invoke_streaming(
  121. self, model: str, credentials: dict, content_text: str, voice: str
  122. ) -> Generator[bytes, None, None]:
  123. """
  124. _tts_invoke_streaming text2speech model
  125. :param model: model name
  126. :param credentials: model credentials
  127. :param content_text: text content to be translated
  128. :param voice: model timbre
  129. :return: text translated to audio file
  130. """
  131. try:
  132. # doc: https://platform.openai.com/docs/guides/text-to-speech
  133. credentials_kwargs = self._to_credential_kwargs(credentials)
  134. client = OpenAI(**credentials_kwargs)
  135. voices = self.get_tts_model_voices(model=model, credentials=credentials)
  136. if not voices:
  137. raise InvokeBadRequestError("No voices found for the model")
  138. if not voice or voice not in voices:
  139. voice = self._get_model_default_voice(model, credentials)
  140. word_limit = self._get_model_word_limit(model, credentials) or 500
  141. if len(content_text) > word_limit:
  142. sentences = self._split_text_into_sentences(
  143. content_text, max_length=word_limit
  144. )
  145. executor = concurrent.futures.ThreadPoolExecutor(
  146. max_workers=min(3, len(sentences))
  147. )
  148. futures = [
  149. executor.submit(
  150. client.audio.speech.with_streaming_response.create,
  151. model=model,
  152. response_format="mp3",
  153. input=sentences[i],
  154. voice=voice, # type: ignore
  155. )
  156. for i in range(len(sentences))
  157. ]
  158. for index, future in enumerate(futures):
  159. yield from future.result().__enter__().iter_bytes(1024)
  160. else:
  161. response = client.audio.speech.with_streaming_response.create(
  162. model=model,
  163. voice=voice, # type: ignore
  164. response_format="mp3",
  165. input=content_text.strip(),
  166. )
  167. yield from response.__enter__().iter_bytes(1024)
  168. except Exception as ex:
  169. raise InvokeBadRequestError(str(ex))
  170. def _process_sentence(self, sentence: str, model: str, voice, credentials: dict):
  171. """
  172. _tts_invoke openai text2speech model api
  173. :param model: model name
  174. :param credentials: model credentials
  175. :param voice: model timbre
  176. :param sentence: text content to be translated
  177. :return: text translated to audio file
  178. """
  179. # transform credentials to kwargs for model instance
  180. credentials_kwargs = self._to_credential_kwargs(credentials)
  181. client = OpenAI(**credentials_kwargs)
  182. response = client.audio.speech.create(
  183. model=model, voice=voice, input=sentence.strip()
  184. )
  185. if isinstance(response.read(), bytes):
  186. return response.read()