123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535 |
- from __future__ import annotations
- import copy
- import logging
- import re
- from abc import ABC, abstractmethod
- from collections.abc import Callable, Collection, Iterable, Sequence, Set
- from dataclasses import dataclass
- from typing import (
- Any,
- Literal,
- Optional,
- TypedDict,
- TypeVar,
- Union,
- )
- from core.rag.models.document import BaseDocumentTransformer, Document
- logger = logging.getLogger(__name__)
- TS = TypeVar("TS", bound="TextSplitter")
- def _split_text_with_regex(
- text: str, separator: str, keep_separator: bool
- ) -> list[str]:
- # Now that we have the separator, split the text
- if separator:
- if keep_separator:
- # The parentheses in the pattern keep the delimiters in the result.
- _splits = re.split(f"({re.escape(separator)})", text)
- splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
- if len(_splits) % 2 == 0:
- splits += _splits[-1:]
- splits = [_splits[0]] + splits
- else:
- splits = re.split(separator, text)
- else:
- splits = list(text)
- return [s for s in splits if s != ""]
- class TextSplitter(BaseDocumentTransformer, ABC):
- """Interface for splitting text into chunks."""
- def __init__(
- self,
- chunk_size: int = 4000,
- chunk_overlap: int = 200,
- length_function: Callable[[str], int] = len,
- keep_separator: bool = False,
- add_start_index: bool = False,
- ) -> None:
- """Create a new TextSplitter.
- Args:
- chunk_size: Maximum size of chunks to return
- chunk_overlap: Overlap in characters between chunks
- length_function: Function that measures the length of given chunks
- keep_separator: Whether to keep the separator in the chunks
- add_start_index: If `True`, includes chunk's start index in metadata
- """
- if chunk_overlap > chunk_size:
- raise ValueError(
- f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
- f"({chunk_size}), should be smaller."
- )
- self._chunk_size = chunk_size
- self._chunk_overlap = chunk_overlap
- self._length_function = length_function
- self._keep_separator = keep_separator
- self._add_start_index = add_start_index
- @abstractmethod
- def split_text(self, text: str) -> list[str]:
- """Split text into multiple components."""
- def create_documents(
- self, texts: list[str], metadatas: Optional[list[dict]] = None
- ) -> list[Document]:
- """Create documents from a list of texts."""
- _metadatas = metadatas or [{}] * len(texts)
- documents = []
- for i, text in enumerate(texts):
- index = -1
- for chunk in self.split_text(text):
- metadata = copy.deepcopy(_metadatas[i])
- if self._add_start_index:
- index = text.find(chunk, index + 1)
- metadata["start_index"] = index
- new_doc = Document(page_content=chunk, metadata=metadata)
- documents.append(new_doc)
- return documents
- def split_documents(self, documents: Iterable[Document]) -> list[Document]:
- """Split documents."""
- texts, metadatas = [], []
- for doc in documents:
- texts.append(doc.page_content)
- metadatas.append(doc.metadata)
- return self.create_documents(texts, metadatas=metadatas)
- def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
- text = separator.join(docs)
- text = text.strip()
- if text == "":
- return None
- else:
- return text
- def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
- # We now want to combine these smaller pieces into medium size
- # chunks to send to the LLM.
- separator_len = self._length_function(separator)
- docs = []
- current_doc: list[str] = []
- total = 0
- for d in splits:
- _len = self._length_function(d)
- if (
- total + _len + (separator_len if len(current_doc) > 0 else 0)
- > self._chunk_size
- ):
- if total > self._chunk_size:
- logger.warning(
- f"Created a chunk of size {total}, "
- f"which is longer than the specified {self._chunk_size}"
- )
- if len(current_doc) > 0:
- doc = self._join_docs(current_doc, separator)
- if doc is not None:
- docs.append(doc)
- # Keep on popping if:
- # - we have a larger chunk than in the chunk overlap
- # - or if we still have any chunks and the length is long
- while total > self._chunk_overlap or (
- total + _len + (separator_len if len(current_doc) > 0 else 0)
- > self._chunk_size
- and total > 0
- ):
- total -= self._length_function(current_doc[0]) + (
- separator_len if len(current_doc) > 1 else 0
- )
- current_doc = current_doc[1:]
- current_doc.append(d)
- total += _len + (separator_len if len(current_doc) > 1 else 0)
- doc = self._join_docs(current_doc, separator)
- if doc is not None:
- docs.append(doc)
- return docs
- @classmethod
- def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
- """Text splitter that uses HuggingFace tokenizer to count length."""
- try:
- from transformers import PreTrainedTokenizerBase
- if not isinstance(tokenizer, PreTrainedTokenizerBase):
- raise ValueError(
- "Tokenizer received was not an instance of PreTrainedTokenizerBase"
- )
- def _huggingface_tokenizer_length(text: str) -> int:
- return len(tokenizer.encode(text))
- except ImportError:
- raise ValueError(
- "Could not import transformers python package. "
- "Please install it with `pip install transformers`."
- )
- return cls(length_function=_huggingface_tokenizer_length, **kwargs)
- @classmethod
- def from_tiktoken_encoder(
- cls: type[TS],
- encoding_name: str = "gpt2",
- model_name: Optional[str] = None,
- allowed_special: Union[Literal["all"], Set[str]] = set(),
- disallowed_special: Union[Literal["all"], Collection[str]] = "all",
- **kwargs: Any,
- ) -> TS:
- """Text splitter that uses tiktoken encoder to count length."""
- try:
- import tiktoken
- except ImportError:
- raise ImportError(
- "Could not import tiktoken python package. "
- "This is needed in order to calculate max_tokens_for_prompt. "
- "Please install it with `pip install tiktoken`."
- )
- if model_name is not None:
- enc = tiktoken.encoding_for_model(model_name)
- else:
- enc = tiktoken.get_encoding(encoding_name)
- def _tiktoken_encoder(text: str) -> int:
- return len(
- enc.encode(
- text,
- allowed_special=allowed_special,
- disallowed_special=disallowed_special,
- )
- )
- if issubclass(cls, TokenTextSplitter):
- extra_kwargs = {
- "encoding_name": encoding_name,
- "model_name": model_name,
- "allowed_special": allowed_special,
- "disallowed_special": disallowed_special,
- }
- kwargs = {**kwargs, **extra_kwargs}
- return cls(length_function=_tiktoken_encoder, **kwargs)
- def transform_documents(
- self, documents: Sequence[Document], **kwargs: Any
- ) -> Sequence[Document]:
- """Transform sequence of documents by splitting them."""
- return self.split_documents(list(documents))
- async def atransform_documents(
- self, documents: Sequence[Document], **kwargs: Any
- ) -> Sequence[Document]:
- """Asynchronously transform a sequence of documents by splitting them."""
- raise NotImplementedError
- class CharacterTextSplitter(TextSplitter):
- """Splitting text that looks at characters."""
- def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
- """Create a new TextSplitter."""
- super().__init__(**kwargs)
- self._separator = separator
- def split_text(self, text: str) -> list[str]:
- """Split incoming text and return chunks."""
- # First we naively split the large input into a bunch of smaller ones.
- splits = _split_text_with_regex(text, self._separator, self._keep_separator)
- _separator = "" if self._keep_separator else self._separator
- return self._merge_splits(splits, _separator)
- class LineType(TypedDict):
- """Line type as typed dict."""
- metadata: dict[str, str]
- content: str
- class HeaderType(TypedDict):
- """Header type as typed dict."""
- level: int
- name: str
- data: str
- class MarkdownHeaderTextSplitter:
- """Splitting markdown files based on specified headers."""
- def __init__(
- self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False
- ):
- """Create a new MarkdownHeaderTextSplitter.
- Args:
- headers_to_split_on: Headers we want to track
- return_each_line: Return each line w/ associated headers
- """
- # Output line-by-line or aggregated into chunks w/ common headers
- self.return_each_line = return_each_line
- # Given the headers we want to split on,
- # (e.g., "#, ##, etc") order by length
- self.headers_to_split_on = sorted(
- headers_to_split_on, key=lambda split: len(split[0]), reverse=True
- )
- def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
- """Combine lines with common metadata into chunks
- Args:
- lines: Line of text / associated header metadata
- """
- aggregated_chunks: list[LineType] = []
- for line in lines:
- if (
- aggregated_chunks
- and aggregated_chunks[-1]["metadata"] == line["metadata"]
- ):
- # If the last line in the aggregated list
- # has the same metadata as the current line,
- # append the current content to the last lines's content
- aggregated_chunks[-1]["content"] += " \n" + line["content"]
- else:
- # Otherwise, append the current line to the aggregated list
- aggregated_chunks.append(line)
- return [
- Document(page_content=chunk["content"], metadata=chunk["metadata"])
- for chunk in aggregated_chunks
- ]
- def split_text(self, text: str) -> list[Document]:
- """Split markdown file
- Args:
- text: Markdown file"""
- # Split the input text by newline character ("\n").
- lines = text.split("\n")
- # Final output
- lines_with_metadata: list[LineType] = []
- # Content and metadata of the chunk currently being processed
- current_content: list[str] = []
- current_metadata: dict[str, str] = {}
- # Keep track of the nested header structure
- # header_stack: List[Dict[str, Union[int, str]]] = []
- header_stack: list[HeaderType] = []
- initial_metadata: dict[str, str] = {}
- for line in lines:
- stripped_line = line.strip()
- # Check each line against each of the header types (e.g., #, ##)
- for sep, name in self.headers_to_split_on:
- # Check if line starts with a header that we intend to split on
- if stripped_line.startswith(sep) and (
- # Header with no text OR header is followed by space
- # Both are valid conditions that sep is being used a header
- len(stripped_line) == len(sep)
- or stripped_line[len(sep)] == " "
- ):
- # Ensure we are tracking the header as metadata
- if name is not None:
- # Get the current header level
- current_header_level = sep.count("#")
- # Pop out headers of lower or same level from the stack
- while (
- header_stack
- and header_stack[-1]["level"] >= current_header_level
- ):
- # We have encountered a new header
- # at the same or higher level
- popped_header = header_stack.pop()
- # Clear the metadata for the
- # popped header in initial_metadata
- if popped_header["name"] in initial_metadata:
- initial_metadata.pop(popped_header["name"])
- # Push the current header to the stack
- header: HeaderType = {
- "level": current_header_level,
- "name": name,
- "data": stripped_line[len(sep):].strip(),
- }
- header_stack.append(header)
- # Update initial_metadata with the current header
- initial_metadata[name] = header["data"]
- # Add the previous line to the lines_with_metadata
- # only if current_content is not empty
- if current_content:
- lines_with_metadata.append(
- {
- "content": "\n".join(current_content),
- "metadata": current_metadata.copy(),
- }
- )
- current_content.clear()
- break
- else:
- if stripped_line:
- current_content.append(stripped_line)
- elif current_content:
- lines_with_metadata.append(
- {
- "content": "\n".join(current_content),
- "metadata": current_metadata.copy(),
- }
- )
- current_content.clear()
- current_metadata = initial_metadata.copy()
- if current_content:
- lines_with_metadata.append(
- {"content": "\n".join(current_content), "metadata": current_metadata}
- )
- # lines_with_metadata has each line with associated header metadata
- # aggregate these into chunks based on common metadata
- if not self.return_each_line:
- return self.aggregate_lines_to_chunks(lines_with_metadata)
- else:
- return [
- Document(page_content=chunk["content"], metadata=chunk["metadata"])
- for chunk in lines_with_metadata
- ]
- # should be in newer Python versions (3.10+)
- # @dataclass(frozen=True, kw_only=True, slots=True)
- @dataclass(frozen=True)
- class Tokenizer:
- chunk_overlap: int
- tokens_per_chunk: int
- decode: Callable[[list[int]], str]
- encode: Callable[[str], list[int]]
- def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
- """Split incoming text and return chunks using tokenizer."""
- splits: list[str] = []
- input_ids = tokenizer.encode(text)
- start_idx = 0
- cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
- chunk_ids = input_ids[start_idx:cur_idx]
- while start_idx < len(input_ids):
- splits.append(tokenizer.decode(chunk_ids))
- start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
- cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
- chunk_ids = input_ids[start_idx:cur_idx]
- return splits
- class TokenTextSplitter(TextSplitter):
- """Splitting text to tokens using model tokenizer."""
- def __init__(
- self,
- encoding_name: str = "gpt2",
- model_name: Optional[str] = None,
- allowed_special: Union[Literal["all"], Set[str]] = set(),
- disallowed_special: Union[Literal["all"], Collection[str]] = "all",
- **kwargs: Any,
- ) -> None:
- """Create a new TextSplitter."""
- super().__init__(**kwargs)
- try:
- import tiktoken
- except ImportError:
- raise ImportError(
- "Could not import tiktoken python package. "
- "This is needed in order to for TokenTextSplitter. "
- "Please install it with `pip install tiktoken`."
- )
- if model_name is not None:
- enc = tiktoken.encoding_for_model(model_name)
- else:
- enc = tiktoken.get_encoding(encoding_name)
- self._tokenizer = enc
- self._allowed_special = allowed_special
- self._disallowed_special = disallowed_special
- def split_text(self, text: str) -> list[str]:
- def _encode(_text: str) -> list[int]:
- return self._tokenizer.encode(
- _text,
- allowed_special=self._allowed_special,
- disallowed_special=self._disallowed_special,
- )
- tokenizer = Tokenizer(
- chunk_overlap=self._chunk_overlap,
- tokens_per_chunk=self._chunk_size,
- decode=self._tokenizer.decode,
- encode=_encode,
- )
- return split_text_on_tokens(text=text, tokenizer=tokenizer)
- class RecursiveCharacterTextSplitter(TextSplitter):
- """Splitting text by recursively look at characters.
- Recursively tries to split by different characters to find one
- that works.
- """
- def __init__(
- self,
- separators: Optional[list[str]] = None,
- keep_separator: bool = True,
- **kwargs: Any,
- ) -> None:
- """Create a new TextSplitter."""
- super().__init__(keep_separator=keep_separator, **kwargs)
- self._separators = separators or ["\n\n", "\n", " ", ""]
- def _split_text(self, text: str, separators: list[str]) -> list[str]:
- """Split incoming text and return chunks."""
- final_chunks = []
- # Get appropriate separator to use
- separator = separators[-1]
- new_separators = []
- for i, _s in enumerate(separators):
- if _s == "":
- separator = _s
- break
- if re.search(_s, text):
- separator = _s
- new_separators = separators[i + 1:]
- break
- splits = _split_text_with_regex(text, separator, self._keep_separator)
- # Now go merging things, recursively splitting longer texts.
- _good_splits = []
- _separator = "" if self._keep_separator else separator
- for s in splits:
- if self._length_function(s) < self._chunk_size:
- _good_splits.append(s)
- else:
- if _good_splits:
- merged_text = self._merge_splits(_good_splits, _separator)
- final_chunks.extend(merged_text)
- _good_splits = []
- if not new_separators:
- final_chunks.append(s)
- else:
- other_info = self._split_text(s, new_separators)
- final_chunks.extend(other_info)
- if _good_splits:
- merged_text = self._merge_splits(_good_splits, _separator)
- final_chunks.extend(merged_text)
- return final_chunks
- def split_text(self, text: str) -> list[str]:
- return self._split_text(text, self._separators)
|