openai.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import os
  2. from collections.abc import Callable
  3. from typing import Literal
  4. import pytest
  5. # import monkeypatch
  6. from _pytest.monkeypatch import MonkeyPatch
  7. from openai.resources.audio.transcriptions import Transcriptions
  8. from openai.resources.chat import Completions as ChatCompletions
  9. from openai.resources.completions import Completions
  10. from openai.resources.embeddings import Embeddings
  11. from openai.resources.models import Models
  12. from openai.resources.moderations import Moderations
  13. from tests.integration_tests.model_runtime.__mock.openai_chat import MockChatClass
  14. from tests.integration_tests.model_runtime.__mock.openai_completion import MockCompletionsClass
  15. from tests.integration_tests.model_runtime.__mock.openai_embeddings import MockEmbeddingsClass
  16. from tests.integration_tests.model_runtime.__mock.openai_moderation import MockModerationClass
  17. from tests.integration_tests.model_runtime.__mock.openai_remote import MockModelClass
  18. from tests.integration_tests.model_runtime.__mock.openai_speech2text import MockSpeech2TextClass
  19. def mock_openai(
  20. monkeypatch: MonkeyPatch,
  21. methods: list[Literal["completion", "chat", "remote", "moderation", "speech2text", "text_embedding"]],
  22. ) -> Callable[[], None]:
  23. """
  24. mock openai module
  25. :param monkeypatch: pytest monkeypatch fixture
  26. :return: unpatch function
  27. """
  28. def unpatch() -> None:
  29. monkeypatch.undo()
  30. if "completion" in methods:
  31. monkeypatch.setattr(Completions, "create", MockCompletionsClass.completion_create)
  32. if "chat" in methods:
  33. monkeypatch.setattr(ChatCompletions, "create", MockChatClass.chat_create)
  34. if "remote" in methods:
  35. monkeypatch.setattr(Models, "list", MockModelClass.list)
  36. if "moderation" in methods:
  37. monkeypatch.setattr(Moderations, "create", MockModerationClass.moderation_create)
  38. if "speech2text" in methods:
  39. monkeypatch.setattr(Transcriptions, "create", MockSpeech2TextClass.speech2text_create)
  40. if "text_embedding" in methods:
  41. monkeypatch.setattr(Embeddings, "create", MockEmbeddingsClass.create_embeddings)
  42. return unpatch
  43. MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
  44. @pytest.fixture
  45. def setup_openai_mock(request, monkeypatch):
  46. methods = request.param if hasattr(request, "param") else []
  47. if MOCK:
  48. unpatch = mock_openai(monkeypatch, methods=methods)
  49. yield
  50. if MOCK:
  51. unpatch()