test_text_embedding.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.ollama.text_embedding.text_embedding import OllamaEmbeddingModel
  6. def test_validate_credentials():
  7. model = OllamaEmbeddingModel()
  8. with pytest.raises(CredentialsValidateFailedError):
  9. model.validate_credentials(
  10. model='mistral:text',
  11. credentials={
  12. 'base_url': 'http://localhost:21434',
  13. 'mode': 'chat',
  14. 'context_size': 4096,
  15. }
  16. )
  17. model.validate_credentials(
  18. model='mistral:text',
  19. credentials={
  20. 'base_url': os.environ.get('OLLAMA_BASE_URL'),
  21. 'mode': 'chat',
  22. 'context_size': 4096,
  23. }
  24. )
  25. def test_invoke_model():
  26. model = OllamaEmbeddingModel()
  27. result = model.invoke(
  28. model='mistral:text',
  29. credentials={
  30. 'base_url': os.environ.get('OLLAMA_BASE_URL'),
  31. 'mode': 'chat',
  32. 'context_size': 4096,
  33. },
  34. texts=[
  35. "hello",
  36. "world"
  37. ],
  38. user="abc-123"
  39. )
  40. assert isinstance(result, TextEmbeddingResult)
  41. assert len(result.embeddings) == 2
  42. assert result.usage.total_tokens == 2
  43. def test_get_num_tokens():
  44. model = OllamaEmbeddingModel()
  45. num_tokens = model.get_num_tokens(
  46. model='mistral:text',
  47. credentials={
  48. 'base_url': os.environ.get('OLLAMA_BASE_URL'),
  49. 'mode': 'chat',
  50. 'context_size': 4096,
  51. },
  52. texts=[
  53. "hello",
  54. "world"
  55. ]
  56. )
  57. assert num_tokens == 2