test_tongyi_provider.py 3.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import pytest
  2. from unittest.mock import patch
  3. import json
  4. from langchain.schema import LLMResult, Generation
  5. from core.model_providers.providers.base import CredentialsValidateFailedError
  6. from core.model_providers.providers.minimax_provider import MinimaxProvider
  7. from core.model_providers.providers.tongyi_provider import TongyiProvider
  8. from models.provider import ProviderType, Provider
  9. PROVIDER_NAME = 'tongyi'
  10. MODEL_PROVIDER_CLASS = TongyiProvider
  11. VALIDATE_CREDENTIAL = {
  12. 'dashscope_api_key': 'valid_key'
  13. }
  14. def encrypt_side_effect(tenant_id, encrypt_key):
  15. return f'encrypted_{encrypt_key}'
  16. def decrypt_side_effect(tenant_id, encrypted_key):
  17. return encrypted_key.replace('encrypted_', '')
  18. def test_is_provider_credentials_valid_or_raise_valid(mocker):
  19. mocker.patch('core.third_party.langchain.llms.tongyi_llm.EnhanceTongyi._generate', return_value=LLMResult(generations=[[Generation(text="abc")]]))
  20. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
  21. def test_is_provider_credentials_valid_or_raise_invalid():
  22. # raise CredentialsValidateFailedError if api_key is not in credentials
  23. with pytest.raises(CredentialsValidateFailedError):
  24. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
  25. credential = VALIDATE_CREDENTIAL.copy()
  26. credential['dashscope_api_key'] = 'invalid_key'
  27. # raise CredentialsValidateFailedError if api_key is invalid
  28. with pytest.raises(CredentialsValidateFailedError):
  29. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
  30. @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
  31. def test_encrypt_credentials(mock_encrypt):
  32. api_key = 'valid_key'
  33. result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
  34. mock_encrypt.assert_called_with('tenant_id', api_key)
  35. assert result['dashscope_api_key'] == f'encrypted_{api_key}'
  36. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  37. def test_get_credentials_custom(mock_decrypt):
  38. encrypted_credential = VALIDATE_CREDENTIAL.copy()
  39. encrypted_credential['dashscope_api_key'] = 'encrypted_' + encrypted_credential['dashscope_api_key']
  40. provider = Provider(
  41. id='provider_id',
  42. tenant_id='tenant_id',
  43. provider_name=PROVIDER_NAME,
  44. provider_type=ProviderType.CUSTOM.value,
  45. encrypted_config=json.dumps(encrypted_credential),
  46. is_valid=True,
  47. )
  48. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  49. result = model_provider.get_provider_credentials()
  50. assert result['dashscope_api_key'] == 'valid_key'
  51. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  52. def test_get_credentials_obfuscated(mock_decrypt):
  53. encrypted_credential = VALIDATE_CREDENTIAL.copy()
  54. encrypted_credential['dashscope_api_key'] = 'encrypted_' + encrypted_credential['dashscope_api_key']
  55. provider = Provider(
  56. id='provider_id',
  57. tenant_id='tenant_id',
  58. provider_name=PROVIDER_NAME,
  59. provider_type=ProviderType.CUSTOM.value,
  60. encrypted_config=json.dumps(encrypted_credential),
  61. is_valid=True,
  62. )
  63. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  64. result = model_provider.get_provider_credentials(obfuscated=True)
  65. middle_token = result['dashscope_api_key'][6:-2]
  66. assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['dashscope_api_key']) - 8, 0)
  67. assert all(char == '*' for char in middle_token)