test_anthropic_provider.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. from typing import List, Optional, Any
  2. import anthropic
  3. import httpx
  4. import pytest
  5. from unittest.mock import patch
  6. import json
  7. from langchain.callbacks.manager import CallbackManagerForLLMRun
  8. from langchain.schema import BaseMessage, ChatResult, ChatGeneration, AIMessage
  9. from core.model_providers.providers.anthropic_provider import AnthropicProvider
  10. from core.model_providers.providers.base import CredentialsValidateFailedError
  11. from models.provider import ProviderType, Provider
  12. PROVIDER_NAME = 'anthropic'
  13. MODEL_PROVIDER_CLASS = AnthropicProvider
  14. VALIDATE_CREDENTIAL_KEY = 'anthropic_api_key'
  15. def mock_chat_generate(messages: List[BaseMessage],
  16. stop: Optional[List[str]] = None,
  17. run_manager: Optional[CallbackManagerForLLMRun] = None,
  18. **kwargs: Any):
  19. return ChatResult(generations=[ChatGeneration(message=AIMessage(content='answer'))])
  20. def mock_chat_generate_invalid(messages: List[BaseMessage],
  21. stop: Optional[List[str]] = None,
  22. run_manager: Optional[CallbackManagerForLLMRun] = None,
  23. **kwargs: Any):
  24. raise anthropic.APIStatusError('Invalid credentials',
  25. request=httpx._models.Request(
  26. method='POST',
  27. url='https://api.anthropic.com/v1/completions',
  28. ),
  29. response=httpx._models.Response(
  30. status_code=401,
  31. ),
  32. body=None
  33. )
  34. def encrypt_side_effect(tenant_id, encrypt_key):
  35. return f'encrypted_{encrypt_key}'
  36. def decrypt_side_effect(tenant_id, encrypted_key):
  37. return encrypted_key.replace('encrypted_', '')
  38. @patch('langchain.chat_models.ChatAnthropic._generate', side_effect=mock_chat_generate)
  39. def test_is_provider_credentials_valid_or_raise_valid(mock_create):
  40. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'valid_key'})
  41. @patch('langchain.chat_models.ChatAnthropic._generate', side_effect=mock_chat_generate_invalid)
  42. def test_is_provider_credentials_valid_or_raise_invalid(mock_create):
  43. # raise CredentialsValidateFailedError if anthropic_api_key is not in credentials
  44. with pytest.raises(CredentialsValidateFailedError):
  45. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
  46. # raise CredentialsValidateFailedError if anthropic_api_key is invalid
  47. with pytest.raises(CredentialsValidateFailedError):
  48. MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({VALIDATE_CREDENTIAL_KEY: 'invalid_key'})
  49. @patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
  50. def test_encrypt_credentials(mock_encrypt):
  51. api_key = 'valid_key'
  52. result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', {VALIDATE_CREDENTIAL_KEY: api_key})
  53. mock_encrypt.assert_called_with('tenant_id', api_key)
  54. assert result[VALIDATE_CREDENTIAL_KEY] == f'encrypted_{api_key}'
  55. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  56. def test_get_credentials_custom(mock_decrypt):
  57. provider = Provider(
  58. id='provider_id',
  59. tenant_id='tenant_id',
  60. provider_name=PROVIDER_NAME,
  61. provider_type=ProviderType.CUSTOM.value,
  62. encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: 'encrypted_valid_key'}),
  63. is_valid=True,
  64. )
  65. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  66. result = model_provider.get_provider_credentials()
  67. assert result[VALIDATE_CREDENTIAL_KEY] == 'valid_key'
  68. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  69. def test_get_credentials_obfuscated(mock_decrypt):
  70. api_key = 'valid_key'
  71. provider = Provider(
  72. id='provider_id',
  73. tenant_id='tenant_id',
  74. provider_name=PROVIDER_NAME,
  75. provider_type=ProviderType.CUSTOM.value,
  76. encrypted_config=json.dumps({VALIDATE_CREDENTIAL_KEY: f'encrypted_{api_key}'}),
  77. is_valid=True,
  78. )
  79. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  80. result = model_provider.get_provider_credentials(obfuscated=True)
  81. middle_token = result[VALIDATE_CREDENTIAL_KEY][6:-2]
  82. assert len(middle_token) == max(len(api_key) - 8, 0)
  83. assert all(char == '*' for char in middle_token)
  84. @patch('core.model_providers.providers.hosted.hosted_model_providers.anthropic')
  85. def test_get_credentials_hosted(mock_hosted):
  86. provider = Provider(
  87. id='provider_id',
  88. tenant_id='tenant_id',
  89. provider_name=PROVIDER_NAME,
  90. provider_type=ProviderType.SYSTEM.value,
  91. encrypted_config='',
  92. is_valid=True,
  93. )
  94. model_provider = MODEL_PROVIDER_CLASS(provider=provider)
  95. mock_hosted.api_key = 'hosted_key'
  96. result = model_provider.get_provider_credentials()
  97. assert result[VALIDATE_CREDENTIAL_KEY] == 'hosted_key'