openai_completion.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. import re
  2. from collections.abc import Generator
  3. from time import time
  4. # import monkeypatch
  5. from typing import Any, Literal, Optional, Union
  6. from openai import AzureOpenAI, BadRequestError, OpenAI
  7. from openai._types import NOT_GIVEN, NotGiven
  8. from openai.resources.completions import Completions
  9. from openai.types import Completion as CompletionMessage
  10. from openai.types.completion import CompletionChoice
  11. from openai.types.completion_usage import CompletionUsage
  12. from core.model_runtime.errors.invoke import InvokeAuthorizationError
  13. class MockCompletionsClass:
  14. @staticmethod
  15. def mocked_openai_completion_create_sync(
  16. model: str
  17. ) -> CompletionMessage:
  18. return CompletionMessage(
  19. id="cmpl-3QJQa5jXJ5Z5X",
  20. object="text_completion",
  21. created=int(time()),
  22. model=model,
  23. system_fingerprint="",
  24. choices=[
  25. CompletionChoice(
  26. text="mock",
  27. index=0,
  28. logprobs=None,
  29. finish_reason="stop",
  30. )
  31. ],
  32. usage=CompletionUsage(
  33. prompt_tokens=2,
  34. completion_tokens=1,
  35. total_tokens=3,
  36. )
  37. )
  38. @staticmethod
  39. def mocked_openai_completion_create_stream(
  40. model: str
  41. ) -> Generator[CompletionMessage, None, None]:
  42. full_text = "Hello, world!\n\n```python\nprint('Hello, world!')\n```"
  43. for i in range(0, len(full_text) + 1):
  44. if i == len(full_text):
  45. yield CompletionMessage(
  46. id="cmpl-3QJQa5jXJ5Z5X",
  47. object="text_completion",
  48. created=int(time()),
  49. model=model,
  50. system_fingerprint="",
  51. choices=[
  52. CompletionChoice(
  53. text="",
  54. index=0,
  55. logprobs=None,
  56. finish_reason="stop",
  57. )
  58. ],
  59. usage=CompletionUsage(
  60. prompt_tokens=2,
  61. completion_tokens=17,
  62. total_tokens=19,
  63. ),
  64. )
  65. else:
  66. yield CompletionMessage(
  67. id="cmpl-3QJQa5jXJ5Z5X",
  68. object="text_completion",
  69. created=int(time()),
  70. model=model,
  71. system_fingerprint="",
  72. choices=[
  73. CompletionChoice(
  74. text=full_text[i],
  75. index=0,
  76. logprobs=None,
  77. finish_reason="content_filter"
  78. )
  79. ],
  80. )
  81. def completion_create(self: Completions, *, model: Union[
  82. str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo-instruct",
  83. "text-davinci-003", "text-davinci-002", "text-davinci-001",
  84. "code-davinci-002", "text-curie-001", "text-babbage-001",
  85. "text-ada-001"],
  86. ],
  87. prompt: Union[str, list[str], list[int], list[list[int]], None],
  88. stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN,
  89. **kwargs: Any
  90. ):
  91. openai_models = [
  92. "babbage-002", "davinci-002", "gpt-3.5-turbo-instruct", "text-davinci-003", "text-davinci-002", "text-davinci-001",
  93. "code-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001",
  94. ]
  95. azure_openai_models = [
  96. "gpt-35-turbo-instruct"
  97. ]
  98. if not re.match(r'^(https?):\/\/[^\s\/$.?#].[^\s]*$', self._client.base_url.__str__()):
  99. raise InvokeAuthorizationError('Invalid base url')
  100. if model in openai_models + azure_openai_models:
  101. if not re.match(r'sk-[a-zA-Z0-9]{24,}$', self._client.api_key) and type(self._client) == OpenAI:
  102. # sometime, provider use OpenAI compatible API will not have api key or have different api key format
  103. # so we only check if model is in openai_models
  104. raise InvokeAuthorizationError('Invalid api key')
  105. if len(self._client.api_key) < 18 and type(self._client) == AzureOpenAI:
  106. raise InvokeAuthorizationError('Invalid api key')
  107. if not prompt:
  108. raise BadRequestError('Invalid prompt')
  109. if stream:
  110. return MockCompletionsClass.mocked_openai_completion_create_stream(model=model)
  111. return MockCompletionsClass.mocked_openai_completion_create_sync(model=model)