google.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. from collections.abc import Generator
  2. import google.generativeai.types.content_types as content_types
  3. import google.generativeai.types.generation_types as generation_config_types
  4. import google.generativeai.types.safety_types as safety_types
  5. import pytest
  6. from _pytest.monkeypatch import MonkeyPatch
  7. from google.ai import generativelanguage as glm
  8. from google.ai.generativelanguage_v1beta.types import content as gag_content
  9. from google.generativeai import GenerativeModel
  10. from google.generativeai.client import _ClientManager, configure
  11. from google.generativeai.types import GenerateContentResponse
  12. from google.generativeai.types.generation_types import BaseGenerateContentResponse
  13. current_api_key = ''
  14. class MockGoogleResponseClass:
  15. _done = False
  16. def __iter__(self):
  17. full_response_text = 'it\'s google!'
  18. for i in range(0, len(full_response_text) + 1, 1):
  19. if i == len(full_response_text):
  20. self._done = True
  21. yield GenerateContentResponse(
  22. done=True,
  23. iterator=None,
  24. result=glm.GenerateContentResponse({
  25. }),
  26. chunks=[]
  27. )
  28. else:
  29. yield GenerateContentResponse(
  30. done=False,
  31. iterator=None,
  32. result=glm.GenerateContentResponse({
  33. }),
  34. chunks=[]
  35. )
  36. class MockGoogleResponseCandidateClass:
  37. finish_reason = 'stop'
  38. @property
  39. def content(self) -> gag_content.Content:
  40. return gag_content.Content(
  41. parts=[
  42. gag_content.Part(text='it\'s google!')
  43. ]
  44. )
  45. class MockGoogleClass:
  46. @staticmethod
  47. def generate_content_sync() -> GenerateContentResponse:
  48. return GenerateContentResponse(
  49. done=True,
  50. iterator=None,
  51. result=glm.GenerateContentResponse({
  52. }),
  53. chunks=[]
  54. )
  55. @staticmethod
  56. def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
  57. return MockGoogleResponseClass()
  58. def generate_content(self: GenerativeModel,
  59. contents: content_types.ContentsType,
  60. *,
  61. generation_config: generation_config_types.GenerationConfigType | None = None,
  62. safety_settings: safety_types.SafetySettingOptions | None = None,
  63. stream: bool = False,
  64. **kwargs,
  65. ) -> GenerateContentResponse:
  66. global current_api_key
  67. if len(current_api_key) < 16:
  68. raise Exception('Invalid API key')
  69. if stream:
  70. return MockGoogleClass.generate_content_stream()
  71. return MockGoogleClass.generate_content_sync()
  72. @property
  73. def generative_response_text(self) -> str:
  74. return 'it\'s google!'
  75. @property
  76. def generative_response_candidates(self) -> list[MockGoogleResponseCandidateClass]:
  77. return [MockGoogleResponseCandidateClass()]
  78. def make_client(self: _ClientManager, name: str):
  79. global current_api_key
  80. if name.endswith("_async"):
  81. name = name.split("_")[0]
  82. cls = getattr(glm, name.title() + "ServiceAsyncClient")
  83. else:
  84. cls = getattr(glm, name.title() + "ServiceClient")
  85. # Attempt to configure using defaults.
  86. if not self.client_config:
  87. configure()
  88. client_options = self.client_config.get("client_options", None)
  89. if client_options:
  90. current_api_key = client_options.api_key
  91. def nop(self, *args, **kwargs):
  92. pass
  93. original_init = cls.__init__
  94. cls.__init__ = nop
  95. client: glm.GenerativeServiceClient = cls(**self.client_config)
  96. cls.__init__ = original_init
  97. if not self.default_metadata:
  98. return client
  99. @pytest.fixture
  100. def setup_google_mock(request, monkeypatch: MonkeyPatch):
  101. monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
  102. monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
  103. monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
  104. monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
  105. yield
  106. monkeypatch.undo()