openai_completion.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import re
  2. from collections.abc import Generator
  3. from time import time
  4. # import monkeypatch
  5. from typing import Any, Literal, Optional, Union
  6. from openai import AzureOpenAI, BadRequestError, OpenAI
  7. from openai._types import NOT_GIVEN, NotGiven
  8. from openai.resources.completions import Completions
  9. from openai.types import Completion as CompletionMessage
  10. from openai.types.completion import CompletionChoice
  11. from openai.types.completion_usage import CompletionUsage
  12. from core.model_runtime.errors.invoke import InvokeAuthorizationError
  13. class MockCompletionsClass:
  14. @staticmethod
  15. def mocked_openai_completion_create_sync(model: str) -> CompletionMessage:
  16. return CompletionMessage(
  17. id="cmpl-3QJQa5jXJ5Z5X",
  18. object="text_completion",
  19. created=int(time()),
  20. model=model,
  21. system_fingerprint="",
  22. choices=[
  23. CompletionChoice(
  24. text="mock",
  25. index=0,
  26. logprobs=None,
  27. finish_reason="stop",
  28. )
  29. ],
  30. usage=CompletionUsage(
  31. prompt_tokens=2,
  32. completion_tokens=1,
  33. total_tokens=3,
  34. ),
  35. )
  36. @staticmethod
  37. def mocked_openai_completion_create_stream(model: str) -> Generator[CompletionMessage, None, None]:
  38. full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
  39. for i in range(0, len(full_text) + 1):
  40. if i == len(full_text):
  41. yield CompletionMessage(
  42. id="cmpl-3QJQa5jXJ5Z5X",
  43. object="text_completion",
  44. created=int(time()),
  45. model=model,
  46. system_fingerprint="",
  47. choices=[
  48. CompletionChoice(
  49. text="",
  50. index=0,
  51. logprobs=None,
  52. finish_reason="stop",
  53. )
  54. ],
  55. usage=CompletionUsage(
  56. prompt_tokens=2,
  57. completion_tokens=17,
  58. total_tokens=19,
  59. ),
  60. )
  61. else:
  62. yield CompletionMessage(
  63. id="cmpl-3QJQa5jXJ5Z5X",
  64. object="text_completion",
  65. created=int(time()),
  66. model=model,
  67. system_fingerprint="",
  68. choices=[
  69. CompletionChoice(text=full_text[i], index=0, logprobs=None, finish_reason="content_filter")
  70. ],
  71. )
  72. def completion_create(
  73. self: Completions,
  74. *,
  75. model: Union[
  76. str,
  77. Literal[
  78. "babbage-002",
  79. "davinci-002",
  80. "gpt-3.5-turbo-instruct",
  81. "text-davinci-003",
  82. "text-davinci-002",
  83. "text-davinci-001",
  84. "code-davinci-002",
  85. "text-curie-001",
  86. "text-babbage-001",
  87. "text-ada-001",
  88. ],
  89. ],
  90. prompt: Union[str, list[str], list[int], list[list[int]], None],
  91. stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
  92. **kwargs: Any,
  93. ):
  94. openai_models = [
  95. "babbage-002",
  96. "davinci-002",
  97. "gpt-3.5-turbo-instruct",
  98. "text-davinci-003",
  99. "text-davinci-002",
  100. "text-davinci-001",
  101. "code-davinci-002",
  102. "text-curie-001",
  103. "text-babbage-001",
  104. "text-ada-001",
  105. ]
  106. azure_openai_models = ["gpt-35-turbo-instruct"]
  107. if not re.match(r"^(https?):\/\/[^\s\/$.?#].[^\s]*$", str(self._client.base_url)):
  108. raise InvokeAuthorizationError("Invalid base url")
  109. if model in openai_models + azure_openai_models:
  110. if not re.match(r"sk-[a-zA-Z0-9]{24,}$", self._client.api_key) and type(self._client) == OpenAI:
  111. # sometime, provider use OpenAI compatible API will not have api key or have different api key format
  112. # so we only check if model is in openai_models
  113. raise InvokeAuthorizationError("Invalid api key")
  114. if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
  115. raise InvokeAuthorizationError("Invalid api key")
  116. if not prompt:
  117. raise BadRequestError("Invalid prompt")
  118. if stream:
  119. return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
  120. return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)