test_rerank.py 1.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.rerank_entities import RerankResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.cohere.rerank.rerank import CohereRerankModel
  6. def test_validate_credentials():
  7. model = CohereRerankModel()
  8. with pytest.raises(CredentialsValidateFailedError):
  9. model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": "invalid_key"})
  10. model.validate_credentials(model="rerank-english-v2.0", credentials={"api_key": os.environ.get("COHERE_API_KEY")})
  11. def test_invoke_model():
  12. model = CohereRerankModel()
  13. result = model.invoke(
  14. model="rerank-english-v2.0",
  15. credentials={"api_key": os.environ.get("COHERE_API_KEY")},
  16. query="What is the capital of the United States?",
  17. docs=[
  18. "Carson City is the capital city of the American state of Nevada. At the 2010 United States "
  19. "Census, Carson City had a population of 55,274.",
  20. "Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) "
  21. "is the capital of the United States. It is a federal district. The President of the USA and many major "
  22. "national government offices are in the territory. This makes it the political center of the United "
  23. "States of America.",
  24. ],
  25. score_threshold=0.8,
  26. )
  27. assert isinstance(result, RerankResult)
  28. assert len(result.docs) == 1
  29. assert result.docs[0].index == 1
  30. assert result.docs[0].score >= 0.8