1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import os
- from collections.abc import Callable
- from typing import Literal
- import httpx
- import pytest
- from _pytest.monkeypatch import MonkeyPatch
- def mock_get(*args, **kwargs):
- if kwargs.get("headers", {}).get("Authorization") != "Bearer test":
- raise httpx.HTTPStatusError(
- "Invalid API key",
- request=httpx.Request("GET", ""),
- response=httpx.Response(401),
- )
- return httpx.Response(
- 200,
- json={
- "items": [
- {"title": "Model 1", "_id": "model1"},
- {"title": "Model 2", "_id": "model2"},
- ]
- },
- request=httpx.Request("GET", ""),
- )
- def mock_stream(*args, **kwargs):
- class MockStreamResponse:
- def __init__(self):
- self.status_code = 200
- def __enter__(self):
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- pass
- def iter_bytes(self):
- yield b"Mocked audio data"
- return MockStreamResponse()
- def mock_fishaudio(
- monkeypatch: MonkeyPatch,
- methods: list[Literal["list-models", "tts"]],
- ) -> Callable[[], None]:
- """
- mock fishaudio module
- :param monkeypatch: pytest monkeypatch fixture
- :return: unpatch function
- """
- def unpatch() -> None:
- monkeypatch.undo()
- if "list-models" in methods:
- monkeypatch.setattr(httpx, "get", mock_get)
- if "tts" in methods:
- monkeypatch.setattr(httpx, "stream", mock_stream)
- return unpatch
- MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
- @pytest.fixture
- def setup_fishaudio_mock(request, monkeypatch):
- methods = request.param if hasattr(request, "param") else []
- if MOCK:
- unpatch = mock_fishaudio(monkeypatch, methods=methods)
- yield
- if MOCK:
- unpatch()
|