openai_chat.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. import re
  2. from json import dumps, loads
  3. from time import sleep, time
  4. # import monkeypatch
  5. from typing import Any, Generator, List, Literal, Optional, Union
  6. import openai.types.chat.completion_create_params as completion_create_params
  7. from core.model_runtime.errors.invoke import InvokeAuthorizationError
  8. from openai import AzureOpenAI, OpenAI
  9. from openai._types import NOT_GIVEN, NotGiven
  10. from openai.resources.chat.completions import Completions
  11. from openai.types import Completion as CompletionMessage
  12. from openai.types.chat import (ChatCompletion, ChatCompletionChunk, ChatCompletionMessageParam,
  13. ChatCompletionMessageToolCall, ChatCompletionToolChoiceOptionParam,
  14. ChatCompletionToolParam)
  15. from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion
  16. from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice
  17. from openai.types.chat.chat_completion_chunk import (Choice, ChoiceDelta, ChoiceDeltaFunctionCall, ChoiceDeltaToolCall,
  18. ChoiceDeltaToolCallFunction)
  19. from openai.types.chat.chat_completion_message import ChatCompletionMessage, FunctionCall
  20. from openai.types.chat.chat_completion_message_tool_call import Function
  21. from openai.types.completion_usage import CompletionUsage
  22. class MockChatClass(object):
  23. @staticmethod
  24. def generate_function_call(
  25. functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
  26. ) -> Optional[FunctionCall]:
  27. if not functions or len(functions) == 0:
  28. return None
  29. function: completion_create_params.Function = functions[0]
  30. function_name = function['name']
  31. function_description = function['description']
  32. function_parameters = function['parameters']
  33. function_parameters_type = function_parameters['type']
  34. if function_parameters_type != 'object':
  35. return None
  36. function_parameters_properties = function_parameters['properties']
  37. function_parameters_required = function_parameters['required']
  38. parameters = {}
  39. for parameter_name, parameter in function_parameters_properties.items():
  40. if parameter_name not in function_parameters_required:
  41. continue
  42. parameter_type = parameter['type']
  43. if parameter_type == 'string':
  44. if 'enum' in parameter:
  45. if len(parameter['enum']) == 0:
  46. continue
  47. parameters[parameter_name] = parameter['enum'][0]
  48. else:
  49. parameters[parameter_name] = 'kawaii'
  50. elif parameter_type == 'integer':
  51. parameters[parameter_name] = 114514
  52. elif parameter_type == 'number':
  53. parameters[parameter_name] = 1919810.0
  54. elif parameter_type == 'boolean':
  55. parameters[parameter_name] = True
  56. return FunctionCall(name=function_name, arguments=dumps(parameters))
  57. @staticmethod
  58. def generate_tool_calls(
  59. tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
  60. ) -> Optional[List[ChatCompletionMessageToolCall]]:
  61. list_tool_calls = []
  62. if not tools or len(tools) == 0:
  63. return None
  64. tool: ChatCompletionToolParam = tools[0]
  65. if tools['type'] != 'function':
  66. return None
  67. function = tool['function']
  68. function_call = MockChatClass.generate_function_call(functions=[function])
  69. if function_call is None:
  70. return None
  71. list_tool_calls.append(ChatCompletionMessageToolCall(
  72. id='sakurajima-mai',
  73. function=Function(
  74. name=function_call.name,
  75. arguments=function_call.arguments,
  76. ),
  77. type='function'
  78. ))
  79. return list_tool_calls
  80. @staticmethod
  81. def mocked_openai_chat_create_sync(
  82. model: str,
  83. functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
  84. tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
  85. ) -> CompletionMessage:
  86. tool_calls = []
  87. function_call = MockChatClass.generate_function_call(functions=functions)
  88. if not function_call:
  89. tool_calls = MockChatClass.generate_tool_calls(tools=tools)
  90. sleep(1)
  91. return _ChatCompletion(
  92. id='cmpl-3QJQa5jXJ5Z5X',
  93. choices=[
  94. _ChatCompletionChoice(
  95. finish_reason='content_filter',
  96. index=0,
  97. message=ChatCompletionMessage(
  98. content='elaina',
  99. role='assistant',
  100. function_call=function_call,
  101. tool_calls=tool_calls
  102. )
  103. )
  104. ],
  105. created=int(time()),
  106. model=model,
  107. object='chat.completion',
  108. system_fingerprint='',
  109. usage=CompletionUsage(
  110. prompt_tokens=2,
  111. completion_tokens=1,
  112. total_tokens=3,
  113. )
  114. )
  115. @staticmethod
  116. def mocked_openai_chat_create_stream(
  117. model: str,
  118. functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
  119. tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
  120. ) -> Generator[ChatCompletionChunk, None, None]:
  121. tool_calls = []
  122. function_call = MockChatClass.generate_function_call(functions=functions)
  123. if not function_call:
  124. tool_calls = MockChatClass.generate_tool_calls(tools=tools)
  125. full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
  126. for i in range(0, len(full_text) + 1):
  127. sleep(0.1)
  128. if i == len(full_text):
  129. yield ChatCompletionChunk(
  130. id='cmpl-3QJQa5jXJ5Z5X',
  131. choices=[
  132. Choice(
  133. delta=ChoiceDelta(
  134. content='',
  135. function_call=ChoiceDeltaFunctionCall(
  136. name=function_call.name,
  137. arguments=function_call.arguments,
  138. ) if function_call else None,
  139. role='assistant',
  140. tool_calls=[
  141. ChoiceDeltaToolCall(
  142. index=0,
  143. id='misaka-mikoto',
  144. function=ChoiceDeltaToolCallFunction(
  145. name=tool_calls[0].function.name,
  146. arguments=tool_calls[0].function.arguments,
  147. ),
  148. type='function'
  149. )
  150. ] if tool_calls and len(tool_calls) > 0 else None
  151. ),
  152. finish_reason='function_call',
  153. index=0,
  154. )
  155. ],
  156. created=int(time()),
  157. model=model,
  158. object='chat.completion.chunk',
  159. system_fingerprint='',
  160. usage=CompletionUsage(
  161. prompt_tokens=2,
  162. completion_tokens=17,
  163. total_tokens=19,
  164. ),
  165. )
  166. else:
  167. yield ChatCompletionChunk(
  168. id='cmpl-3QJQa5jXJ5Z5X',
  169. choices=[
  170. Choice(
  171. delta=ChoiceDelta(
  172. content=full_text[i],
  173. role='assistant',
  174. ),
  175. finish_reason='content_filter',
  176. index=0,
  177. )
  178. ],
  179. created=int(time()),
  180. model=model,
  181. object='chat.completion.chunk',
  182. system_fingerprint='',
  183. )
  184. def chat_create(self: Completions, *,
  185. messages: List[ChatCompletionMessageParam],
  186. model: Union[str,Literal[
  187. "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
  188. "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
  189. "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
  190. "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613"],
  191. ],
  192. functions: List[completion_create_params.Function] | NotGiven = NOT_GIVEN,
  193. response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
  194. stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
  195. tools: List[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
  196. **kwargs: Any,
  197. ):
  198. openai_models = [
  199. "gpt-4-1106-preview", "gpt-4-vision-preview", "gpt-4", "gpt-4-0314", "gpt-4-0613",
  200. "gpt-4-32k", "gpt-4-32k-0314", "gpt-4-32k-0613",
  201. "gpt-3.5-turbo-1106", "gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0301",
  202. "gpt-3.5-turbo-0613", "gpt-3.5-turbo-16k-0613",
  203. ]
  204. azure_openai_models = [
  205. "gpt35", "gpt-4v", "gpt-35-turbo"
  206. ]
  207. if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
  208. raise InvokeAuthorizationError('Invalid base url')
  209. if model in openai_models + azure_openai_models:
  210. if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
  211. # sometime, provider use OpenAI compatible API will not have api key or have different api key format
  212. # so we only check if model is in openai_models
  213. raise InvokeAuthorizationError('Invalid api key')
  214. if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
  215. raise InvokeAuthorizationError('Invalid api key')
  216. if stream:
  217. return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
  218. return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)