spark_llm.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import base64
  2. import datetime
  3. import hashlib
  4. import hmac
  5. import json
  6. import queue
  7. from typing import Optional
  8. from urllib.parse import urlparse
  9. import ssl
  10. from datetime import datetime
  11. from time import mktime
  12. from urllib.parse import urlencode
  13. from wsgiref.handlers import format_date_time
  14. import websocket
  15. class SparkLLMClient:
  16. def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
  17. domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
  18. model_api_configs = {
  19. 'spark': {
  20. 'version': 'v1.1',
  21. 'chat_domain': 'general'
  22. },
  23. 'spark-v2': {
  24. 'version': 'v2.1',
  25. 'chat_domain': 'generalv2'
  26. },
  27. 'spark-v3': {
  28. 'version': 'v3.1',
  29. 'chat_domain': 'generalv3'
  30. }
  31. }
  32. api_version = model_api_configs[model_name]['version']
  33. self.chat_domain = model_api_configs[model_name]['chat_domain']
  34. self.api_base = f"wss://{domain}/{api_version}/chat"
  35. self.app_id = app_id
  36. self.ws_url = self.create_url(
  37. urlparse(self.api_base).netloc,
  38. urlparse(self.api_base).path,
  39. self.api_base,
  40. api_key,
  41. api_secret
  42. )
  43. self.queue = queue.Queue()
  44. self.blocking_message = ''
  45. def create_url(self, host: str, path: str, api_base: str, api_key: str, api_secret: str) -> str:
  46. # generate timestamp by RFC1123
  47. now = datetime.now()
  48. date = format_date_time(mktime(now.timetuple()))
  49. signature_origin = "host: " + host + "\n"
  50. signature_origin += "date: " + date + "\n"
  51. signature_origin += "GET " + path + " HTTP/1.1"
  52. # encrypt using hmac-sha256
  53. signature_sha = hmac.new(api_secret.encode('utf-8'), signature_origin.encode('utf-8'),
  54. digestmod=hashlib.sha256).digest()
  55. signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
  56. authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
  57. authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
  58. v = {
  59. "authorization": authorization,
  60. "date": date,
  61. "host": host
  62. }
  63. # generate url
  64. url = api_base + '?' + urlencode(v)
  65. return url
  66. def run(self, messages: list, user_id: str,
  67. model_kwargs: Optional[dict] = None, streaming: bool = False):
  68. websocket.enableTrace(False)
  69. ws = websocket.WebSocketApp(
  70. self.ws_url,
  71. on_message=self.on_message,
  72. on_error=self.on_error,
  73. on_close=self.on_close,
  74. on_open=self.on_open
  75. )
  76. ws.messages = messages
  77. ws.user_id = user_id
  78. ws.model_kwargs = model_kwargs
  79. ws.streaming = streaming
  80. ws.run_forever(sslopt={"cert_reqs": ssl.CERT_NONE})
  81. def on_error(self, ws, error):
  82. self.queue.put({
  83. 'status_code': error.status_code,
  84. 'error': error.resp_body.decode('utf-8')
  85. })
  86. ws.close()
  87. def on_close(self, ws, close_status_code, close_reason):
  88. self.queue.put({'done': True})
  89. def on_open(self, ws):
  90. self.blocking_message = ''
  91. data = json.dumps(self.gen_params(
  92. messages=ws.messages,
  93. user_id=ws.user_id,
  94. model_kwargs=ws.model_kwargs
  95. ))
  96. ws.send(data)
  97. def on_message(self, ws, message):
  98. data = json.loads(message)
  99. code = data['header']['code']
  100. if code != 0:
  101. self.queue.put({
  102. 'status_code': 400,
  103. 'error': f"Code: {code}, Error: {data['header']['message']}"
  104. })
  105. ws.close()
  106. else:
  107. choices = data["payload"]["choices"]
  108. status = choices["status"]
  109. content = choices["text"][0]["content"]
  110. if ws.streaming:
  111. self.queue.put({'data': content})
  112. else:
  113. self.blocking_message += content
  114. if status == 2:
  115. if not ws.streaming:
  116. self.queue.put({'data': self.blocking_message})
  117. ws.close()
  118. def gen_params(self, messages: list, user_id: str,
  119. model_kwargs: Optional[dict] = None) -> dict:
  120. data = {
  121. "header": {
  122. "app_id": self.app_id,
  123. "uid": user_id
  124. },
  125. "parameter": {
  126. "chat": {
  127. "domain": self.chat_domain
  128. }
  129. },
  130. "payload": {
  131. "message": {
  132. "text": messages
  133. }
  134. }
  135. }
  136. if model_kwargs:
  137. data['parameter']['chat'].update(model_kwargs)
  138. return data
  139. def subscribe(self):
  140. while True:
  141. content = self.queue.get()
  142. if 'error' in content:
  143. if content['status_code'] == 401:
  144. raise SparkError('[Spark] The credentials you provided are incorrect. '
  145. 'Please double-check and fill them in again.')
  146. elif content['status_code'] == 403:
  147. raise SparkError("[Spark] Sorry, the credentials you provided are access denied. "
  148. "Please try again after obtaining the necessary permissions.")
  149. else:
  150. raise SparkError(f"[Spark] code: {content['status_code']}, error: {content['error']}")
  151. if 'data' not in content:
  152. break
  153. yield content
  154. class SparkError(Exception):
  155. pass