Pārlūkot izejas kodu

fix: replace os.path.join with yarl (#2690)

Yeuoly 1 gadu atpakaļ
vecāks
revīzija
95733796f0

+ 5 - 3
api/core/model_runtime/model_providers/xinference/xinference_helper.py

@@ -1,10 +1,10 @@
-from os import path
 from threading import Lock
 from threading import Lock
 from time import time
 from time import time
 
 
 from requests.adapters import HTTPAdapter
 from requests.adapters import HTTPAdapter
 from requests.exceptions import ConnectionError, MissingSchema, Timeout
 from requests.exceptions import ConnectionError, MissingSchema, Timeout
 from requests.sessions import Session
 from requests.sessions import Session
+from yarl import URL
 
 
 
 
 class XinferenceModelExtraParameter:
 class XinferenceModelExtraParameter:
@@ -55,7 +55,10 @@ class XinferenceHelper:
             get xinference model extra parameter like model_format and model_handle_type
             get xinference model extra parameter like model_format and model_handle_type
         """
         """
 
 
-        url = path.join(server_url, 'v1/models', model_uid)
+        if not model_uid or not model_uid.strip() or not server_url or not server_url.strip():
+            raise RuntimeError('model_uid is empty')
+
+        url = str(URL(server_url) / 'v1' / 'models' / model_uid)
 
 
         # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
         # this method is surrounded by a lock, and default requests may hang forever, so we just set a Adapter with max_retries=3
         session = Session()
         session = Session()
@@ -66,7 +69,6 @@ class XinferenceHelper:
             response = session.get(url, timeout=10)
             response = session.get(url, timeout=10)
         except (MissingSchema, ConnectionError, Timeout) as e:
         except (MissingSchema, ConnectionError, Timeout) as e:
             raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
             raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
-
         if response.status_code != 200:
         if response.status_code != 200:
             raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')
             raise RuntimeError(f'get xinference model extra parameter failed, status code: {response.status_code}, response: {response.text}')
         
         

+ 2 - 1
api/requirements.txt

@@ -68,4 +68,5 @@ pydub~=0.25.1
 gmpy2~=2.1.5
 gmpy2~=2.1.5
 numexpr~=2.9.0
 numexpr~=2.9.0
 duckduckgo-search==4.4.3
 duckduckgo-search==4.4.3
-arxiv==2.1.0
+arxiv==2.1.0
+yarl~=1.9.4

+ 41 - 39
api/tests/integration_tests/model_runtime/__mock/xinference.py

@@ -32,68 +32,70 @@ class MockXinferenceClass(object):
         response = Response()
         response = Response()
         if 'v1/models/' in url:
         if 'v1/models/' in url:
             # get model uid
             # get model uid
-            model_uid = url.split('/')[-1]
+            model_uid = url.split('/')[-1] or ''
             if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
             if not re.match(r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}', model_uid) and \
                 model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
                 model_uid not in ['generate', 'chat', 'embedding', 'rerank']:
                 response.status_code = 404
                 response.status_code = 404
+                response._content = b'{}'
                 return response
                 return response
 
 
             # check if url is valid
             # check if url is valid
             if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
             if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', url):
                 response.status_code = 404
                 response.status_code = 404
+                response._content = b'{}'
                 return response
                 return response
             
             
             if model_uid in ['generate', 'chat']:
             if model_uid in ['generate', 'chat']:
                 response.status_code = 200
                 response.status_code = 200
                 response._content = b'''{
                 response._content = b'''{
-        "model_type": "LLM",
-        "address": "127.0.0.1:43877",
-        "accelerators": [
-            "0",
-            "1"
-        ],
-        "model_name": "chatglm3-6b",
-        "model_lang": [
-            "en"
-        ],
-        "model_ability": [
-            "generate",
-            "chat"
-        ],
-        "model_description": "latest chatglm3",
-        "model_format": "pytorch",
-        "model_size_in_billions": 7,
-        "quantization": "none",
-        "model_hub": "huggingface",
-        "revision": null,
-        "context_length": 2048,
-        "replica": 1
-    }'''
+                    "model_type": "LLM",
+                    "address": "127.0.0.1:43877",
+                    "accelerators": [
+                        "0",
+                        "1"
+                    ],
+                    "model_name": "chatglm3-6b",
+                    "model_lang": [
+                        "en"
+                    ],
+                    "model_ability": [
+                        "generate",
+                        "chat"
+                    ],
+                    "model_description": "latest chatglm3",
+                    "model_format": "pytorch",
+                    "model_size_in_billions": 7,
+                    "quantization": "none",
+                    "model_hub": "huggingface",
+                    "revision": null,
+                    "context_length": 2048,
+                    "replica": 1
+                }'''
                 return response
                 return response
             
             
             elif model_uid == 'embedding':
             elif model_uid == 'embedding':
                 response.status_code = 200
                 response.status_code = 200
                 response._content = b'''{
                 response._content = b'''{
-        "model_type": "embedding",
-        "address": "127.0.0.1:43877",
-        "accelerators": [
-            "0",
-            "1"
-        ],
-        "model_name": "bge",
-        "model_lang": [
-            "en"
-        ],
-        "revision": null,
-        "max_tokens": 512
-}'''
+                    "model_type": "embedding",
+                    "address": "127.0.0.1:43877",
+                    "accelerators": [
+                        "0",
+                        "1"
+                    ],
+                    "model_name": "bge",
+                    "model_lang": [
+                        "en"
+                    ],
+                    "revision": null,
+                    "max_tokens": 512
+                }'''
                 return response
                 return response
             
             
         elif 'v1/cluster/auth' in url:
         elif 'v1/cluster/auth' in url:
             response.status_code = 200
             response.status_code = 200
             response._content = b'''{
             response._content = b'''{
-    "auth": true
-}'''
+                "auth": true
+            }'''
             return response
             return response
         
         
     def _check_cluster_authenticated(self):
     def _check_cluster_authenticated(self):