| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171 | import osimport refrom typing import List, Unionimport pytestfrom _pytest.monkeypatch import MonkeyPatchfrom requests import Responsefrom requests.exceptions import ConnectionErrorfrom requests.sessions import Sessionfrom xinference_client.client.restful.restful_client import (Client, RESTfulChatglmCppChatModelHandle,                                                             RESTfulChatModelHandle, RESTfulEmbeddingModelHandle,                                                             RESTfulGenerateModelHandle, RESTfulRerankModelHandle)from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsageclass MockXinferenceClass(object):    def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:        if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):            raise RuntimeError('404 Not Found')                if 'generate' == model_uid:            return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})        if 'chat' == model_uid:            return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})        if 'embedding' == model_uid:            return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})        if 'rerank' == model_uid:            return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})        raise RuntimeError('404 Not Found')            def get(self: Session, url: str, **kwargs):        response = Response()        if 'v1/models/' in url:            # get model uid            model_uid = url.split('/')[-1]            if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \                model_uid not in ['generate', 'chat', 'embedding', 'rerank']:                response.status_code = 404                return response            # check if url is valid            if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):                response.status_code = 404                return response                        if model_uid in ['generate', 'chat']:                response.status_code = 200                response._content = b'''{        "model_type": "LLM",        "address": "127.0.0.1:43877",        "accelerators": [            "0",            "1"        ],        "model_name": "chatglm3-6b",        "model_lang": [            "en"        ],        "model_ability": [            "generate",            "chat"        ],        "model_description": "latest chatglm3",        "model_format": "pytorch",        "model_size_in_billions": 7,        "quantization": "none",        "model_hub": "huggingface",        "revision": null,        "context_length": 2048,        "replica": 1    }'''                return response                        elif model_uid == 'embedding':                response.status_code = 200                response._content = b'''{        "model_type": "embedding",        "address": "127.0.0.1:43877",        "accelerators": [            "0",            "1"        ],        "model_name": "bge",        "model_lang": [            "en"        ],        "revision": null,        "max_tokens": 512}'''                return response                    elif 'v1/cluster/auth' in url:            response.status_code = 200            response._content = b'''{    "auth": true}'''            return response            def _check_cluster_authenticated(self):        self._cluster_authed = True            def rerank(self: RESTfulRerankModelHandle, documents: List[str], query: str, top_n: int) -> dict:        # check if self._model_uid is a valid uuid        if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \            self._model_uid != 'rerank':            raise RuntimeError('404 Not Found')                if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):            raise RuntimeError('404 Not Found')        if top_n is None:            top_n = 1        return {            'results': [                {                    'index': i,                    'document': doc,                    'relevance_score': 0.9                }                for i, doc in enumerate(documents[:top_n])            ]        }            def create_embedding(        self: RESTfulGenerateModelHandle,        input: Union[str, List[str]],        **kwargs    ) -> dict:        # check if self._model_uid is a valid uuid        if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \            self._model_uid != 'embedding':            raise RuntimeError('404 Not Found')        if isinstance(input, str):            input = [input]        ipt_len = len(input)        embedding = Embedding(            object="list",            model=self._model_uid,            data=[                EmbeddingData(                    index=i,                    object="embedding",                    embedding=[1919.810 for _ in range(768)]                )                for i in range(ipt_len)            ],            usage=EmbeddingUsage(                prompt_tokens=ipt_len,                total_tokens=ipt_len            )        )        return embeddingMOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'@pytest.fixturedef setup_xinference_mock(request, monkeypatch: MonkeyPatch):    if MOCK:        monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)        monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)        monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)        monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)        monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)    yield    if MOCK:        monkeypatch.undo()
 |