test_embeddings.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.nomic.text_embedding.text_embedding import NomicTextEmbeddingModel
  6. from tests.integration_tests.model_runtime.__mock.nomic_embeddings import setup_nomic_mock
  7. @pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True)
  8. def test_validate_credentials(setup_nomic_mock):
  9. model = NomicTextEmbeddingModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model="nomic-embed-text-v1.5",
  13. credentials={
  14. "nomic_api_key": "invalid_key",
  15. },
  16. )
  17. model.validate_credentials(
  18. model="nomic-embed-text-v1.5",
  19. credentials={
  20. "nomic_api_key": os.environ.get("NOMIC_API_KEY"),
  21. },
  22. )
  23. @pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True)
  24. def test_invoke_model(setup_nomic_mock):
  25. model = NomicTextEmbeddingModel()
  26. result = model.invoke(
  27. model="nomic-embed-text-v1.5",
  28. credentials={
  29. "nomic_api_key": os.environ.get("NOMIC_API_KEY"),
  30. },
  31. texts=["hello", "world"],
  32. user="foo",
  33. )
  34. assert isinstance(result, TextEmbeddingResult)
  35. assert result.model == "nomic-embed-text-v1.5"
  36. assert len(result.embeddings) == 2
  37. assert result.usage.total_tokens == 2
  38. @pytest.mark.parametrize("setup_nomic_mock", [["text_embedding"]], indirect=True)
  39. def test_get_num_tokens(setup_nomic_mock):
  40. model = NomicTextEmbeddingModel()
  41. num_tokens = model.get_num_tokens(
  42. model="nomic-embed-text-v1.5",
  43. credentials={
  44. "nomic_api_key": os.environ.get("NOMIC_API_KEY"),
  45. },
  46. texts=["hello", "world"],
  47. )
  48. assert num_tokens == 2