streamable_open_ai.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. import os
  2. from langchain.callbacks.manager import Callbacks
  3. from langchain.schema import LLMResult
  4. from typing import Optional, List, Dict, Any, Mapping
  5. from langchain import OpenAI
  6. from pydantic import root_validator
  7. from core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async
  8. class StreamableOpenAI(OpenAI):
  9. @root_validator()
  10. def validate_environment(cls, values: Dict) -> Dict:
  11. """Validate that api key and python package exists in environment."""
  12. try:
  13. import openai
  14. values["client"] = openai.Completion
  15. except ImportError:
  16. raise ValueError(
  17. "Could not import openai python package. "
  18. "Please install it with `pip install openai`."
  19. )
  20. if values["streaming"] and values["n"] > 1:
  21. raise ValueError("Cannot stream results when n > 1.")
  22. if values["streaming"] and values["best_of"] > 1:
  23. raise ValueError("Cannot stream results when best_of > 1.")
  24. return values
  25. @property
  26. def _invocation_params(self) -> Dict[str, Any]:
  27. return {**super()._invocation_params, **{
  28. "api_type": 'openai',
  29. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  30. "api_version": None,
  31. "api_key": self.openai_api_key,
  32. "organization": self.openai_organization if self.openai_organization else None,
  33. }}
  34. @property
  35. def _identifying_params(self) -> Mapping[str, Any]:
  36. return {**super()._identifying_params, **{
  37. "api_type": 'openai',
  38. "api_base": os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1"),
  39. "api_version": None,
  40. "api_key": self.openai_api_key,
  41. "organization": self.openai_organization if self.openai_organization else None,
  42. }}
  43. @handle_llm_exceptions
  44. def generate(
  45. self,
  46. prompts: List[str],
  47. stop: Optional[List[str]] = None,
  48. callbacks: Callbacks = None,
  49. **kwargs: Any,
  50. ) -> LLMResult:
  51. return super().generate(prompts, stop, callbacks, **kwargs)
  52. @handle_llm_exceptions_async
  53. async def agenerate(
  54. self,
  55. prompts: List[str],
  56. stop: Optional[List[str]] = None,
  57. callbacks: Callbacks = None,
  58. **kwargs: Any,
  59. ) -> LLMResult:
  60. return await super().agenerate(prompts, stop, callbacks, **kwargs)