test_rerank.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import os
  2. from unittest.mock import Mock, patch
  3. import pytest
  4. from core.model_runtime.entities.rerank_entities import RerankResult
  5. from core.model_runtime.errors.validate import CredentialsValidateFailedError
  6. from core.model_runtime.model_providers.mixedbread.rerank.rerank import MixedBreadRerankModel
  7. def test_validate_credentials():
  8. model = MixedBreadRerankModel()
  9. with pytest.raises(CredentialsValidateFailedError):
  10. model.validate_credentials(
  11. model="mxbai-rerank-large-v1",
  12. credentials={"api_key": "invalid_key"},
  13. )
  14. with patch("httpx.post") as mock_post:
  15. mock_response = Mock()
  16. mock_response.json.return_value = {
  17. "usage": {"prompt_tokens": 86, "total_tokens": 86},
  18. "model": "mixedbread-ai/mxbai-rerank-large-v1",
  19. "data": [
  20. {
  21. "index": 0,
  22. "score": 0.06762695,
  23. "input": "Carson City is the capital city of the American state of Nevada. At the 2010 United "
  24. "States Census, Carson City had a population of 55,274.",
  25. "object": "text_document",
  26. },
  27. {
  28. "index": 1,
  29. "score": 0.057403564,
  30. "input": "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific "
  31. "Ocean that are a political division controlled by the United States. Its capital is "
  32. "Saipan.",
  33. "object": "text_document",
  34. },
  35. ],
  36. "object": "list",
  37. "top_k": 2,
  38. "return_input": True,
  39. }
  40. mock_response.status_code = 200
  41. mock_post.return_value = mock_response
  42. model.validate_credentials(
  43. model="mxbai-rerank-large-v1",
  44. credentials={
  45. "api_key": os.environ.get("MIXEDBREAD_API_KEY"),
  46. },
  47. )
  48. def test_invoke_model():
  49. model = MixedBreadRerankModel()
  50. with patch("httpx.post") as mock_post:
  51. mock_response = Mock()
  52. mock_response.json.return_value = {
  53. "usage": {"prompt_tokens": 56, "total_tokens": 56},
  54. "model": "mixedbread-ai/mxbai-rerank-large-v1",
  55. "data": [
  56. {
  57. "index": 0,
  58. "score": 0.6044922,
  59. "input": "Kasumi is a girl name of Japanese origin meaning mist.",
  60. "object": "text_document",
  61. },
  62. {
  63. "index": 1,
  64. "score": 0.0703125,
  65. "input": "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a "
  66. "team named PopiParty.",
  67. "object": "text_document",
  68. },
  69. ],
  70. "object": "list",
  71. "top_k": 2,
  72. "return_input": "true",
  73. }
  74. mock_response.status_code = 200
  75. mock_post.return_value = mock_response
  76. result = model.invoke(
  77. model="mxbai-rerank-large-v1",
  78. credentials={
  79. "api_key": os.environ.get("MIXEDBREAD_API_KEY"),
  80. },
  81. query="Who is Kasumi?",
  82. docs=[
  83. "Kasumi is a girl name of Japanese origin meaning mist.",
  84. "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music and she leads a team named "
  85. "PopiParty.",
  86. ],
  87. score_threshold=0.5,
  88. )
  89. assert isinstance(result, RerankResult)
  90. assert len(result.docs) == 1
  91. assert result.docs[0].index == 0
  92. assert result.docs[0].score >= 0.5