test_azure_openai_model.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. import json
  2. import os
  3. from unittest.mock import patch, MagicMock
  4. import pytest
  5. from langchain.schema import ChatGeneration, AIMessage
  6. from core.model_providers.models.entity.model_params import ModelKwargs, ModelType
  7. from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
  8. from core.model_providers.models.entity.message import PromptMessage, MessageType
  9. from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
  10. from models.provider import Provider, ProviderType, ProviderModel
  11. def get_mock_provider():
  12. return Provider(
  13. id='provider_id',
  14. tenant_id='tenant_id',
  15. provider_name='azure_openai',
  16. provider_type=ProviderType.CUSTOM.value,
  17. encrypted_config='',
  18. is_valid=True,
  19. )
  20. def get_mock_azure_openai_model(model_name, mocker):
  21. model_kwargs = ModelKwargs(
  22. max_tokens=10,
  23. temperature=0
  24. )
  25. valid_openai_api_base = os.environ['AZURE_OPENAI_API_BASE']
  26. valid_openai_api_key = os.environ['AZURE_OPENAI_API_KEY']
  27. provider = AzureOpenAIProvider(provider=get_mock_provider())
  28. mock_query = MagicMock()
  29. mock_query.filter.return_value.first.return_value = ProviderModel(
  30. provider_name='azure_openai',
  31. model_name=model_name,
  32. model_type=ModelType.TEXT_GENERATION.value,
  33. encrypted_config=json.dumps({
  34. 'openai_api_base': valid_openai_api_base,
  35. 'openai_api_key': valid_openai_api_key,
  36. 'base_model_name': model_name
  37. }),
  38. is_valid=True,
  39. )
  40. mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
  41. return AzureOpenAIModel(
  42. model_provider=provider,
  43. name=model_name,
  44. model_kwargs=model_kwargs
  45. )
  46. def decrypt_side_effect(tenant_id, encrypted_openai_api_key):
  47. return encrypted_openai_api_key
  48. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  49. def test_get_num_tokens(mock_decrypt, mocker):
  50. openai_model = get_mock_azure_openai_model('text-davinci-003', mocker)
  51. rst = openai_model.get_num_tokens([PromptMessage(content='you are a kindness Assistant.')])
  52. assert rst == 6
  53. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  54. def test_chat_get_num_tokens(mock_decrypt, mocker):
  55. openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
  56. rst = openai_model.get_num_tokens([
  57. PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
  58. PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
  59. ])
  60. assert rst == 22
  61. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  62. def test_run(mock_decrypt, mocker):
  63. mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
  64. openai_model = get_mock_azure_openai_model('gpt-35-turbo', mocker)
  65. messages = [PromptMessage(content='Human: Are you Human? you MUST only answer `y` or `n`? \nAssistant: ')]
  66. rst = openai_model.run(
  67. messages,
  68. stop=['\nHuman:'],
  69. )
  70. assert len(rst.content) > 0