test_base_model_provider.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from unittest.mock import MagicMock
  2. import pytest
  3. from core.model_providers.error import QuotaExceededError
  4. from core.model_providers.models.entity.model_params import ModelType
  5. from models.provider import Provider, ProviderType
  6. from tests.unit_tests.model_providers.fake_model_provider import FakeModelProvider
  7. def test_get_supported_model_list(mocker):
  8. mocker.patch.object(
  9. FakeModelProvider,
  10. 'get_rules',
  11. return_value={'support_provider_types': ['custom'], 'model_flexibility': 'configurable'}
  12. )
  13. mock_provider_model = MagicMock()
  14. mock_provider_model.model_name = 'test_model'
  15. mock_query = MagicMock()
  16. mock_query.filter.return_value.order_by.return_value.all.return_value = [mock_provider_model]
  17. mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
  18. provider = FakeModelProvider(provider=Provider())
  19. result = provider.get_supported_model_list(ModelType.TEXT_GENERATION)
  20. assert result == [{'id': 'test_model', 'name': 'test_model'}]
  21. def test_check_quota_over_limit(mocker):
  22. mocker.patch.object(
  23. FakeModelProvider,
  24. 'get_rules',
  25. return_value={'support_provider_types': ['system']}
  26. )
  27. mock_query = MagicMock()
  28. mock_query.filter.return_value.first.return_value = None
  29. mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
  30. provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.SYSTEM.value))
  31. with pytest.raises(QuotaExceededError):
  32. provider.check_quota_over_limit()
  33. def test_check_quota_not_over_limit(mocker):
  34. mocker.patch.object(
  35. FakeModelProvider,
  36. 'get_rules',
  37. return_value={'support_provider_types': ['system']}
  38. )
  39. mock_query = MagicMock()
  40. mock_query.filter.return_value.first.return_value = Provider()
  41. mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
  42. provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.SYSTEM.value))
  43. assert provider.check_quota_over_limit() is None
  44. def test_check_custom_quota_over_limit(mocker):
  45. mocker.patch.object(
  46. FakeModelProvider,
  47. 'get_rules',
  48. return_value={'support_provider_types': ['custom']}
  49. )
  50. provider = FakeModelProvider(provider=Provider(provider_type=ProviderType.CUSTOM.value))
  51. assert provider.check_quota_over_limit() is None