| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 | import osimport pytestfrom core.model_runtime.entities.rerank_entities import RerankResultfrom core.model_runtime.errors.validate import CredentialsValidateFailedErrorfrom core.model_runtime.model_providers.huggingface_tei.rerank.rerank import (    HuggingfaceTeiRerankModel,)from core.model_runtime.model_providers.huggingface_tei.text_embedding.text_embedding import TeiHelperfrom tests.integration_tests.model_runtime.__mock.huggingface_tei import MockTEIClassMOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"@pytest.fixturedef setup_tei_mock(request, monkeypatch: pytest.MonkeyPatch):    if MOCK:        monkeypatch.setattr(TeiHelper, "get_tei_extra_parameter", MockTEIClass.get_tei_extra_parameter)        monkeypatch.setattr(TeiHelper, "invoke_tokenize", MockTEIClass.invoke_tokenize)        monkeypatch.setattr(TeiHelper, "invoke_embeddings", MockTEIClass.invoke_embeddings)        monkeypatch.setattr(TeiHelper, "invoke_rerank", MockTEIClass.invoke_rerank)    yield    if MOCK:        monkeypatch.undo()@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)def test_validate_credentials(setup_tei_mock):    model = HuggingfaceTeiRerankModel()    # model name is only used in mock    model_name = "reranker"    if MOCK:        # TEI Provider will check model type by API endpoint, at real server, the model type is correct.        # So we dont need to check model type here. Only check in mock        with pytest.raises(CredentialsValidateFailedError):            model.validate_credentials(                model="embedding",                credentials={                    "server_url": os.environ.get("TEI_RERANK_SERVER_URL"),                    "api_key": os.environ.get("TEI_API_KEY", ""),                },            )    model.validate_credentials(        model=model_name,        credentials={            "server_url": os.environ.get("TEI_RERANK_SERVER_URL"),            "api_key": os.environ.get("TEI_API_KEY", ""),        },    )@pytest.mark.parametrize("setup_tei_mock", [["none"]], indirect=True)def test_invoke_model(setup_tei_mock):    model = HuggingfaceTeiRerankModel()    # model name is only used in mock    model_name = "reranker"    result = model.invoke(        model=model_name,        credentials={            "server_url": os.environ.get("TEI_RERANK_SERVER_URL"),            "api_key": os.environ.get("TEI_API_KEY", ""),        },        query="Who is Kasumi?",        docs=[            'Kasumi is a girl\'s name of Japanese origin meaning "mist".',            "Her music is a kawaii bass, a mix of future bass, pop, and kawaii music ",            "and she leads a team named PopiParty.",        ],        score_threshold=0.8,    )    assert isinstance(result, RerankResult)    assert len(result.docs) == 1    assert result.docs[0].index == 0    assert result.docs[0].score >= 0.8
 |