| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201 | from typing import Optional, Any, Listimport openaifrom llama_index.embeddings.base import BaseEmbeddingfrom llama_index.embeddings.openai import OpenAIEmbeddingMode, OpenAIEmbeddingModelType, _QUERY_MODE_MODEL_DICT, \    _TEXT_MODE_MODEL_DICTfrom tenacity import wait_random_exponential, retry, stop_after_attemptfrom core.llm.error_handle_wraps import handle_llm_exceptions, handle_llm_exceptions_async@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))def get_embedding(        text: str,        engine: Optional[str] = None,        api_key: Optional[str] = None,        **kwargs) -> List[float]:    """Get embedding.    NOTE: Copied from OpenAI's embedding utils:    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py    Copied here to avoid importing unnecessary dependencies    like matplotlib, plotly, scipy, sklearn.    """    text = text.replace("\n", " ")    return openai.Embedding.create(input=[text], engine=engine, api_key=api_key, **kwargs)["data"][0]["embedding"]@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))async def aget_embedding(text: str, engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[    float]:    """Asynchronously get embedding.    NOTE: Copied from OpenAI's embedding utils:    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py    Copied here to avoid importing unnecessary dependencies    like matplotlib, plotly, scipy, sklearn.    """    # replace newlines, which can negatively affect performance.    text = text.replace("\n", " ")    return (await openai.Embedding.acreate(input=[text], engine=engine, api_key=api_key, **kwargs))["data"][0][        "embedding"    ]@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))def get_embeddings(        list_of_text: List[str],        engine: Optional[str] = None,        api_key: Optional[str] = None,        **kwargs) -> List[List[float]]:    """Get embeddings.    NOTE: Copied from OpenAI's embedding utils:    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py    Copied here to avoid importing unnecessary dependencies    like matplotlib, plotly, scipy, sklearn.    """    assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."    # replace newlines, which can negatively affect performance.    list_of_text = [text.replace("\n", " ") for text in list_of_text]    data = openai.Embedding.create(input=list_of_text, engine=engine, api_key=api_key, **kwargs).data    data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.    return [d["embedding"] for d in data]@retry(reraise=True, wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(6))async def aget_embeddings(        list_of_text: List[str], engine: Optional[str] = None, api_key: Optional[str] = None, **kwargs) -> List[List[float]]:    """Asynchronously get embeddings.    NOTE: Copied from OpenAI's embedding utils:    https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py    Copied here to avoid importing unnecessary dependencies    like matplotlib, plotly, scipy, sklearn.    """    assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048."    # replace newlines, which can negatively affect performance.    list_of_text = [text.replace("\n", " ") for text in list_of_text]    data = (await openai.Embedding.acreate(input=list_of_text, engine=engine, api_key=api_key, **kwargs)).data    data = sorted(data, key=lambda x: x["index"])  # maintain the same order as input.    return [d["embedding"] for d in data]class OpenAIEmbedding(BaseEmbedding):    def __init__(            self,            mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE,            model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002,            deployment_name: Optional[str] = None,            openai_api_key: Optional[str] = None,            **kwargs: Any,    ) -> None:        """Init params."""        new_kwargs = {}        if 'embed_batch_size' in kwargs:            new_kwargs['embed_batch_size'] = kwargs['embed_batch_size']        if 'tokenizer' in kwargs:            new_kwargs['tokenizer'] = kwargs['tokenizer']        super().__init__(**new_kwargs)        self.mode = OpenAIEmbeddingMode(mode)        self.model = OpenAIEmbeddingModelType(model)        self.deployment_name = deployment_name        self.openai_api_key = openai_api_key        self.openai_api_type = kwargs.get('openai_api_type')        self.openai_api_version = kwargs.get('openai_api_version')        self.openai_api_base = kwargs.get('openai_api_base')    @handle_llm_exceptions    def _get_query_embedding(self, query: str) -> List[float]:        """Get query embedding."""        if self.deployment_name is not None:            engine = self.deployment_name        else:            key = (self.mode, self.model)            if key not in _QUERY_MODE_MODEL_DICT:                raise ValueError(f"Invalid mode, model combination: {key}")            engine = _QUERY_MODE_MODEL_DICT[key]        return get_embedding(query, engine=engine, api_key=self.openai_api_key,                             api_type=self.openai_api_type, api_version=self.openai_api_version,                             api_base=self.openai_api_base)    def _get_text_embedding(self, text: str) -> List[float]:        """Get text embedding."""        if self.deployment_name is not None:            engine = self.deployment_name        else:            key = (self.mode, self.model)            if key not in _TEXT_MODE_MODEL_DICT:                raise ValueError(f"Invalid mode, model combination: {key}")            engine = _TEXT_MODE_MODEL_DICT[key]        return get_embedding(text, engine=engine, api_key=self.openai_api_key,                             api_type=self.openai_api_type, api_version=self.openai_api_version,                             api_base=self.openai_api_base)    async def _aget_text_embedding(self, text: str) -> List[float]:        """Asynchronously get text embedding."""        if self.deployment_name is not None:            engine = self.deployment_name        else:            key = (self.mode, self.model)            if key not in _TEXT_MODE_MODEL_DICT:                raise ValueError(f"Invalid mode, model combination: {key}")            engine = _TEXT_MODE_MODEL_DICT[key]        return await aget_embedding(text, engine=engine, api_key=self.openai_api_key,                                    api_type=self.openai_api_type, api_version=self.openai_api_version,                                    api_base=self.openai_api_base)    def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:        """Get text embeddings.        By default, this is a wrapper around _get_text_embedding.        Can be overriden for batch queries.        """        if self.deployment_name is not None:            engine = self.deployment_name        else:            key = (self.mode, self.model)            if key not in _TEXT_MODE_MODEL_DICT:                raise ValueError(f"Invalid mode, model combination: {key}")            engine = _TEXT_MODE_MODEL_DICT[key]        embeddings = get_embeddings(texts, engine=engine, api_key=self.openai_api_key,                                    api_type=self.openai_api_type, api_version=self.openai_api_version,                                    api_base=self.openai_api_base)        return embeddings    async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:        """Asynchronously get text embeddings."""        if self.deployment_name is not None:            engine = self.deployment_name        else:            key = (self.mode, self.model)            if key not in _TEXT_MODE_MODEL_DICT:                raise ValueError(f"Invalid mode, model combination: {key}")            engine = _TEXT_MODE_MODEL_DICT[key]        embeddings = await aget_embeddings(texts, engine=engine, api_key=self.openai_api_key,                                           api_type=self.openai_api_type, api_version=self.openai_api_version,                                           api_base=self.openai_api_base)        return embeddings
 |