test_text_embedding.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.ollama.text_embedding.text_embedding import OllamaEmbeddingModel
  6. def test_validate_credentials():
  7. model = OllamaEmbeddingModel()
  8. with pytest.raises(CredentialsValidateFailedError):
  9. model.validate_credentials(
  10. model="mistral:text",
  11. credentials={
  12. "base_url": "http://localhost:21434",
  13. "mode": "chat",
  14. "context_size": 4096,
  15. },
  16. )
  17. model.validate_credentials(
  18. model="mistral:text",
  19. credentials={
  20. "base_url": os.environ.get("OLLAMA_BASE_URL"),
  21. "mode": "chat",
  22. "context_size": 4096,
  23. },
  24. )
  25. def test_invoke_model():
  26. model = OllamaEmbeddingModel()
  27. result = model.invoke(
  28. model="mistral:text",
  29. credentials={
  30. "base_url": os.environ.get("OLLAMA_BASE_URL"),
  31. "mode": "chat",
  32. "context_size": 4096,
  33. },
  34. texts=["hello", "world"],
  35. user="abc-123",
  36. )
  37. assert isinstance(result, TextEmbeddingResult)
  38. assert len(result.embeddings) == 2
  39. assert result.usage.total_tokens == 2
  40. def test_get_num_tokens():
  41. model = OllamaEmbeddingModel()
  42. num_tokens = model.get_num_tokens(
  43. model="mistral:text",
  44. credentials={
  45. "base_url": os.environ.get("OLLAMA_BASE_URL"),
  46. "mode": "chat",
  47. "context_size": 4096,
  48. },
  49. texts=["hello", "world"],
  50. )
  51. assert num_tokens == 2