test_embeddings.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.xinference.text_embedding.text_embedding import XinferenceTextEmbeddingModel
  6. from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock
  7. @pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
  8. def test_validate_credentials(setup_xinference_mock):
  9. model = XinferenceTextEmbeddingModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model='bge-base-en',
  13. credentials={
  14. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  15. 'model_uid': 'www ' + os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
  16. }
  17. )
  18. model.validate_credentials(
  19. model='bge-base-en',
  20. credentials={
  21. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  22. 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
  23. }
  24. )
  25. @pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
  26. def test_invoke_model(setup_xinference_mock):
  27. model = XinferenceTextEmbeddingModel()
  28. result = model.invoke(
  29. model='bge-base-en',
  30. credentials={
  31. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  32. 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
  33. },
  34. texts=[
  35. "hello",
  36. "world"
  37. ],
  38. user="abc-123"
  39. )
  40. assert isinstance(result, TextEmbeddingResult)
  41. assert len(result.embeddings) == 2
  42. assert result.usage.total_tokens > 0
  43. def test_get_num_tokens():
  44. model = XinferenceTextEmbeddingModel()
  45. num_tokens = model.get_num_tokens(
  46. model='bge-base-en',
  47. credentials={
  48. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  49. 'model_uid': os.environ.get('XINFERENCE_EMBEDDINGS_MODEL_UID')
  50. },
  51. texts=[
  52. "hello",
  53. "world"
  54. ]
  55. )
  56. assert num_tokens == 2