| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168 | 
							- import base64
 
- import datetime
 
- import hashlib
 
- import hmac
 
- import json
 
- import queue
 
- from typing import Optional
 
- from urllib.parse import urlparse
 
- import ssl
 
- from datetime import datetime
 
- from time import mktime
 
- from urllib.parse import urlencode
 
- from wsgiref.handlers import format_date_time
 
- import websocket
 
- class SparkLLMClient:
 
-     def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
 
-         domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
 
-         api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1'
 
-         self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general'
 
-         self.api_base = f"wss://{domain}/{api_version}/chat"
 
-         self.app_id = app_id
 
-         self.ws_url = self.create_url(
 
-             urlparse(self.api_base).netloc,
 
-             urlparse(self.api_base).path,
 
-             self.api_base,
 
-             api_key,
 
-             api_secret
 
-         )
 
-         self.queue = queue.Queue()
 
-         self.blocking_message = ''
 
-     def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
 
-         # generate timestamp by RFC1123
 
-         now = datetime.now()
 
-         date = format_date_time(mktime(now.timetuple()))
 
-         signature_origin = "host: " + host + "\n"
 
-         signature_origin += "date: " + date + "\n"
 
-         signature_origin += "GET " + path + " HTTP/1.1"
 
-         # encrypt using hmac-sha256
 
-         signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
 
-                                  digestmod=hashlib.sha256).digest()
 
-         signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
 
-         authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
 
-         authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
 
-         v = {
 
-             "authorization": authorization,
 
-             "date": date,
 
-             "host": host
 
-         }
 
-         # generate url
 
-         url = api_base + '?' + urlencode(v)
 
-         return url
 
-     def run(self, messages: list, user_id: str,
 
-             model_kwargs: Optional[dict] = None, streaming: bool = False):
 
-         websocket.enableTrace(False)
 
-         ws = websocket.WebSocketApp(
 
-             self.ws_url,
 
-             on_message=self.on_message,
 
-             on_error=self.on_error,
 
-             on_close=self.on_close,
 
-             on_open=self.on_open
 
-         )
 
-         ws.messages = messages
 
-         ws.user_id = user_id
 
-         ws.model_kwargs = model_kwargs
 
-         ws.streaming = streaming
 
-         ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
 
-     def on_error(self, ws, error):
 
-         self.queue.put({
 
-             'status_code': error.status_code,
 
-             'error': error.resp_body.decode('utf-8')
 
-         })
 
-         ws.close()
 
-     def on_close(self, ws, close_status_code, close_reason):
 
-         self.queue.put({'done': True})
 
-     def on_open(self, ws):
 
-         self.blocking_message = ''
 
-         data = json.dumps(self.gen_params(
 
-             messages=ws.messages,
 
-             user_id=ws.user_id,
 
-             model_kwargs=ws.model_kwargs
 
-         ))
 
-         ws.send(data)
 
-     def on_message(self, ws, message):
 
-         data = json.loads(message)
 
-         code = data['header']['code']
 
-         if code != 0:
 
-             self.queue.put({
 
-                 'status_code': 400,
 
-                 'error': f"Code: {code}, Error: {data['header']['message']}"
 
-             })
 
-             ws.close()
 
-         else:
 
-             choices = data["payload"]["choices"]
 
-             status = choices["status"]
 
-             content = choices["text"][0]["content"]
 
-             if ws.streaming:
 
-                 self.queue.put({'data': content})
 
-             else:
 
-                 self.blocking_message += content
 
-             if status == 2:
 
-                 if not ws.streaming:
 
-                     self.queue.put({'data': self.blocking_message})
 
-                 ws.close()
 
-     def gen_params(self, messages: list, user_id: str,
 
-                    model_kwargs: Optional[dict] = None) -> dict:
 
-         data = {
 
-             "header": {
 
-                 "app_id": self.app_id,
 
-                 "uid": user_id
 
-             },
 
-             "parameter": {
 
-                 "chat": {
 
-                     "domain": self.chat_domain
 
-                 }
 
-             },
 
-             "payload": {
 
-                 "message": {
 
-                     "text": messages
 
-                 }
 
-             }
 
-         }
 
-         if model_kwargs:
 
-             data['parameter']['chat'].update(model_kwargs)
 
-         return data
 
-     def subscribe(self):
 
-         while True:
 
-             content = self.queue.get()
 
-             if 'error' in content:
 
-                 if content['status_code'] == 401:
 
-                     raise SparkError('[Spark] The credentials you provided are incorrect. '
 
-                                      'Please double-check and fill them in again.')
 
-                 elif content['status_code'] == 403:
 
-                     raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
 
-                                      "Please try again after obtaining the necessary permissions.")
 
-                 else:
 
-                     raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
 
-             if 'data' not in content:
 
-                 break
 
-             yield content
 
- class SparkError(Exception):
 
-     pass
 
 
  |