test_text_embedding.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.jina.text_embedding.text_embedding import JinaTextEmbeddingModel
  6. def test_validate_credentials():
  7. model = JinaTextEmbeddingModel()
  8. with pytest.raises(CredentialsValidateFailedError):
  9. model.validate_credentials(model="jina-embeddings-v2-base-en", credentials={"api_key": "invalid_key"})
  10. model.validate_credentials(
  11. model="jina-embeddings-v2-base-en", credentials={"api_key": os.environ.get("JINA_API_KEY")}
  12. )
  13. def test_invoke_model():
  14. model = JinaTextEmbeddingModel()
  15. result = model.invoke(
  16. model="jina-embeddings-v2-base-en",
  17. credentials={
  18. "api_key": os.environ.get("JINA_API_KEY"),
  19. },
  20. texts=["hello", "world"],
  21. user="abc-123",
  22. )
  23. assert isinstance(result, TextEmbeddingResult)
  24. assert len(result.embeddings) == 2
  25. assert result.usage.total_tokens == 6
  26. def test_get_num_tokens():
  27. model = JinaTextEmbeddingModel()
  28. num_tokens = model.get_num_tokens(
  29. model="jina-embeddings-v2-base-en",
  30. credentials={
  31. "api_key": os.environ.get("JINA_API_KEY"),
  32. },
  33. texts=["hello", "world"],
  34. )
  35. assert num_tokens == 6