openai_chat.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import re
  2. from collections.abc import Generator
  3. from json import dumps
  4. from time import time
  5. # import monkeypatch
  6. from typing import Any, Literal, Optional, Union
  7. from openai import AzureOpenAI, OpenAI
  8. from openai._types import NOT_GIVEN, NotGiven
  9. from openai.resources.chat.completions import Completions
  10. from openai.types import Completion as CompletionMessage
  11. from openai.types.chat import (
  12. ChatCompletionChunk,
  13. ChatCompletionMessageParam,
  14. ChatCompletionMessageToolCall,
  15. ChatCompletionToolParam,
  16. completion_create_params,
  17. )
  18. from openai.types.chat.chat_completion import ChatCompletion as _ChatCompletion
  19. from openai.types.chat.chat_completion import Choice as _ChatCompletionChoice
  20. from openai.types.chat.chat_completion_chunk import (
  21. Choice,
  22. ChoiceDelta,
  23. ChoiceDeltaFunctionCall,
  24. ChoiceDeltaToolCall,
  25. ChoiceDeltaToolCallFunction,
  26. )
  27. from openai.types.chat.chat_completion_message import ChatCompletionMessage, FunctionCall
  28. from openai.types.chat.chat_completion_message_tool_call import Function
  29. from openai.types.completion_usage import CompletionUsage
  30. from core.model_runtime.errors.invoke import InvokeAuthorizationError
  31. class MockChatClass:
  32. @staticmethod
  33. def generate_function_call(
  34. functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
  35. ) -> Optional[FunctionCall]:
  36. if not functions or len(functions) == 0:
  37. return None
  38. function: completion_create_params.Function = functions[0]
  39. function_name = function["name"]
  40. function_description = function["description"]
  41. function_parameters = function["parameters"]
  42. function_parameters_type = function_parameters["type"]
  43. if function_parameters_type != "object":
  44. return None
  45. function_parameters_properties = function_parameters["properties"]
  46. function_parameters_required = function_parameters["required"]
  47. parameters = {}
  48. for parameter_name, parameter in function_parameters_properties.items():
  49. if parameter_name not in function_parameters_required:
  50. continue
  51. parameter_type = parameter["type"]
  52. if parameter_type == "string":
  53. if "enum" in parameter:
  54. if len(parameter["enum"]) == 0:
  55. continue
  56. parameters[parameter_name] = parameter["enum"][0]
  57. else:
  58. parameters[parameter_name] = "kawaii"
  59. elif parameter_type == "integer":
  60. parameters[parameter_name] = 114514
  61. elif parameter_type == "number":
  62. parameters[parameter_name] = 1919810.0
  63. elif parameter_type == "boolean":
  64. parameters[parameter_name] = True
  65. return FunctionCall(name=function_name, arguments=dumps(parameters))
  66. @staticmethod
  67. def generate_tool_calls(tools=NOT_GIVEN) -> Optional[list[ChatCompletionMessageToolCall]]:
  68. list_tool_calls = []
  69. if not tools or len(tools) == 0:
  70. return None
  71. tool = tools[0]
  72. if "type" in tools and tools["type"] != "function":
  73. return None
  74. function = tool["function"]
  75. function_call = MockChatClass.generate_function_call(functions=[function])
  76. if function_call is None:
  77. return None
  78. list_tool_calls.append(
  79. ChatCompletionMessageToolCall(
  80. id="sakurajima-mai",
  81. function=Function(
  82. name=function_call.name,
  83. arguments=function_call.arguments,
  84. ),
  85. type="function",
  86. )
  87. )
  88. return list_tool_calls
  89. @staticmethod
  90. def mocked_openai_chat_create_sync(
  91. model: str,
  92. functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
  93. tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
  94. ) -> CompletionMessage:
  95. tool_calls = []
  96. function_call = MockChatClass.generate_function_call(functions=functions)
  97. if not function_call:
  98. tool_calls = MockChatClass.generate_tool_calls(tools=tools)
  99. return _ChatCompletion(
  100. id="cmpl-3QJQa5jXJ5Z5X",
  101. choices=[
  102. _ChatCompletionChoice(
  103. finish_reason="content_filter",
  104. index=0,
  105. message=ChatCompletionMessage(
  106. content="elaina", role="assistant", function_call=function_call, tool_calls=tool_calls
  107. ),
  108. )
  109. ],
  110. created=int(time()),
  111. model=model,
  112. object="chat.completion",
  113. system_fingerprint="",
  114. usage=CompletionUsage(
  115. prompt_tokens=2,
  116. completion_tokens=1,
  117. total_tokens=3,
  118. ),
  119. )
  120. @staticmethod
  121. def mocked_openai_chat_create_stream(
  122. model: str,
  123. functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
  124. tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
  125. ) -> Generator[ChatCompletionChunk, None, None]:
  126. tool_calls = []
  127. function_call = MockChatClass.generate_function_call(functions=functions)
  128. if not function_call:
  129. tool_calls = MockChatClass.generate_tool_calls(tools=tools)
  130. full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
  131. for i in range(0, len(full_text) + 1):
  132. if i == len(full_text):
  133. yield ChatCompletionChunk(
  134. id="cmpl-3QJQa5jXJ5Z5X",
  135. choices=[
  136. Choice(
  137. delta=ChoiceDelta(
  138. content="",
  139. function_call=ChoiceDeltaFunctionCall(
  140. name=function_call.name,
  141. arguments=function_call.arguments,
  142. )
  143. if function_call
  144. else None,
  145. role="assistant",
  146. tool_calls=[
  147. ChoiceDeltaToolCall(
  148. index=0,
  149. id="misaka-mikoto",
  150. function=ChoiceDeltaToolCallFunction(
  151. name=tool_calls[0].function.name,
  152. arguments=tool_calls[0].function.arguments,
  153. ),
  154. type="function",
  155. )
  156. ]
  157. if tool_calls and len(tool_calls) > 0
  158. else None,
  159. ),
  160. finish_reason="function_call",
  161. index=0,
  162. )
  163. ],
  164. created=int(time()),
  165. model=model,
  166. object="chat.completion.chunk",
  167. system_fingerprint="",
  168. usage=CompletionUsage(
  169. prompt_tokens=2,
  170. completion_tokens=17,
  171. total_tokens=19,
  172. ),
  173. )
  174. else:
  175. yield ChatCompletionChunk(
  176. id="cmpl-3QJQa5jXJ5Z5X",
  177. choices=[
  178. Choice(
  179. delta=ChoiceDelta(
  180. content=full_text[i],
  181. role="assistant",
  182. ),
  183. finish_reason="content_filter",
  184. index=0,
  185. )
  186. ],
  187. created=int(time()),
  188. model=model,
  189. object="chat.completion.chunk",
  190. system_fingerprint="",
  191. )
  192. def chat_create(
  193. self: Completions,
  194. *,
  195. messages: list[ChatCompletionMessageParam],
  196. model: Union[
  197. str,
  198. Literal[
  199. "gpt-4-1106-preview",
  200. "gpt-4-vision-preview",
  201. "gpt-4",
  202. "gpt-4-0314",
  203. "gpt-4-0613",
  204. "gpt-4-32k",
  205. "gpt-4-32k-0314",
  206. "gpt-4-32k-0613",
  207. "gpt-3.5-turbo-1106",
  208. "gpt-3.5-turbo",
  209. "gpt-3.5-turbo-16k",
  210. "gpt-3.5-turbo-0301",
  211. "gpt-3.5-turbo-0613",
  212. "gpt-3.5-turbo-16k-0613",
  213. ],
  214. ],
  215. functions: list[completion_create_params.Function] | NotGiven = NOT_GIVEN,
  216. response_format: completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
  217. stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
  218. tools: list[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
  219. **kwargs: Any,
  220. ):
  221. openai_models = [
  222. "gpt-4-1106-preview",
  223. "gpt-4-vision-preview",
  224. "gpt-4",
  225. "gpt-4-0314",
  226. "gpt-4-0613",
  227. "gpt-4-32k",
  228. "gpt-4-32k-0314",
  229. "gpt-4-32k-0613",
  230. "gpt-3.5-turbo-1106",
  231. "gpt-3.5-turbo",
  232. "gpt-3.5-turbo-16k",
  233. "gpt-3.5-turbo-0301",
  234. "gpt-3.5-turbo-0613",
  235. "gpt-3.5-turbo-16k-0613",
  236. ]
  237. azure_openai_models = ["gpt35", "gpt-4v", "gpt-35-turbo"]
  238. if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
  239. raise InvokeAuthorizationError("Invalid base url")
  240. if model in openai_models + azure_openai_models:
  241. if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
  242. # sometime, provider use OpenAI compatible API will not have api key or have different api key format
  243. # so we only check if model is in openai_models
  244. raise InvokeAuthorizationError("Invalid api key")
  245. if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
  246. raise InvokeAuthorizationError("Invalid api key")
  247. if stream:
  248. return MockChatClass.mocked_openai_chat_create_stream(model=model, functions=functions, tools=tools)
  249. return MockChatClass.mocked_openai_chat_create_sync(model=model, functions=functions, tools=tools)