spark_llm.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. import base64
  2. import datetime
  3. import hashlib
  4. import hmac
  5. import json
  6. import queue
  7. from typing import Optional
  8. from urllib.parse import urlparse
  9. import ssl
  10. from datetime import datetime
  11. from time import mktime
  12. from urllib.parse import urlencode
  13. from wsgiref.handlers import format_date_time
  14. import websocket
  15. class SparkLLMClient:
  16. def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
  17. domain = 'spark-api.xf-yun.com'
  18. endpoint = 'chat'
  19. if api_domain:
  20. domain = api_domain
  21. if model_name == 'spark-v3':
  22. endpoint = 'multimodal'
  23. model_api_configs = {
  24. 'spark': {
  25. 'version': 'v1.1',
  26. 'chat_domain': 'general'
  27. },
  28. 'spark-v2': {
  29. 'version': 'v2.1',
  30. 'chat_domain': 'generalv2'
  31. },
  32. 'spark-v3': {
  33. 'version': 'v3.1',
  34. 'chat_domain': 'generalv3'
  35. }
  36. }
  37. api_version = model_api_configs[model_name]['version']
  38. self.chat_domain = model_api_configs[model_name]['chat_domain']
  39. self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
  40. self.app_id = app_id
  41. self.ws_url = self.create_url(
  42. urlparse(self.api_base).netloc,
  43. urlparse(self.api_base).path,
  44. self.api_base,
  45. api_key,
  46. api_secret
  47. )
  48. self.queue = queue.Queue()
  49. self.blocking_message = ''
  50. def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
  51. # generate timestamp by RFC1123
  52. now = datetime.now()
  53. date = format_date_time(mktime(now.timetuple()))
  54. signature_origin = "host: " + host + "\n"
  55. signature_origin += "date: " + date + "\n"
  56. signature_origin += "GET " + path + " HTTP/1.1"
  57. # encrypt using hmac-sha256
  58. signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
  59. digestmod=hashlib.sha256).digest()
  60. signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
  61. authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
  62. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  63. v = {
  64. "authorization": authorization,
  65. "date": date,
  66. "host": host
  67. }
  68. # generate url
  69. url = api_base + '?' + urlencode(v)
  70. return url
  71. def run(self, messages: list, user_id: str,
  72. model_kwargs: Optional[dict] = None, streaming: bool = False):
  73. websocket.enableTrace(False)
  74. ws = websocket.WebSocketApp(
  75. self.ws_url,
  76. on_message=self.on_message,
  77. on_error=self.on_error,
  78. on_close=self.on_close,
  79. on_open=self.on_open
  80. )
  81. ws.messages = messages
  82. ws.user_id = user_id
  83. ws.model_kwargs = model_kwargs
  84. ws.streaming = streaming
  85. ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  86. def on_error(self, ws, error):
  87. self.queue.put({
  88. 'status_code': error.status_code,
  89. 'error': error.resp_body.decode('utf-8')
  90. })
  91. ws.close()
  92. def on_close(self, ws, close_status_code, close_reason):
  93. self.queue.put({'done': True})
  94. def on_open(self, ws):
  95. self.blocking_message = ''
  96. data = json.dumps(self.gen_params(
  97. messages=ws.messages,
  98. user_id=ws.user_id,
  99. model_kwargs=ws.model_kwargs
  100. ))
  101. ws.send(data)
  102. def on_message(self, ws, message):
  103. data = json.loads(message)
  104. code = data['header']['code']
  105. if code != 0:
  106. self.queue.put({
  107. 'status_code': 400,
  108. 'error': f"Code: {code}, Error: {data['header']['message']}"
  109. })
  110. ws.close()
  111. else:
  112. choices = data["payload"]["choices"]
  113. status = choices["status"]
  114. content = choices["text"][0]["content"]
  115. if ws.streaming:
  116. self.queue.put({'data': content})
  117. else:
  118. self.blocking_message += content
  119. if status == 2:
  120. if not ws.streaming:
  121. self.queue.put({'data': self.blocking_message})
  122. ws.close()
  123. def gen_params(self, messages: list, user_id: str,
  124. model_kwargs: Optional[dict] = None) -> dict:
  125. data = {
  126. "header": {
  127. "app_id": self.app_id,
  128. "uid": user_id
  129. },
  130. "parameter": {
  131. "chat": {
  132. "domain": self.chat_domain
  133. }
  134. },
  135. "payload": {
  136. "message": {
  137. "text": messages
  138. }
  139. }
  140. }
  141. if model_kwargs:
  142. data['parameter']['chat'].update(model_kwargs)
  143. return data
  144. def subscribe(self):
  145. while True:
  146. content = self.queue.get()
  147. if 'error' in content:
  148. if content['status_code'] == 401:
  149. raise SparkError('[Spark] The credentials you provided are incorrect. '
  150. 'Please double-check and fill them in again.')
  151. elif content['status_code'] == 403:
  152. raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
  153. "Please try again after obtaining the necessary permissions.")
  154. else:
  155. raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
  156. if 'data' not in content:
  157. break
  158. yield content
  159. class SparkError(Exception):
  160. pass