fishaudio.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import os
  2. from collections.abc import Callable
  3. from typing import Literal
  4. import httpx
  5. import pytest
  6. from _pytest.monkeypatch import MonkeyPatch
  7. def mock_get(*args, **kwargs):
  8. if kwargs.get("headers", {}).get("Authorization") != "Bearer test":
  9. raise httpx.HTTPStatusError(
  10. "Invalid API key",
  11. request=httpx.Request("GET", ""),
  12. response=httpx.Response(401),
  13. )
  14. return httpx.Response(
  15. 200,
  16. json={
  17. "items": [
  18. {"title": "Model 1", "_id": "model1"},
  19. {"title": "Model 2", "_id": "model2"},
  20. ]
  21. },
  22. request=httpx.Request("GET", ""),
  23. )
  24. def mock_stream(*args, **kwargs):
  25. class MockStreamResponse:
  26. def __init__(self):
  27. self.status_code = 200
  28. def __enter__(self):
  29. return self
  30. def __exit__(self, exc_type, exc_val, exc_tb):
  31. pass
  32. def iter_bytes(self):
  33. yield b"Mocked audio data"
  34. return MockStreamResponse()
  35. def mock_fishaudio(
  36. monkeypatch: MonkeyPatch,
  37. methods: list[Literal["list-models", "tts"]],
  38. ) -> Callable[[], None]:
  39. """
  40. mock fishaudio module
  41. :param monkeypatch: pytest monkeypatch fixture
  42. :return: unpatch function
  43. """
  44. def unpatch() -> None:
  45. monkeypatch.undo()
  46. if "list-models" in methods:
  47. monkeypatch.setattr(httpx, "get", mock_get)
  48. if "tts" in methods:
  49. monkeypatch.setattr(httpx, "stream", mock_stream)
  50. return unpatch
  51. MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
  52. @pytest.fixture
  53. def setup_fishaudio_mock(request, monkeypatch):
  54. methods = request.param if hasattr(request, "param") else []
  55. if MOCK:
  56. unpatch = mock_fishaudio(monkeypatch, methods=methods)
  57. yield
  58. if MOCK:
  59. unpatch()