test_spark_provider.py 3.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import pytest
  2. from unittest.mock import patch
  3. import json
  4. from langchain.schema import LLMResult, Generation, AIMessage, ChatResult, ChatGeneration
  5. from core.model_providers.providers.base import CredentialsValidateFailedError
  6. from core.model_providers.providers.spark_provider import SparkProvider
  7. from models.provider import ProviderType, Provider
  8. PROVIDER_NAME = 'spark'
  9. MODEL_PROVIDER_CLASS = SparkProvider
  10. VALIDATE_CREDENTIAL = {
  11. 'app_id': 'valid_app_id',
  12. 'api_key': 'valid_key',
  13. 'api_secret': 'valid_secret'
  14. }
  15. def encrypt_side_effect(tenant_id, encrypt_key):
  16. return f'encrypted_{encrypt_key}'
  17. def decrypt_side_effect(tenant_id, encrypted_key):
  18. return encrypted_key.replace('encrypted_', '')
  19. def test_is_provider_credentials_valid_or_raise_valid(mocker):
  20. mocker.patch('core.third_party.langchain.llms.spark.ChatSpark._generate',
  21. return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content="abc"))]))
  22. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
  23. def test_is_provider_credentials_valid_or_raise_invalid():
  24. # raise CredentialsValidateFailedError if api_key is not in credentials
  25. with pytest.raises(CredentialsValidateFailedError):
  26. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
  27. credential = VALIDATE_CREDENTIAL.copy()
  28. del credential['api_key']
  29. # raise CredentialsValidateFailedError if api_key is invalid
  30. with pytest.raises(CredentialsValidateFailedError):
  31. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
  32. @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
  33. def test_encrypt_credentials(mock_encrypt):
  34. result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
  35. assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
  36. assert result['api_secret'] == f'encrypted_{VALIDATE_CREDENTIAL["api_secret"]}'
  37. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  38. def test_get_credentials_custom(mock_decrypt):
  39. encrypted_credential = VALIDATE_CREDENTIAL.copy()
  40. encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
  41. encrypted_credential['api_secret'] = 'encrypted_' + encrypted_credential['api_secret']
  42. provider = Provider(
  43. id='provider_id',
  44. tenant_id='tenant_id',
  45. provider_name=PROVIDER_NAME,
  46. provider_type=ProviderType.CUSTOM.value,
  47. encrypted_config=json.dumps(encrypted_credential),
  48. is_valid=True,
  49. )
  50. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  51. result = model_provider.get_provider_credentials()
  52. assert result['api_key'] == 'valid_key'
  53. assert result['api_secret'] == 'valid_secret'
  54. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  55. def test_get_credentials_obfuscated(mock_decrypt):
  56. encrypted_credential = VALIDATE_CREDENTIAL.copy()
  57. encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
  58. encrypted_credential['api_secret'] = 'encrypted_' + encrypted_credential['api_secret']
  59. provider = Provider(
  60. id='provider_id',
  61. tenant_id='tenant_id',
  62. provider_name=PROVIDER_NAME,
  63. provider_type=ProviderType.CUSTOM.value,
  64. encrypted_config=json.dumps(encrypted_credential),
  65. is_valid=True,
  66. )
  67. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  68. result = model_provider.get_provider_credentials(obfuscated=True)
  69. middle_token = result['api_key'][6:-2]
  70. middle_secret = result['api_secret'][6:-2]
  71. assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
  72. assert len(middle_secret) == max(len(VALIDATE_CREDENTIAL['api_secret']) - 8, 0)
  73. assert all(char == '*' for char in middle_token)
  74. assert all(char == '*' for char in middle_secret)