| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 | """Functionality for splitting text."""from __future__ import annotationsfrom typing import Any, Optionalfrom core.model_manager import ModelInstancefrom core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizerfrom core.rag.splitter.text_splitter import (    TS,    Collection,    Literal,    RecursiveCharacterTextSplitter,    Set,    TokenTextSplitter,    Union,)class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):    """    This class is used to implement from_gpt2_encoder, to prevent using of tiktoken    """    @classmethod    def from_encoder(        cls: type[TS],        embedding_model_instance: Optional[ModelInstance],        allowed_special: Union[Literal["all"], Set[str]] = set(),  # noqa: UP037        disallowed_special: Union[Literal["all"], Collection[str]] = "all",  # noqa: UP037        **kwargs: Any,    ):        def _token_encoder(text: str) -> int:            if not text:                return 0            if embedding_model_instance:                return embedding_model_instance.get_text_embedding_num_tokens(texts=[text])            else:                return GPT2Tokenizer.get_num_tokens(text)        if issubclass(cls, TokenTextSplitter):            extra_kwargs = {                "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",                "allowed_special": allowed_special,                "disallowed_special": disallowed_special,            }            kwargs = {**kwargs, **extra_kwargs}        return cls(length_function=_token_encoder, **kwargs)class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):    def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any):        """Create a new TextSplitter."""        super().__init__(**kwargs)        self._fixed_separator = fixed_separator        self._separators = separators or ["\n\n", "\n", " ", ""]    def split_text(self, text: str) -> list[str]:        """Split incoming text and return chunks."""        if self._fixed_separator:            chunks = text.split(self._fixed_separator)        else:            chunks = [text]        final_chunks = []        for chunk in chunks:            if self._length_function(chunk) > self._chunk_size:                final_chunks.extend(self.recursive_split_text(chunk))            else:                final_chunks.append(chunk)        return final_chunks    def recursive_split_text(self, text: str) -> list[str]:        """Split incoming text and return chunks."""        final_chunks = []        # Get appropriate separator to use        separator = self._separators[-1]        for _s in self._separators:            if _s == "":                separator = _s                break            if _s in text:                separator = _s                break        # Now that we have the separator, split the text        if separator:            splits = text.split(separator)        else:            splits = list(text)        # Now go merging things, recursively splitting longer texts.        _good_splits = []        _good_splits_lengths = []  # cache the lengths of the splits        for s in splits:            s_len = self._length_function(s)            if s_len < self._chunk_size:                _good_splits.append(s)                _good_splits_lengths.append(s_len)            else:                if _good_splits:                    merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths)                    final_chunks.extend(merged_text)                    _good_splits = []                    _good_splits_lengths = []                other_info = self.recursive_split_text(s)                final_chunks.extend(other_info)        if _good_splits:            merged_text = self._merge_splits(_good_splits, separator, _good_splits_lengths)            final_chunks.extend(merged_text)        return final_chunks
 |