123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- """Functionality for splitting text."""
- from __future__ import annotations
- from typing import Any, Optional
- from core.model_manager import ModelInstance
- from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
- from core.rag.splitter.text_splitter import (
- TS,
- Collection,
- Literal,
- RecursiveCharacterTextSplitter,
- Set,
- TokenTextSplitter,
- Union,
- )
- class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
- """
- This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
- """
- @classmethod
- def from_encoder(
- cls: type[TS],
- embedding_model_instance: Optional[ModelInstance],
- allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
- disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
- **kwargs: Any,
- ):
- def _token_encoder(texts: list[str]) -> list[int]:
- if not texts:
- return []
- if embedding_model_instance:
- return embedding_model_instance.get_text_embedding_num_tokens(texts=texts)
- else:
- return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
- if issubclass(cls, TokenTextSplitter):
- extra_kwargs = {
- "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
- "allowed_special": allowed_special,
- "disallowed_special": disallowed_special,
- }
- kwargs = {**kwargs, **extra_kwargs}
- return cls(length_function=_token_encoder, **kwargs)
- class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
- def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any):
- """Create a new TextSplitter."""
- super().__init__(**kwargs)
- self._fixed_separator = fixed_separator
- self._separators = separators or ["\n\n", "\n", " ", ""]
- def split_text(self, text: str) -> list[str]:
- """Split incoming text and return chunks."""
- if self._fixed_separator:
- chunks = text.split(self._fixed_separator)
- else:
- chunks = [text]
- final_chunks = []
- chunks_lengths = self._length_function(chunks)
- for chunk, chunk_length in zip(chunks, chunks_lengths):
- if chunk_length > self._chunk_size:
- final_chunks.extend(self.recursive_split_text(chunk))
- else:
- final_chunks.append(chunk)
- return final_chunks
- def recursive_split_text(self, text: str) -> list[str]:
- """Split incoming text and return chunks."""
- final_chunks = []
- separator = self._separators[-1]
- new_separators = []
- for i, _s in enumerate(self._separators):
- if _s == "":
- separator = _s
- break
- if _s in text:
- separator = _s
- new_separators = self._separators[i + 1 :]
- break
- # Now that we have the separator, split the text
- if separator:
- if separator == " ":
- splits = text.split()
- else:
- splits = text.split(separator)
- else:
- splits = list(text)
- splits = [s for s in splits if (s not in {"", "\n"})]
- _good_splits = []
- _good_splits_lengths = [] # cache the lengths of the splits
- _separator = "" if self._keep_separator else separator
- s_lens = self._length_function(splits)
- if _separator != "":
- for s, s_len in zip(splits, s_lens):
- if s_len < self._chunk_size:
- _good_splits.append(s)
- _good_splits_lengths.append(s_len)
- else:
- if _good_splits:
- merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
- final_chunks.extend(merged_text)
- _good_splits = []
- _good_splits_lengths = []
- if not new_separators:
- final_chunks.append(s)
- else:
- other_info = self._split_text(s, new_separators)
- final_chunks.extend(other_info)
- if _good_splits:
- merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
- final_chunks.extend(merged_text)
- else:
- current_part = ""
- current_length = 0
- overlap_part = ""
- overlap_part_length = 0
- for s, s_len in zip(splits, s_lens):
- if current_length + s_len <= self._chunk_size - self._chunk_overlap:
- current_part += s
- current_length += s_len
- elif current_length + s_len <= self._chunk_size:
- current_part += s
- current_length += s_len
- overlap_part += s
- overlap_part_length += s_len
- else:
- final_chunks.append(current_part)
- current_part = overlap_part + s
- current_length = s_len + overlap_part_length
- overlap_part = ""
- overlap_part_length = 0
- if current_part:
- final_chunks.append(current_part)
- return final_chunks
|