xinference.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import os
  2. import re
  3. from typing import Union
  4. import pytest
  5. from _pytest.monkeypatch import MonkeyPatch
  6. from requests import Response
  7. from requests.exceptions import ConnectionError
  8. from requests.sessions import Session
  9. from xinference_client.client.restful.restful_client import (
  10. Client,
  11. RESTfulChatglmCppChatModelHandle,
  12. RESTfulChatModelHandle,
  13. RESTfulEmbeddingModelHandle,
  14. RESTfulGenerateModelHandle,
  15. RESTfulRerankModelHandle,
  16. )
  17. from xinference_client.types import Embedding, EmbeddingData, EmbeddingUsage
  18. class MockXinferenceClass:
  19. def get_chat_model(self: Client, model_uid: str) -> Union[RESTfulChatglmCppChatModelHandle, RESTfulGenerateModelHandle, RESTfulChatModelHandle]:
  20. if not re.match(r'https?:\/\/[^\s\/$.?#].[^\s]*$', self.base_url):
  21. raise RuntimeError('404 Not Found')
  22. if 'generate' == model_uid:
  23. return RESTfulGenerateModelHandle(model_uid, base_url=self.base_url, auth_headers={})
  24. if 'chat' == model_uid:
  25. return RESTfulChatModelHandle(model_uid, base_url=self.base_url, auth_headers={})
  26. if 'embedding' == model_uid:
  27. return RESTfulEmbeddingModelHandle(model_uid, base_url=self.base_url, auth_headers={})
  28. if 'rerank' == model_uid:
  29. return RESTfulRerankModelHandle(model_uid, base_url=self.base_url, auth_headers={})
  30. raise RuntimeError('404 Not Found')
  31. def get(self: Session, url: str, **kwargs):
  32. response = Response()
  33. if 'v1/models/' in url:
  34. # get model uid
  35. model_uid = url.split('/')[-1] or ''
  36. if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
  37. model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
  38. response.status_code = 404
  39. response._content = b'{}'
  40. return response
  41. # check if url is valid
  42. if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
  43. response.status_code = 404
  44. response._content = b'{}'
  45. return response
  46. if model_uid in ['generate', 'chat']:
  47. response.status_code = 200
  48. response._content = b'''{
  49. "model_type": "LLM",
  50. "address": "127.0.0.1:43877",
  51. "accelerators": [
  52. "0",
  53. "1"
  54. ],
  55. "model_name": "chatglm3-6b",
  56. "model_lang": [
  57. "en"
  58. ],
  59. "model_ability": [
  60. "generate",
  61. "chat"
  62. ],
  63. "model_description": "latest chatglm3",
  64. "model_format": "pytorch",
  65. "model_size_in_billions": 7,
  66. "quantization": "none",
  67. "model_hub": "huggingface",
  68. "revision": null,
  69. "context_length": 2048,
  70. "replica": 1
  71. }'''
  72. return response
  73. elif model_uid == 'embedding':
  74. response.status_code = 200
  75. response._content = b'''{
  76. "model_type": "embedding",
  77. "address": "127.0.0.1:43877",
  78. "accelerators": [
  79. "0",
  80. "1"
  81. ],
  82. "model_name": "bge",
  83. "model_lang": [
  84. "en"
  85. ],
  86. "revision": null,
  87. "max_tokens": 512
  88. }'''
  89. return response
  90. elif 'v1/cluster/auth' in url:
  91. response.status_code = 200
  92. response._content = b'''{
  93. "auth": true
  94. }'''
  95. return response
  96. def _check_cluster_authenticated(self):
  97. self._cluster_authed = True
  98. def rerank(self: RESTfulRerankModelHandle, documents: list[str], query: str, top_n: int, return_documents: bool) -> dict:
  99. # check if self._model_uid is a valid uuid
  100. if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
  101. self._model_uid != 'rerank':
  102. raise RuntimeError('404 Not Found')
  103. if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._base_url):
  104. raise RuntimeError('404 Not Found')
  105. if top_n is None:
  106. top_n = 1
  107. return {
  108. 'results': [
  109. {
  110. 'index': i,
  111. 'document': doc,
  112. 'relevance_score': 0.9
  113. }
  114. for i, doc in enumerate(documents[:top_n])
  115. ]
  116. }
  117. def create_embedding(
  118. self: RESTfulGenerateModelHandle,
  119. input: Union[str, list[str]],
  120. **kwargs
  121. ) -> dict:
  122. # check if self._model_uid is a valid uuid
  123. if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', self._model_uid) and \
  124. self._model_uid != 'embedding':
  125. raise RuntimeError('404 Not Found')
  126. if isinstance(input, str):
  127. input = [input]
  128. ipt_len = len(input)
  129. embedding = Embedding(
  130. object="list",
  131. model=self._model_uid,
  132. data=[
  133. EmbeddingData(
  134. index=i,
  135. object="embedding",
  136. embedding=[1919.810 for _ in range(768)]
  137. )
  138. for i in range(ipt_len)
  139. ],
  140. usage=EmbeddingUsage(
  141. prompt_tokens=ipt_len,
  142. total_tokens=ipt_len
  143. )
  144. )
  145. return embedding
  146. MOCK = os.getenv('MOCK_SWITCH', 'false').lower() == 'true'
  147. @pytest.fixture
  148. def setup_xinference_mock(request, monkeypatch: MonkeyPatch):
  149. if MOCK:
  150. monkeypatch.setattr(Client, 'get_model', MockXinferenceClass.get_chat_model)
  151. monkeypatch.setattr(Client, '_check_cluster_authenticated', MockXinferenceClass._check_cluster_authenticated)
  152. monkeypatch.setattr(Session, 'get', MockXinferenceClass.get)
  153. monkeypatch.setattr(RESTfulEmbeddingModelHandle, 'create_embedding', MockXinferenceClass.create_embedding)
  154. monkeypatch.setattr(RESTfulRerankModelHandle, 'rerank', MockXinferenceClass.rerank)
  155. yield
  156. if MOCK:
  157. monkeypatch.undo()