test_rerank.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. import os
  2. import pytest
  3. from core.model_runtime.entities.rerank_entities import RerankResult
  4. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  5. from core.model_runtime.model_providers.xinference.rerank.rerank import XinferenceRerankModel
  6. from tests.integration_tests.model_runtime.__mock.xinference import MOCK, setup_xinference_mock
  7. @pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
  8. def test_validate_credentials(setup_xinference_mock):
  9. model = XinferenceRerankModel()
  10. with pytest.raises(CredentialsValidateFailedError):
  11. model.validate_credentials(
  12. model='bge-reranker-base',
  13. credentials={
  14. 'server_url': 'awdawdaw',
  15. 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
  16. }
  17. )
  18. model.validate_credentials(
  19. model='bge-reranker-base',
  20. credentials={
  21. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  22. 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
  23. }
  24. )
  25. @pytest.mark.parametrize('setup_xinference_mock', [['none']], indirect=True)
  26. def test_invoke_model(setup_xinference_mock):
  27. model = XinferenceRerankModel()
  28. result = model.invoke(
  29. model='bge-reranker-base',
  30. credentials={
  31. 'server_url': os.environ.get('XINFERENCE_SERVER_URL'),
  32. 'model_uid': os.environ.get('XINFERENCE_RERANK_MODEL_UID')
  33. },
  34. query="Who is Kasumi?",
  35. docs=[
  36. "Kasumi is a girl's name of Japanese origin meaning \"mist\".",
  37. "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",
  38. "and she leads a team named PopiParty."
  39. ],
  40. score_threshold=0.8
  41. )
  42. assert isinstance(result, RerankResult)
  43. assert len(result.docs) == 1
  44. assert result.docs[0].index == 0
  45. assert result.docs[0].score >= 0.8