test_huggingface_hub_embedding.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import json
  2. import os
  3. from unittest.mock import patch, MagicMock
  4. from core.model_providers.models.entity.model_params import ModelType
  5. from core.model_providers.models.embedding.huggingface_embedding import HuggingfaceEmbedding
  6. from core.model_providers.providers.huggingface_hub_provider import HuggingfaceHubProvider
  7. from models.provider import Provider, ProviderType, ProviderModel
  8. DEFAULT_MODEL_NAME = 'obrizum/all-MiniLM-L6-v2'
  9. def get_mock_provider():
  10. return Provider(
  11. id='provider_id',
  12. tenant_id='tenant_id',
  13. provider_name='huggingface_hub',
  14. provider_type=ProviderType.CUSTOM.value,
  15. encrypted_config='',
  16. is_valid=True,
  17. )
  18. def get_mock_embedding_model(model_name, huggingfacehub_api_type, mocker):
  19. valid_api_key = os.environ['HUGGINGFACE_API_KEY']
  20. endpoint_url = os.environ['HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL']
  21. model_provider = HuggingfaceHubProvider(provider=get_mock_provider())
  22. credentials = {
  23. 'huggingfacehub_api_type': huggingfacehub_api_type,
  24. 'huggingfacehub_api_token': valid_api_key,
  25. 'task_type': 'feature-extraction'
  26. }
  27. if huggingfacehub_api_type == 'inference_endpoints':
  28. credentials['huggingfacehub_endpoint_url'] = endpoint_url
  29. mock_query = MagicMock()
  30. mock_query.filter.return_value.first.return_value = ProviderModel(
  31. provider_name='huggingface_hub',
  32. model_name=model_name,
  33. model_type=ModelType.EMBEDDINGS.value,
  34. encrypted_config=json.dumps(credentials),
  35. is_valid=True,
  36. )
  37. mocker.patch('extensions.ext_database.db.session.query',
  38. return_value=mock_query)
  39. return HuggingfaceEmbedding(
  40. model_provider=model_provider,
  41. name=model_name
  42. )
  43. def decrypt_side_effect(tenant_id, encrypted_api_key):
  44. return encrypted_api_key
  45. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  46. def test_hosted_inference_api_embed_documents(mock_decrypt, mocker):
  47. embedding_model = get_mock_embedding_model(
  48. DEFAULT_MODEL_NAME,
  49. 'hosted_inference_api',
  50. mocker)
  51. rst = embedding_model.client.embed_documents(['test', 'test1'])
  52. assert isinstance(rst, list)
  53. assert len(rst) == 2
  54. assert len(rst[0]) == 384
  55. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  56. def test_endpoint_url_inference_api_embed_documents(mock_decrypt, mocker):
  57. embedding_model = get_mock_embedding_model(
  58. '',
  59. 'inference_endpoints',
  60. mocker)
  61. mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
  62. , return_value=bytes(json.dumps([[1, 2, 3], [4, 5, 6]]), 'utf-8'))
  63. rst = embedding_model.client.embed_documents(['test', 'test1'])
  64. assert isinstance(rst, list)
  65. assert len(rst) == 2
  66. assert len(rst[0]) == 3
  67. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  68. def test_endpoint_url_inference_api_embed_documents_two(mock_decrypt, mocker):
  69. embedding_model = get_mock_embedding_model(
  70. '',
  71. 'inference_endpoints',
  72. mocker)
  73. mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
  74. , return_value=bytes(json.dumps([[[[1,2,3],[4,5,6],[7,8,9]]],[[[1,2,3],[4,5,6],[7,8,9]]]]), 'utf-8'))
  75. rst = embedding_model.client.embed_documents(['test', 'test1'])
  76. assert isinstance(rst, list)
  77. assert len(rst) == 2
  78. assert len(rst[0]) == 3
  79. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  80. def test_hosted_inference_api_embed_query(mock_decrypt, mocker):
  81. embedding_model = get_mock_embedding_model(
  82. DEFAULT_MODEL_NAME,
  83. 'hosted_inference_api',
  84. mocker)
  85. rst = embedding_model.client.embed_query('test')
  86. assert isinstance(rst, list)
  87. assert len(rst) == 384
  88. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  89. def test_endpoint_url_inference_api_embed_query(mock_decrypt, mocker):
  90. embedding_model = get_mock_embedding_model(
  91. '',
  92. 'inference_endpoints',
  93. mocker)
  94. mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
  95. , return_value=bytes(json.dumps([[1, 2, 3]]), 'utf-8'))
  96. rst = embedding_model.client.embed_query('test')
  97. assert isinstance(rst, list)
  98. assert len(rst) == 3
  99. @patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
  100. def test_endpoint_url_inference_api_embed_query_two(mock_decrypt, mocker):
  101. embedding_model = get_mock_embedding_model(
  102. '',
  103. 'inference_endpoints',
  104. mocker)
  105. mocker.patch('core.third_party.langchain.embeddings.huggingface_hub_embedding.InferenceClient.post'
  106. , return_value=bytes(json.dumps([[[[1,2,3],[4,5,6],[7,8,9]]]]), 'utf-8'))
  107. rst = embedding_model.client.embed_query('test')
  108. assert isinstance(rst, list)
  109. assert len(rst) == 3