azure_provider.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. import json
  2. import logging
  3. from typing import Optional, Union
  4. import requests
  5. from core.llm.provider.base import BaseProvider
  6. from core.llm.provider.errors import ValidateFailedError
  7. from models.provider import ProviderName
  8. class AzureProvider(BaseProvider):
  9. def get_models(self, model_id: Optional[str] = None, credentials: Optional[dict] = None) -> list[dict]:
  10. credentials = self.get_credentials(model_id) if not credentials else credentials
  11. url = "{}/openai/deployments?api-version={}".format(
  12. str(credentials.get('openai_api_base')),
  13. str(credentials.get('openai_api_version'))
  14. )
  15. headers = {
  16. "api-key": str(credentials.get('openai_api_key')),
  17. "content-type": "application/json; charset=utf-8"
  18. }
  19. response = requests.get(url, headers=headers)
  20. if response.status_code == 200:
  21. result = response.json()
  22. return [{
  23. 'id': deployment['id'],
  24. 'name': '{} ({})'.format(deployment['id'], deployment['model'])
  25. } for deployment in result['data'] if deployment['status'] == 'succeeded']
  26. else:
  27. if response.status_code == 401:
  28. raise AzureAuthenticationError()
  29. else:
  30. raise AzureRequestFailedError('Failed to request Azure OpenAI. Status code: {}'.format(response.status_code))
  31. def get_credentials(self, model_id: Optional[str] = None) -> dict:
  32. """
  33. Returns the API credentials for Azure OpenAI as a dictionary.
  34. """
  35. config = self.get_provider_api_key(model_id=model_id)
  36. config['openai_api_type'] = 'azure'
  37. if model_id == 'text-embedding-ada-002':
  38. config['deployment'] = model_id.replace('.', '') if model_id else None
  39. config['chunk_size'] = 1
  40. else:
  41. config['deployment_name'] = model_id.replace('.', '') if model_id else None
  42. return config
  43. def get_provider_name(self):
  44. return ProviderName.AZURE_OPENAI
  45. def get_provider_configs(self, obfuscated: bool = False) -> Union[str | dict]:
  46. """
  47. Returns the provider configs.
  48. """
  49. try:
  50. config = self.get_provider_api_key()
  51. except:
  52. config = {
  53. 'openai_api_type': 'azure',
  54. 'openai_api_version': '2023-03-15-preview',
  55. 'openai_api_base': '',
  56. 'openai_api_key': ''
  57. }
  58. if obfuscated:
  59. if not config.get('openai_api_key'):
  60. config = {
  61. 'openai_api_type': 'azure',
  62. 'openai_api_version': '2023-03-15-preview',
  63. 'openai_api_base': '',
  64. 'openai_api_key': ''
  65. }
  66. config['openai_api_key'] = self.obfuscated_token(config.get('openai_api_key'))
  67. return config
  68. return config
  69. def get_token_type(self):
  70. # TODO: change to dict when implemented
  71. return dict
  72. def config_validate(self, config: Union[dict | str]):
  73. """
  74. Validates the given config.
  75. """
  76. try:
  77. if not isinstance(config, dict):
  78. raise ValueError('Config must be a object.')
  79. if 'openai_api_version' not in config:
  80. config['openai_api_version'] = '2023-03-15-preview'
  81. models = self.get_models(credentials=config)
  82. if not models:
  83. raise ValidateFailedError("Please add deployments for 'text-davinci-003', "
  84. "'gpt-3.5-turbo', 'text-embedding-ada-002' (required) "
  85. "and 'gpt-4', 'gpt-35-turbo-16k' (optional).")
  86. fixed_model_ids = [
  87. 'text-davinci-003',
  88. 'gpt-35-turbo',
  89. 'text-embedding-ada-002'
  90. ]
  91. current_model_ids = [model['id'] for model in models]
  92. missing_model_ids = [fixed_model_id for fixed_model_id in fixed_model_ids if
  93. fixed_model_id not in current_model_ids]
  94. if missing_model_ids:
  95. raise ValidateFailedError("Please add deployments for '{}'.".format(", ".join(missing_model_ids)))
  96. except ValidateFailedError as e:
  97. raise e
  98. except AzureAuthenticationError:
  99. raise ValidateFailedError('Validation failed, please check your API Key.')
  100. except (requests.ConnectionError, requests.RequestException):
  101. raise ValidateFailedError('Validation failed, please check your API Base Endpoint.')
  102. except AzureRequestFailedError as ex:
  103. raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
  104. except Exception as ex:
  105. logging.exception('Azure OpenAI Credentials validation failed')
  106. raise ValidateFailedError('Validation failed, error: {}.'.format(str(ex)))
  107. def get_encrypted_token(self, config: Union[dict | str]):
  108. """
  109. Returns the encrypted token.
  110. """
  111. return json.dumps({
  112. 'openai_api_type': 'azure',
  113. 'openai_api_version': '2023-03-15-preview',
  114. 'openai_api_base': config['openai_api_base'],
  115. 'openai_api_key': self.encrypt_token(config['openai_api_key'])
  116. })
  117. def get_decrypted_token(self, token: str):
  118. """
  119. Returns the decrypted token.
  120. """
  121. config = json.loads(token)
  122. config['openai_api_key'] = self.decrypt_token(config['openai_api_key'])
  123. return config
  124. class AzureAuthenticationError(Exception):
  125. pass
  126. class AzureRequestFailedError(Exception):
  127. pass