"""Functionality for splitting text.""" from __future__ import annotations from typing import Any, Optional from core.model_manager import ModelInstance from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.rag.splitter.text_splitter import ( TS, Collection, Literal, RecursiveCharacterTextSplitter, Set, TokenTextSplitter, Union, ) class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): """ This class is used to implement from_gpt2_encoder, to prevent using of tiktoken """ @classmethod def from_encoder( cls: type[TS], embedding_model_instance: Optional[ModelInstance], allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037 disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037 **kwargs: Any, ): def _token_encoder(texts: list[str]) -> list[int]: if not texts: return [] if embedding_model_instance: return embedding_model_instance.get_text_embedding_num_tokens(texts=texts) else: return [GPT2Tokenizer.get_num_tokens(text) for text in texts] if issubclass(cls, TokenTextSplitter): extra_kwargs = { "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2", "allowed_special": allowed_special, "disallowed_special": disallowed_special, } kwargs = {**kwargs, **extra_kwargs} return cls(length_function=_token_encoder, **kwargs) class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter): def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any): """Create a new TextSplitter.""" super().__init__(**kwargs) self._fixed_separator = fixed_separator self._separators = separators or ["\n\n", "\n", " ", ""] def split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" if self._fixed_separator: chunks = text.split(self._fixed_separator) else: chunks = [text] final_chunks = [] chunks_lengths = self._length_function(chunks) for chunk, chunk_length in zip(chunks, chunks_lengths): if chunk_length > self._chunk_size: final_chunks.extend(self.recursive_split_text(chunk)) else: final_chunks.append(chunk) return final_chunks def recursive_split_text(self, text: str) -> list[str]: """Split incoming text and return chunks.""" final_chunks = [] separator = self._separators[-1] new_separators = [] for i, _s in enumerate(self._separators): if _s == "": separator = _s break if _s in text: separator = _s new_separators = self._separators[i + 1 :] break # Now that we have the separator, split the text if separator: if separator == " ": splits = text.split() else: splits = text.split(separator) else: splits = list(text) splits = [s for s in splits if (s not in {"", "\n"})] _good_splits = [] _good_splits_lengths = [] # cache the lengths of the splits _separator = "" if self._keep_separator else separator s_lens = self._length_function(splits) if _separator != "": for s, s_len in zip(splits, s_lens): if s_len < self._chunk_size: _good_splits.append(s) _good_splits_lengths.append(s_len) else: if _good_splits: merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) final_chunks.extend(merged_text) _good_splits = [] _good_splits_lengths = [] if not new_separators: final_chunks.append(s) else: other_info = self._split_text(s, new_separators) final_chunks.extend(other_info) if _good_splits: merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) final_chunks.extend(merged_text) else: current_part = "" current_length = 0 overlap_part = "" overlap_part_length = 0 for s, s_len in zip(splits, s_lens): if current_length + s_len <= self._chunk_size - self._chunk_overlap: current_part += s current_length += s_len elif current_length + s_len <= self._chunk_size: current_part += s current_length += s_len overlap_part += s overlap_part_length += s_len else: final_chunks.append(current_part) current_part = overlap_part + s current_length = s_len + overlap_part_length overlap_part = "" overlap_part_length = 0 if current_part: final_chunks.append(current_part) return final_chunks