google.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. from collections.abc import Generator
  2. import google.generativeai.types.generation_types as generation_config_types
  3. import pytest
  4. from _pytest.monkeypatch import MonkeyPatch
  5. from google.ai import generativelanguage as glm
  6. from google.ai.generativelanguage_v1beta.types import content as gag_content
  7. from google.generativeai import GenerativeModel
  8. from google.generativeai.client import _ClientManager, configure
  9. from google.generativeai.types import GenerateContentResponse, content_types, safety_types
  10. from google.generativeai.types.generation_types import BaseGenerateContentResponse
  11. current_api_key = ""
  12. class MockGoogleResponseClass:
  13. _done = False
  14. def __iter__(self):
  15. full_response_text = "it's google!"
  16. for i in range(0, len(full_response_text) + 1, 1):
  17. if i == len(full_response_text):
  18. self._done = True
  19. yield GenerateContentResponse(
  20. done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
  21. )
  22. else:
  23. yield GenerateContentResponse(
  24. done=False, iterator=None, result=glm.GenerateContentResponse({}), chunks=[]
  25. )
  26. class MockGoogleResponseCandidateClass:
  27. finish_reason = "stop"
  28. @property
  29. def content(self) -> gag_content.Content:
  30. return gag_content.Content(parts=[gag_content.Part(text="it's google!")])
  31. class MockGoogleClass:
  32. @staticmethod
  33. def generate_content_sync() -> GenerateContentResponse:
  34. return GenerateContentResponse(done=True, iterator=None, result=glm.GenerateContentResponse({}), chunks=[])
  35. @staticmethod
  36. def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
  37. return MockGoogleResponseClass()
  38. def generate_content(
  39. self: GenerativeModel,
  40. contents: content_types.ContentsType,
  41. *,
  42. generation_config: generation_config_types.GenerationConfigType | None = None,
  43. safety_settings: safety_types.SafetySettingOptions | None = None,
  44. stream: bool = False,
  45. **kwargs,
  46. ) -> GenerateContentResponse:
  47. global current_api_key
  48. if len(current_api_key) < 16:
  49. raise Exception("Invalid API key")
  50. if stream:
  51. return MockGoogleClass.generate_content_stream()
  52. return MockGoogleClass.generate_content_sync()
  53. @property
  54. def generative_response_text(self) -> str:
  55. return "it's google!"
  56. @property
  57. def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
  58. return [MockGoogleResponseCandidateClass()]
  59. def make_client(self: _ClientManager, name: str):
  60. global current_api_key
  61. if name.endswith("_async"):
  62. name = name.split("_")[0]
  63. cls = getattr(glm, name.title() + "ServiceAsyncClient")
  64. else:
  65. cls = getattr(glm, name.title() + "ServiceClient")
  66. # Attempt to configure using defaults.
  67. if not self.client_config:
  68. configure()
  69. client_options = self.client_config.get("client_options", None)
  70. if client_options:
  71. current_api_key = client_options.api_key
  72. def nop(self, *args, **kwargs):
  73. pass
  74. original_init = cls.__init__
  75. cls.__init__ = nop
  76. client: glm.GenerativeServiceClient = cls(**self.client_config)
  77. cls.__init__ = original_init
  78. if not self.default_metadata:
  79. return client
  80. @pytest.fixture
  81. def setup_google_mock(request, monkeypatch: MonkeyPatch):
  82. monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
  83. monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
  84. monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
  85. monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
  86. yield
  87. monkeypatch.undo()