123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- import os
- import re
- from typing import Union
- import pytest
- from _pytest.monkeypatch import MonkeyPatch
- from requests import Response
- from requests.exceptions import ConnectionError
- from requests.sessions import Session
- from xinference_client.client.restful.restful_client import (
- Client,
- RESTfulChatglmCppChatModelHandle,
- RESTfulChatModelHandle,
- RESTfulEmbeddingModelHandle,
- RESTfulGenerateModelHandle,
- RESTfulRerankModelHandle,
- )
- from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
- class MockXinferenceClass:
- def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
- if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
- raise RuntimeError('404 Not Found')
-
- if 'generate' == model_uid:
- return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
- if 'chat' == model_uid:
- return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
- if 'embedding' == model_uid:
- return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
- if 'rerank' == model_uid:
- return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
- raise RuntimeError('404 Not Found')
-
- def get(self: Session, url: str, **kwargs):
- response = Response()
- if 'v1/models/' in url:
- # get model uid
- model_uid = url.split('/')[-1] or ''
- if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
- model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
- response.status_code = 404
- response._content = b'{}'
- return response
- # check if url is valid
- if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
- response.status_code = 404
- response._content = b'{}'
- return response
-
- if model_uid in ['generate', 'chat']:
- response.status_code = 200
- response._content = b'''{
- "model_type": "LLM",
- "address": "127.0.0.1:43877",
- "accelerators": [
- "0",
- "1"
- ],
- "model_name": "chatglm3-6b",
- "model_lang": [
- "en"
- ],
- "model_ability": [
- "generate",
- "chat"
- ],
- "model_description": "latest chatglm3",
- "model_format": "pytorch",
- "model_size_in_billions": 7,
- "quantization": "none",
- "model_hub": "huggingface",
- "revision": null,
- "context_length": 2048,
- "replica": 1
- }'''
- return response
-
- elif model_uid == 'embedding':
- response.status_code = 200
- response._content = b'''{
- "model_type": "embedding",
- "address": "127.0.0.1:43877",
- "accelerators": [
- "0",
- "1"
- ],
- "model_name": "bge",
- "model_lang": [
- "en"
- ],
- "revision": null,
- "max_tokens": 512
- }'''
- return response
-
- elif 'v1/cluster/auth' in url:
- response.status_code = 200
- response._content = b'''{
- "auth": true
- }'''
- return response
-
- def _check_cluster_authenticated(self):
- self._cluster_authed = True
-
- def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
- # check if self._model_uid is a valid uuid
- if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
- self._model_uid != 'rerank':
- raise RuntimeError('404 Not Found')
-
- if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
- raise RuntimeError('404 Not Found')
- if top_n is None:
- top_n = 1
- return {
- 'results': [
- {
- 'index': i,
- 'document': doc,
- 'relevance_score': 0.9
- }
- for i, doc in enumerate(documents[:top_n])
- ]
- }
-
- def create_embedding(
- self: RESTfulGenerateModelHandle,
- input: Union[str, list[str]],
- **kwargs
- ) -> dict:
- # check if self._model_uid is a valid uuid
- if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
- self._model_uid != 'embedding':
- raise RuntimeError('404 Not Found')
- if isinstance(input, str):
- input = [input]
- ipt_len = len(input)
- embedding = Embedding(
- object="list",
- model=self._model_uid,
- data=[
- EmbeddingData(
- index=i,
- object="embedding",
- embedding=[1919.810 for _ in range(768)]
- )
- for i in range(ipt_len)
- ],
- usage=EmbeddingUsage(
- prompt_tokens=ipt_len,
- total_tokens=ipt_len
- )
- )
- return embedding
- MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
- @pytest.fixture
- def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
- if MOCK:
- monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
- monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
- monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
- monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
- monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
- yield
- if MOCK:
- monkeypatch.undo()
|