test_azure_openai_provider.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import pytest
  2. from unittest.mock import patch, MagicMock
  3. import json
  4. from core.model_providers.models.entity.model_params import ModelType
  5. from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
  6. from core.model_providers.providers.base import CredentialsValidateFailedError
  7. from models.provider import ProviderType, Provider, ProviderModel
  8. PROVIDER_NAME = 'azure_openai'
  9. MODEL_PROVIDER_CLASS = AzureOpenAIProvider
  10. VALIDATE_CREDENTIAL = {
  11. 'openai_api_base': 'https://xxxx.openai.azure.com/',
  12. 'openai_api_key': 'valid_key',
  13. 'base_model_name': 'gpt-35-turbo'
  14. }
  15. def encrypt_side_effect(tenant_id, encrypt_key):
  16. return f'encrypted_{encrypt_key}'
  17. def decrypt_side_effect(tenant_id, encrypted_key):
  18. return encrypted_key.replace('encrypted_', '')
  19. def test_is_model_credentials_valid_or_raise(mocker):
  20. mocker.patch('langchain.chat_models.base.BaseChatModel.generate', return_value=None)
  21. # assert True if credentials is valid
  22. MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
  23. model_name='test_model_name',
  24. model_type=ModelType.TEXT_GENERATION,
  25. credentials=VALIDATE_CREDENTIAL
  26. )
  27. def test_is_model_credentials_valid_or_raise_invalid():
  28. # raise CredentialsValidateFailedError if credentials is not in credentials
  29. with pytest.raises(CredentialsValidateFailedError):
  30. MODEL_PROVIDER_CLASS.is_model_credentials_valid_or_raise(
  31. model_name='test_model_name',
  32. model_type=ModelType.TEXT_GENERATION,
  33. credentials={}
  34. )
  35. @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
  36. def test_encrypt_model_credentials(mock_encrypt):
  37. openai_api_key = 'valid_key'
  38. result = MODEL_PROVIDER_CLASS.encrypt_model_credentials(
  39. tenant_id='tenant_id',
  40. model_name='test_model_name',
  41. model_type=ModelType.TEXT_GENERATION,
  42. credentials={'openai_api_key': openai_api_key}
  43. )
  44. mock_encrypt.assert_called_with('tenant_id', openai_api_key)
  45. assert result['openai_api_key'] == f'encrypted_{openai_api_key}'
  46. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  47. def test_get_model_credentials_custom(mock_decrypt, mocker):
  48. provider = Provider(
  49. id='provider_id',
  50. tenant_id='tenant_id',
  51. provider_name=PROVIDER_NAME,
  52. provider_type=ProviderType.CUSTOM.value,
  53. encrypted_config=None,
  54. is_valid=True,
  55. )
  56. encrypted_credential = VALIDATE_CREDENTIAL.copy()
  57. encrypted_credential['openai_api_key'] = 'encrypted_' + encrypted_credential['openai_api_key']
  58. mock_query = MagicMock()
  59. mock_query.filter.return_value.first.return_value = ProviderModel(
  60. encrypted_config=json.dumps(encrypted_credential)
  61. )
  62. mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
  63. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  64. result = model_provider.get_model_credentials(
  65. model_name='test_model_name',
  66. model_type=ModelType.TEXT_GENERATION
  67. )
  68. assert result['openai_api_key'] == 'valid_key'
  69. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  70. def test_get_model_credentials_obfuscated(mock_decrypt, mocker):
  71. provider = Provider(
  72. id='provider_id',
  73. tenant_id='tenant_id',
  74. provider_name=PROVIDER_NAME,
  75. provider_type=ProviderType.CUSTOM.value,
  76. encrypted_config=None,
  77. is_valid=True,
  78. )
  79. encrypted_credential = VALIDATE_CREDENTIAL.copy()
  80. encrypted_credential['openai_api_key'] = 'encrypted_' + encrypted_credential['openai_api_key']
  81. mock_query = MagicMock()
  82. mock_query.filter.return_value.first.return_value = ProviderModel(
  83. encrypted_config=json.dumps(encrypted_credential)
  84. )
  85. mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
  86. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  87. result = model_provider.get_model_credentials(
  88. model_name='test_model_name',
  89. model_type=ModelType.TEXT_GENERATION,
  90. obfuscated=True
  91. )
  92. middle_token = result['openai_api_key'][6:-2]
  93. assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['openai_api_key']) - 8, 0)
  94. assert all(char == '*' for char in middle_token)