streamable_open_ai.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import os
  2. from langchain.callbacks.manager import Callbacks
  3. from langchain.schema import LLMResult
  4. from typing import Optional, List, Dict, Any, Mapping, Union, Tuple
  5. from langchain import OpenAI
  6. from pydantic import root_validator
  7. from core.llm.wrappers.openai_wrapper import handle_openai_exceptions
  8. class StreamableOpenAI(OpenAI):
  9. request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
  10. """Timeout for requests to OpenAI completion API. Default is 600 seconds."""
  11. max_retries: int = 1
  12. """Maximum number of retries to make when generating."""
  13. @root_validator()
  14. def validate_environment(cls, values: Dict) -> Dict:
  15. """Validate that api key and python package exists in environment."""
  16. try:
  17. import openai
  18. values["client"] = openai.Completion
  19. except ImportError:
  20. raise ValueError(
  21. "Could not import openai python package. "
  22. "Please install it with `pip install openai`."
  23. )
  24. if values["streaming"] and values["n"] > 1:
  25. raise ValueError("Cannot stream results when n > 1.")
  26. if values["streaming"] and values["best_of"] > 1:
  27. raise ValueError("Cannot stream results when best_of > 1.")
  28. return values
  29. @property
  30. def _invocation_params(self) -> Dict[str, Any]:
  31. return {**super()._invocation_params, **{
  32. "api_type": 'openai',
  33. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  34. "api_version": None,
  35. "api_key": self.openai_api_key,
  36. "organization": self.openai_organization if self.openai_organization else None,
  37. }}
  38. @property
  39. def _identifying_params(self) -> Mapping[str, Any]:
  40. return {**super()._identifying_params, **{
  41. "api_type": 'openai',
  42. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  43. "api_version": None,
  44. "api_key": self.openai_api_key,
  45. "organization": self.openai_organization if self.openai_organization else None,
  46. }}
  47. @handle_openai_exceptions
  48. def generate(
  49. self,
  50. prompts: List[str],
  51. stop: Optional[List[str]] = None,
  52. callbacks: Callbacks = None,
  53. **kwargs: Any,
  54. ) -> LLMResult:
  55. return super().generate(prompts, stop, callbacks, **kwargs)
  56. @classmethod
  57. def get_kwargs_from_model_params(cls, params: dict):
  58. return params