test_chatglm_provider.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. import pytest
  2. from unittest.mock import patch
  3. import json
  4. from langchain.schema import LLMResult, Generation, AIMessage, ChatResult, ChatGeneration
  5. from core.model_providers.providers.base import CredentialsValidateFailedError
  6. from core.model_providers.providers.chatglm_provider import ChatGLMProvider
  7. from core.model_providers.providers.spark_provider import SparkProvider
  8. from models.provider import ProviderType, Provider
  9. PROVIDER_NAME = 'chatglm'
  10. MODEL_PROVIDER_CLASS = ChatGLMProvider
  11. VALIDATE_CREDENTIAL = {
  12. 'api_base': 'valid_api_base',
  13. }
  14. def encrypt_side_effect(tenant_id, encrypt_key):
  15. return f'encrypted_{encrypt_key}'
  16. def decrypt_side_effect(tenant_id, encrypted_key):
  17. return encrypted_key.replace('encrypted_', '')
  18. def test_is_provider_credentials_valid_or_raise_valid(mocker):
  19. mocker.patch('langchain.llms.chatglm.ChatGLM._call',
  20. return_value="abc")
  21. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
  22. def test_is_provider_credentials_valid_or_raise_invalid():
  23. # raise CredentialsValidateFailedError if api_key is not in credentials
  24. with pytest.raises(CredentialsValidateFailedError):
  25. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
  26. credential = VALIDATE_CREDENTIAL.copy()
  27. credential['api_base'] = 'invalid_api_base'
  28. # raise CredentialsValidateFailedError if api_key is invalid
  29. with pytest.raises(CredentialsValidateFailedError):
  30. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
  31. @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
  32. def test_encrypt_credentials(mock_encrypt):
  33. result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
  34. assert result['api_base'] == f'encrypted_{VALIDATE_CREDENTIAL["api_base"]}'
  35. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  36. def test_get_credentials_custom(mock_decrypt):
  37. encrypted_credential = VALIDATE_CREDENTIAL.copy()
  38. encrypted_credential['api_base'] = 'encrypted_' + encrypted_credential['api_base']
  39. provider = Provider(
  40. id='provider_id',
  41. tenant_id='tenant_id',
  42. provider_name=PROVIDER_NAME,
  43. provider_type=ProviderType.CUSTOM.value,
  44. encrypted_config=json.dumps(encrypted_credential),
  45. is_valid=True,
  46. )
  47. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  48. result = model_provider.get_provider_credentials()
  49. assert result['api_base'] == 'valid_api_base'
  50. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  51. def test_get_credentials_obfuscated(mock_decrypt):
  52. encrypted_credential = VALIDATE_CREDENTIAL.copy()
  53. encrypted_credential['api_base'] = 'encrypted_' + encrypted_credential['api_base']
  54. provider = Provider(
  55. id='provider_id',
  56. tenant_id='tenant_id',
  57. provider_name=PROVIDER_NAME,
  58. provider_type=ProviderType.CUSTOM.value,
  59. encrypted_config=json.dumps(encrypted_credential),
  60. is_valid=True,
  61. )
  62. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  63. result = model_provider.get_provider_credentials(obfuscated=True)
  64. middle_token = result['api_base'][6:-2]
  65. assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_base']) - 8, 0)
  66. assert all(char == '*' for char in middle_token)