test_text_embedding.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.openai.text_embedding.text_embedding import OpenAITextEmbeddingModel
  6. from tests.integration_tests.model_runtime.__mock.openai import setup_openai_mock
  7. @pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
  8. def test_validate_credentials(setup_openai_mock):
  9. model = OpenAITextEmbeddingModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(model="text-embedding-ada-002", credentials={"openai_api_key": "invalid_key"})
  12. model.validate_credentials(
  13. model="text-embedding-ada-002", credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY")}
  14. )
  15. @pytest.mark.parametrize("setup_openai_mock", [["text_embedding"]], indirect=True)
  16. def test_invoke_model(setup_openai_mock):
  17. model = OpenAITextEmbeddingModel()
  18. result = model.invoke(
  19. model="text-embedding-ada-002",
  20. credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"},
  21. texts=["hello", "world", " ".join(["long_text"] * 100), " ".join(["another_long_text"] * 100)],
  22. user="abc-123",
  23. )
  24. assert isinstance(result, TextEmbeddingResult)
  25. assert len(result.embeddings) == 4
  26. assert result.usage.total_tokens == 2
  27. def test_get_num_tokens():
  28. model = OpenAITextEmbeddingModel()
  29. num_tokens = model.get_num_tokens(
  30. model="text-embedding-ada-002",
  31. credentials={"openai_api_key": os.environ.get("OPENAI_API_KEY"), "openai_api_base": "https://api.openai.com"},
  32. texts=["hello", "world"],
  33. )
  34. assert num_tokens == 2