| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170 | 
							- import os
 
- import re
 
- from typing import Union
 
- import pytest
 
- from _pytest.monkeypatch import MonkeyPatch
 
- from requests import Response
 
- from requests.sessions import Session
 
- from xinference_client.client.restful.restful_client import (
 
-     Client,
 
-     RESTfulChatModelHandle,
 
-     RESTfulEmbeddingModelHandle,
 
-     RESTfulGenerateModelHandle,
 
-     RESTfulRerankModelHandle,
 
- )
 
- from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
 
- class MockXinferenceClass:
 
-     def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
 
-         if not re.match(r"https?:\/\/[^\s\/$.?#].[^\s]*$", self.base_url):
 
-             raise RuntimeError("404 Not Found")
 
-         if "generate" == model_uid:
 
-             return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
 
-         if "chat" == model_uid:
 
-             return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
 
-         if "embedding" == model_uid:
 
-             return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
 
-         if "rerank" == model_uid:
 
-             return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
 
-         raise RuntimeError("404 Not Found")
 
-     def get(self: Session, url: str, **kwargs):
 
-         response = Response()
 
-         if "v1/models/" in url:
 
-             # get model uid
 
-             model_uid = url.split("/")[-1] or ""
 
-             if not re.match(
 
-                 r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", model_uid
 
-             ) and model_uid not in {"generate", "chat", "embedding", "rerank"}:
 
-                 response.status_code = 404
 
-                 response._content = b"{}"
 
-                 return response
 
-             # check if url is valid
 
-             if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", url):
 
-                 response.status_code = 404
 
-                 response._content = b"{}"
 
-                 return response
 
-             if model_uid in {"generate", "chat"}:
 
-                 response.status_code = 200
 
-                 response._content = b"""{
 
-                     "model_type": "LLM",
 
-                     "address": "127.0.0.1:43877",
 
-                     "accelerators": [
 
-                         "0",
 
-                         "1"
 
-                     ],
 
-                     "model_name": "chatglm3-6b",
 
-                     "model_lang": [
 
-                         "en"
 
-                     ],
 
-                     "model_ability": [
 
-                         "generate",
 
-                         "chat"
 
-                     ],
 
-                     "model_description": "latest chatglm3",
 
-                     "model_format": "pytorch",
 
-                     "model_size_in_billions": 7,
 
-                     "quantization": "none",
 
-                     "model_hub": "huggingface",
 
-                     "revision": null,
 
-                     "context_length": 2048,
 
-                     "replica": 1
 
-                 }"""
 
-                 return response
 
-             elif model_uid == "embedding":
 
-                 response.status_code = 200
 
-                 response._content = b"""{
 
-                     "model_type": "embedding",
 
-                     "address": "127.0.0.1:43877",
 
-                     "accelerators": [
 
-                         "0",
 
-                         "1"
 
-                     ],
 
-                     "model_name": "bge",
 
-                     "model_lang": [
 
-                         "en"
 
-                     ],
 
-                     "revision": null,
 
-                     "max_tokens": 512
 
-                 }"""
 
-                 return response
 
-         elif "v1/cluster/auth" in url:
 
-             response.status_code = 200
 
-             response._content = b"""{
 
-                 "auth": true
 
-             }"""
 
-             return response
 
-     def _check_cluster_authenticated(self):
 
-         self._cluster_authed = True
 
-     def rerank(
 
-         self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool
 
-     ) -> dict:
 
-         # check if self._model_uid is a valid uuid
 
-         if (
 
-             not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
 
-             and self._model_uid != "rerank"
 
-         ):
 
-             raise RuntimeError("404 Not Found")
 
-         if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", self._base_url):
 
-             raise RuntimeError("404 Not Found")
 
-         if top_n is None:
 
-             top_n = 1
 
-         return {
 
-             "results": [
 
-                 {"index": i, "document": doc, "relevance_score": 0.9} for i, doc in enumerate(documents[:top_n])
 
-             ]
 
-         }
 
-     def create_embedding(self: RESTfulGenerateModelHandle, input: Union[str, list[str]], **kwargs) -> dict:
 
-         # check if self._model_uid is a valid uuid
 
-         if (
 
-             not re.match(r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}", self._model_uid)
 
-             and self._model_uid != "embedding"
 
-         ):
 
-             raise RuntimeError("404 Not Found")
 
-         if isinstance(input, str):
 
-             input = [input]
 
-         ipt_len = len(input)
 
-         embedding = Embedding(
 
-             object="list",
 
-             model=self._model_uid,
 
-             data=[
 
-                 EmbeddingData(index=i, object="embedding", embedding=[1919.810 for _ in range(768)])
 
-                 for i in range(ipt_len)
 
-             ],
 
-             usage=EmbeddingUsage(prompt_tokens=ipt_len, total_tokens=ipt_len),
 
-         )
 
-         return embedding
 
- MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"
 
- @pytest.fixture
 
- def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
 
-     if MOCK:
 
-         monkeypatch.setattr(Client, "get_model", MockXinferenceClass.get_chat_model)
 
-         monkeypatch.setattr(Client, "_check_cluster_authenticated", MockXinferenceClass._check_cluster_authenticated)
 
-         monkeypatch.setattr(Session, "get", MockXinferenceClass.get)
 
-         monkeypatch.setattr(RESTfulEmbeddingModelHandle, "create_embedding", MockXinferenceClass.create_embedding)
 
-         monkeypatch.setattr(RESTfulRerankModelHandle, "rerank", MockXinferenceClass.rerank)
 
-     yield
 
-     if MOCK:
 
-         monkeypatch.undo()
 
 
  |