test_text_embedding.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.zhipuai.text_embedding.text_embedding import ZhipuAITextEmbeddingModel
  6. def test_validate_credentials():
  7. model = ZhipuAITextEmbeddingModel()
  8. with pytest.raises(CredentialsValidateFailedError):
  9. model.validate_credentials(model="text_embedding", credentials={"api_key": "invalid_key"})
  10. model.validate_credentials(model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")})
  11. def test_invoke_model():
  12. model = ZhipuAITextEmbeddingModel()
  13. result = model.invoke(
  14. model="text_embedding",
  15. credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")},
  16. texts=["hello", "world"],
  17. user="abc-123",
  18. )
  19. assert isinstance(result, TextEmbeddingResult)
  20. assert len(result.embeddings) == 2
  21. assert result.usage.total_tokens > 0
  22. def test_get_num_tokens():
  23. model = ZhipuAITextEmbeddingModel()
  24. num_tokens = model.get_num_tokens(
  25. model="text_embedding", credentials={"api_key": os.environ.get("ZHIPUAI_API_KEY")}, texts=["hello", "world"]
  26. )
  27. assert num_tokens == 2