google.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from typing import Generator, List
  2. import google.generativeai.types.content_types as content_types
  3. import google.generativeai.types.generation_types as generation_config_types
  4. import google.generativeai.types.safety_types as safety_types
  5. import pytest
  6. from _pytest.monkeypatch import MonkeyPatch
  7. from google.ai import generativelanguage as glm
  8. from google.generativeai import GenerativeModel
  9. from google.generativeai.client import _ClientManager, configure
  10. from google.generativeai.types import GenerateContentResponse
  11. from google.generativeai.types.generation_types import BaseGenerateContentResponse
  12. current_api_key = ''
  13. class MockGoogleResponseClass(object):
  14. _done = False
  15. def __iter__(self):
  16. full_response_text = 'it\'s google!'
  17. for i in range(0, len(full_response_text) + 1, 1):
  18. if i == len(full_response_text):
  19. self._done = True
  20. yield GenerateContentResponse(
  21. done=True,
  22. iterator=None,
  23. result=glm.GenerateContentResponse({
  24. }),
  25. chunks=[]
  26. )
  27. else:
  28. yield GenerateContentResponse(
  29. done=False,
  30. iterator=None,
  31. result=glm.GenerateContentResponse({
  32. }),
  33. chunks=[]
  34. )
  35. class MockGoogleResponseCandidateClass(object):
  36. finish_reason = 'stop'
  37. class MockGoogleClass(object):
  38. @staticmethod
  39. def generate_content_sync() -> GenerateContentResponse:
  40. return GenerateContentResponse(
  41. done=True,
  42. iterator=None,
  43. result=glm.GenerateContentResponse({
  44. }),
  45. chunks=[]
  46. )
  47. @staticmethod
  48. def generate_content_stream() -> Generator[GenerateContentResponse, None, None]:
  49. return MockGoogleResponseClass()
  50. def generate_content(self: GenerativeModel,
  51. contents: content_types.ContentsType,
  52. *,
  53. generation_config: generation_config_types.GenerationConfigType | None = None,
  54. safety_settings: safety_types.SafetySettingOptions | None = None,
  55. stream: bool = False,
  56. **kwargs,
  57. ) -> GenerateContentResponse:
  58. global current_api_key
  59. if len(current_api_key) < 16:
  60. raise Exception('Invalid API key')
  61. if stream:
  62. return MockGoogleClass.generate_content_stream()
  63. return MockGoogleClass.generate_content_sync()
  64. @property
  65. def generative_response_text(self) -> str:
  66. return 'it\'s google!'
  67. @property
  68. def generative_response_candidates(self) -> List[MockGoogleResponseCandidateClass]:
  69. return [MockGoogleResponseCandidateClass()]
  70. def make_client(self: _ClientManager, name: str):
  71. global current_api_key
  72. if name.endswith("_async"):
  73. name = name.split("_")[0]
  74. cls = getattr(glm, name.title() + "ServiceAsyncClient")
  75. else:
  76. cls = getattr(glm, name.title() + "ServiceClient")
  77. # Attempt to configure using defaults.
  78. if not self.client_config:
  79. configure()
  80. client_options = self.client_config.get("client_options", None)
  81. if client_options:
  82. current_api_key = client_options.api_key
  83. def nop(self, *args, **kwargs):
  84. pass
  85. original_init = cls.__init__
  86. cls.__init__ = nop
  87. client: glm.GenerativeServiceClient = cls(**self.client_config)
  88. cls.__init__ = original_init
  89. if not self.default_metadata:
  90. return client
  91. @pytest.fixture
  92. def setup_google_mock(request, monkeypatch: MonkeyPatch):
  93. monkeypatch.setattr(BaseGenerateContentResponse, "text", MockGoogleClass.generative_response_text)
  94. monkeypatch.setattr(BaseGenerateContentResponse, "candidates", MockGoogleClass.generative_response_candidates)
  95. monkeypatch.setattr(GenerativeModel, "generate_content", MockGoogleClass.generate_content)
  96. monkeypatch.setattr(_ClientManager, "make_client", MockGoogleClass.make_client)
  97. yield
  98. monkeypatch.undo()