spark_llm.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. import base64
  2. import datetime
  3. import hashlib
  4. import hmac
  5. import json
  6. import queue
  7. from typing import Optional
  8. from urllib.parse import urlparse
  9. import ssl
  10. from datetime import datetime
  11. from time import mktime
  12. from urllib.parse import urlencode
  13. from wsgiref.handlers import format_date_time
  14. import websocket
  15. class SparkLLMClient:
  16. def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
  17. domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
  18. api_version = 'v2.1' if model_name == 'spark-v2' else 'v1.1'
  19. self.chat_domain = 'generalv2' if model_name == 'spark-v2' else 'general'
  20. self.api_base = f"wss://{domain}/{api_version}/chat"
  21. self.app_id = app_id
  22. self.ws_url = self.create_url(
  23. urlparse(self.api_base).netloc,
  24. urlparse(self.api_base).path,
  25. self.api_base,
  26. api_key,
  27. api_secret
  28. )
  29. self.queue = queue.Queue()
  30. self.blocking_message = ''
  31. def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
  32. # generate timestamp by RFC1123
  33. now = datetime.now()
  34. date = format_date_time(mktime(now.timetuple()))
  35. signature_origin = "host: " + host + "\n"
  36. signature_origin += "date: " + date + "\n"
  37. signature_origin += "GET " + path + " HTTP/1.1"
  38. # encrypt using hmac-sha256
  39. signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
  40. digestmod=hashlib.sha256).digest()
  41. signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
  42. authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
  43. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  44. v = {
  45. "authorization": authorization,
  46. "date": date,
  47. "host": host
  48. }
  49. # generate url
  50. url = api_base + '?' + urlencode(v)
  51. return url
  52. def run(self, messages: list, user_id: str,
  53. model_kwargs: Optional[dict] = None, streaming: bool = False):
  54. websocket.enableTrace(False)
  55. ws = websocket.WebSocketApp(
  56. self.ws_url,
  57. on_message=self.on_message,
  58. on_error=self.on_error,
  59. on_close=self.on_close,
  60. on_open=self.on_open
  61. )
  62. ws.messages = messages
  63. ws.user_id = user_id
  64. ws.model_kwargs = model_kwargs
  65. ws.streaming = streaming
  66. ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  67. def on_error(self, ws, error):
  68. self.queue.put({
  69. 'status_code': error.status_code,
  70. 'error': error.resp_body.decode('utf-8')
  71. })
  72. ws.close()
  73. def on_close(self, ws, close_status_code, close_reason):
  74. self.queue.put({'done': True})
  75. def on_open(self, ws):
  76. self.blocking_message = ''
  77. data = json.dumps(self.gen_params(
  78. messages=ws.messages,
  79. user_id=ws.user_id,
  80. model_kwargs=ws.model_kwargs
  81. ))
  82. ws.send(data)
  83. def on_message(self, ws, message):
  84. data = json.loads(message)
  85. code = data['header']['code']
  86. if code != 0:
  87. self.queue.put({
  88. 'status_code': 400,
  89. 'error': f"Code: {code}, Error: {data['header']['message']}"
  90. })
  91. ws.close()
  92. else:
  93. choices = data["payload"]["choices"]
  94. status = choices["status"]
  95. content = choices["text"][0]["content"]
  96. if ws.streaming:
  97. self.queue.put({'data': content})
  98. else:
  99. self.blocking_message += content
  100. if status == 2:
  101. if not ws.streaming:
  102. self.queue.put({'data': self.blocking_message})
  103. ws.close()
  104. def gen_params(self, messages: list, user_id: str,
  105. model_kwargs: Optional[dict] = None) -> dict:
  106. data = {
  107. "header": {
  108. "app_id": self.app_id,
  109. "uid": user_id
  110. },
  111. "parameter": {
  112. "chat": {
  113. "domain": self.chat_domain
  114. }
  115. },
  116. "payload": {
  117. "message": {
  118. "text": messages
  119. }
  120. }
  121. }
  122. if model_kwargs:
  123. data['parameter']['chat'].update(model_kwargs)
  124. return data
  125. def subscribe(self):
  126. while True:
  127. content = self.queue.get()
  128. if 'error' in content:
  129. if content['status_code'] == 401:
  130. raise SparkError('[Spark] The credentials you provided are incorrect. '
  131. 'Please double-check and fill them in again.')
  132. elif content['status_code'] == 403:
  133. raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
  134. "Please try again after obtaining the necessary permissions.")
  135. else:
  136. raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
  137. if 'data' not in content:
  138. break
  139. yield content
  140. class SparkError(Exception):
  141. pass