text_splitter.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535
  1. from __future__ import annotations
  2. import copy
  3. import logging
  4. import re
  5. from abc import ABC, abstractmethod
  6. from collections.abc import Callable, Collection, Iterable, Sequence, Set
  7. from dataclasses import dataclass
  8. from typing import (
  9. Any,
  10. Literal,
  11. Optional,
  12. TypedDict,
  13. TypeVar,
  14. Union,
  15. )
  16. from core.rag.models.document import BaseDocumentTransformer, Document
  17. logger = logging.getLogger(__name__)
  18. TS = TypeVar("TS", bound="TextSplitter")
  19. def _split_text_with_regex(
  20. text: str, separator: str, keep_separator: bool
  21. ) -> list[str]:
  22. # Now that we have the separator, split the text
  23. if separator:
  24. if keep_separator:
  25. # The parentheses in the pattern keep the delimiters in the result.
  26. _splits = re.split(f"({re.escape(separator)})", text)
  27. splits = [_splits[i] + _splits[i + 1] for i in range(1, len(_splits), 2)]
  28. if len(_splits) % 2 == 0:
  29. splits += _splits[-1:]
  30. splits = [_splits[0]] + splits
  31. else:
  32. splits = re.split(separator, text)
  33. else:
  34. splits = list(text)
  35. return [s for s in splits if s != ""]
  36. class TextSplitter(BaseDocumentTransformer, ABC):
  37. """Interface for splitting text into chunks."""
  38. def __init__(
  39. self,
  40. chunk_size: int = 4000,
  41. chunk_overlap: int = 200,
  42. length_function: Callable[[str], int] = len,
  43. keep_separator: bool = False,
  44. add_start_index: bool = False,
  45. ) -> None:
  46. """Create a new TextSplitter.
  47. Args:
  48. chunk_size: Maximum size of chunks to return
  49. chunk_overlap: Overlap in characters between chunks
  50. length_function: Function that measures the length of given chunks
  51. keep_separator: Whether to keep the separator in the chunks
  52. add_start_index: If `True`, includes chunk's start index in metadata
  53. """
  54. if chunk_overlap > chunk_size:
  55. raise ValueError(
  56. f"Got a larger chunk overlap ({chunk_overlap}) than chunk size "
  57. f"({chunk_size}), should be smaller."
  58. )
  59. self._chunk_size = chunk_size
  60. self._chunk_overlap = chunk_overlap
  61. self._length_function = length_function
  62. self._keep_separator = keep_separator
  63. self._add_start_index = add_start_index
  64. @abstractmethod
  65. def split_text(self, text: str) -> list[str]:
  66. """Split text into multiple components."""
  67. def create_documents(
  68. self, texts: list[str], metadatas: Optional[list[dict]] = None
  69. ) -> list[Document]:
  70. """Create documents from a list of texts."""
  71. _metadatas = metadatas or [{}] * len(texts)
  72. documents = []
  73. for i, text in enumerate(texts):
  74. index = -1
  75. for chunk in self.split_text(text):
  76. metadata = copy.deepcopy(_metadatas[i])
  77. if self._add_start_index:
  78. index = text.find(chunk, index + 1)
  79. metadata["start_index"] = index
  80. new_doc = Document(page_content=chunk, metadata=metadata)
  81. documents.append(new_doc)
  82. return documents
  83. def split_documents(self, documents: Iterable[Document]) -> list[Document]:
  84. """Split documents."""
  85. texts, metadatas = [], []
  86. for doc in documents:
  87. texts.append(doc.page_content)
  88. metadatas.append(doc.metadata)
  89. return self.create_documents(texts, metadatas=metadatas)
  90. def _join_docs(self, docs: list[str], separator: str) -> Optional[str]:
  91. text = separator.join(docs)
  92. text = text.strip()
  93. if text == "":
  94. return None
  95. else:
  96. return text
  97. def _merge_splits(self, splits: Iterable[str], separator: str) -> list[str]:
  98. # We now want to combine these smaller pieces into medium size
  99. # chunks to send to the LLM.
  100. separator_len = self._length_function(separator)
  101. docs = []
  102. current_doc: list[str] = []
  103. total = 0
  104. for d in splits:
  105. _len = self._length_function(d)
  106. if (
  107. total + _len + (separator_len if len(current_doc) > 0 else 0)
  108. > self._chunk_size
  109. ):
  110. if total > self._chunk_size:
  111. logger.warning(
  112. f"Created a chunk of size {total}, "
  113. f"which is longer than the specified {self._chunk_size}"
  114. )
  115. if len(current_doc) > 0:
  116. doc = self._join_docs(current_doc, separator)
  117. if doc is not None:
  118. docs.append(doc)
  119. # Keep on popping if:
  120. # - we have a larger chunk than in the chunk overlap
  121. # - or if we still have any chunks and the length is long
  122. while total > self._chunk_overlap or (
  123. total + _len + (separator_len if len(current_doc) > 0 else 0)
  124. > self._chunk_size
  125. and total > 0
  126. ):
  127. total -= self._length_function(current_doc[0]) + (
  128. separator_len if len(current_doc) > 1 else 0
  129. )
  130. current_doc = current_doc[1:]
  131. current_doc.append(d)
  132. total += _len + (separator_len if len(current_doc) > 1 else 0)
  133. doc = self._join_docs(current_doc, separator)
  134. if doc is not None:
  135. docs.append(doc)
  136. return docs
  137. @classmethod
  138. def from_huggingface_tokenizer(cls, tokenizer: Any, **kwargs: Any) -> TextSplitter:
  139. """Text splitter that uses HuggingFace tokenizer to count length."""
  140. try:
  141. from transformers import PreTrainedTokenizerBase
  142. if not isinstance(tokenizer, PreTrainedTokenizerBase):
  143. raise ValueError(
  144. "Tokenizer received was not an instance of PreTrainedTokenizerBase"
  145. )
  146. def _huggingface_tokenizer_length(text: str) -> int:
  147. return len(tokenizer.encode(text))
  148. except ImportError:
  149. raise ValueError(
  150. "Could not import transformers python package. "
  151. "Please install it with `pip install transformers`."
  152. )
  153. return cls(length_function=_huggingface_tokenizer_length, **kwargs)
  154. @classmethod
  155. def from_tiktoken_encoder(
  156. cls: type[TS],
  157. encoding_name: str = "gpt2",
  158. model_name: Optional[str] = None,
  159. allowed_special: Union[Literal["all"], Set[str]] = set(),
  160. disallowed_special: Union[Literal["all"], Collection[str]] = "all",
  161. **kwargs: Any,
  162. ) -> TS:
  163. """Text splitter that uses tiktoken encoder to count length."""
  164. try:
  165. import tiktoken
  166. except ImportError:
  167. raise ImportError(
  168. "Could not import tiktoken python package. "
  169. "This is needed in order to calculate max_tokens_for_prompt. "
  170. "Please install it with `pip install tiktoken`."
  171. )
  172. if model_name is not None:
  173. enc = tiktoken.encoding_for_model(model_name)
  174. else:
  175. enc = tiktoken.get_encoding(encoding_name)
  176. def _tiktoken_encoder(text: str) -> int:
  177. return len(
  178. enc.encode(
  179. text,
  180. allowed_special=allowed_special,
  181. disallowed_special=disallowed_special,
  182. )
  183. )
  184. if issubclass(cls, TokenTextSplitter):
  185. extra_kwargs = {
  186. "encoding_name": encoding_name,
  187. "model_name": model_name,
  188. "allowed_special": allowed_special,
  189. "disallowed_special": disallowed_special,
  190. }
  191. kwargs = {**kwargs, **extra_kwargs}
  192. return cls(length_function=_tiktoken_encoder, **kwargs)
  193. def transform_documents(
  194. self, documents: Sequence[Document], **kwargs: Any
  195. ) -> Sequence[Document]:
  196. """Transform sequence of documents by splitting them."""
  197. return self.split_documents(list(documents))
  198. async def atransform_documents(
  199. self, documents: Sequence[Document], **kwargs: Any
  200. ) -> Sequence[Document]:
  201. """Asynchronously transform a sequence of documents by splitting them."""
  202. raise NotImplementedError
  203. class CharacterTextSplitter(TextSplitter):
  204. """Splitting text that looks at characters."""
  205. def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
  206. """Create a new TextSplitter."""
  207. super().__init__(**kwargs)
  208. self._separator = separator
  209. def split_text(self, text: str) -> list[str]:
  210. """Split incoming text and return chunks."""
  211. # First we naively split the large input into a bunch of smaller ones.
  212. splits = _split_text_with_regex(text, self._separator, self._keep_separator)
  213. _separator = "" if self._keep_separator else self._separator
  214. return self._merge_splits(splits, _separator)
  215. class LineType(TypedDict):
  216. """Line type as typed dict."""
  217. metadata: dict[str, str]
  218. content: str
  219. class HeaderType(TypedDict):
  220. """Header type as typed dict."""
  221. level: int
  222. name: str
  223. data: str
  224. class MarkdownHeaderTextSplitter:
  225. """Splitting markdown files based on specified headers."""
  226. def __init__(
  227. self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False
  228. ):
  229. """Create a new MarkdownHeaderTextSplitter.
  230. Args:
  231. headers_to_split_on: Headers we want to track
  232. return_each_line: Return each line w/ associated headers
  233. """
  234. # Output line-by-line or aggregated into chunks w/ common headers
  235. self.return_each_line = return_each_line
  236. # Given the headers we want to split on,
  237. # (e.g., "#, ##, etc") order by length
  238. self.headers_to_split_on = sorted(
  239. headers_to_split_on, key=lambda split: len(split[0]), reverse=True
  240. )
  241. def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
  242. """Combine lines with common metadata into chunks
  243. Args:
  244. lines: Line of text / associated header metadata
  245. """
  246. aggregated_chunks: list[LineType] = []
  247. for line in lines:
  248. if (
  249. aggregated_chunks
  250. and aggregated_chunks[-1]["metadata"] == line["metadata"]
  251. ):
  252. # If the last line in the aggregated list
  253. # has the same metadata as the current line,
  254. # append the current content to the last lines's content
  255. aggregated_chunks[-1]["content"] += " \n" + line["content"]
  256. else:
  257. # Otherwise, append the current line to the aggregated list
  258. aggregated_chunks.append(line)
  259. return [
  260. Document(page_content=chunk["content"], metadata=chunk["metadata"])
  261. for chunk in aggregated_chunks
  262. ]
  263. def split_text(self, text: str) -> list[Document]:
  264. """Split markdown file
  265. Args:
  266. text: Markdown file"""
  267. # Split the input text by newline character ("\n").
  268. lines = text.split("\n")
  269. # Final output
  270. lines_with_metadata: list[LineType] = []
  271. # Content and metadata of the chunk currently being processed
  272. current_content: list[str] = []
  273. current_metadata: dict[str, str] = {}
  274. # Keep track of the nested header structure
  275. # header_stack: List[Dict[str, Union[int, str]]] = []
  276. header_stack: list[HeaderType] = []
  277. initial_metadata: dict[str, str] = {}
  278. for line in lines:
  279. stripped_line = line.strip()
  280. # Check each line against each of the header types (e.g., #, ##)
  281. for sep, name in self.headers_to_split_on:
  282. # Check if line starts with a header that we intend to split on
  283. if stripped_line.startswith(sep) and (
  284. # Header with no text OR header is followed by space
  285. # Both are valid conditions that sep is being used a header
  286. len(stripped_line) == len(sep)
  287. or stripped_line[len(sep)] == " "
  288. ):
  289. # Ensure we are tracking the header as metadata
  290. if name is not None:
  291. # Get the current header level
  292. current_header_level = sep.count("#")
  293. # Pop out headers of lower or same level from the stack
  294. while (
  295. header_stack
  296. and header_stack[-1]["level"] >= current_header_level
  297. ):
  298. # We have encountered a new header
  299. # at the same or higher level
  300. popped_header = header_stack.pop()
  301. # Clear the metadata for the
  302. # popped header in initial_metadata
  303. if popped_header["name"] in initial_metadata:
  304. initial_metadata.pop(popped_header["name"])
  305. # Push the current header to the stack
  306. header: HeaderType = {
  307. "level": current_header_level,
  308. "name": name,
  309. "data": stripped_line[len(sep):].strip(),
  310. }
  311. header_stack.append(header)
  312. # Update initial_metadata with the current header
  313. initial_metadata[name] = header["data"]
  314. # Add the previous line to the lines_with_metadata
  315. # only if current_content is not empty
  316. if current_content:
  317. lines_with_metadata.append(
  318. {
  319. "content": "\n".join(current_content),
  320. "metadata": current_metadata.copy(),
  321. }
  322. )
  323. current_content.clear()
  324. break
  325. else:
  326. if stripped_line:
  327. current_content.append(stripped_line)
  328. elif current_content:
  329. lines_with_metadata.append(
  330. {
  331. "content": "\n".join(current_content),
  332. "metadata": current_metadata.copy(),
  333. }
  334. )
  335. current_content.clear()
  336. current_metadata = initial_metadata.copy()
  337. if current_content:
  338. lines_with_metadata.append(
  339. {"content": "\n".join(current_content), "metadata": current_metadata}
  340. )
  341. # lines_with_metadata has each line with associated header metadata
  342. # aggregate these into chunks based on common metadata
  343. if not self.return_each_line:
  344. return self.aggregate_lines_to_chunks(lines_with_metadata)
  345. else:
  346. return [
  347. Document(page_content=chunk["content"], metadata=chunk["metadata"])
  348. for chunk in lines_with_metadata
  349. ]
  350. # should be in newer Python versions (3.10+)
  351. # @dataclass(frozen=True, kw_only=True, slots=True)
  352. @dataclass(frozen=True)
  353. class Tokenizer:
  354. chunk_overlap: int
  355. tokens_per_chunk: int
  356. decode: Callable[[list[int]], str]
  357. encode: Callable[[str], list[int]]
  358. def split_text_on_tokens(*, text: str, tokenizer: Tokenizer) -> list[str]:
  359. """Split incoming text and return chunks using tokenizer."""
  360. splits: list[str] = []
  361. input_ids = tokenizer.encode(text)
  362. start_idx = 0
  363. cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
  364. chunk_ids = input_ids[start_idx:cur_idx]
  365. while start_idx < len(input_ids):
  366. splits.append(tokenizer.decode(chunk_ids))
  367. start_idx += tokenizer.tokens_per_chunk - tokenizer.chunk_overlap
  368. cur_idx = min(start_idx + tokenizer.tokens_per_chunk, len(input_ids))
  369. chunk_ids = input_ids[start_idx:cur_idx]
  370. return splits
  371. class TokenTextSplitter(TextSplitter):
  372. """Splitting text to tokens using model tokenizer."""
  373. def __init__(
  374. self,
  375. encoding_name: str = "gpt2",
  376. model_name: Optional[str] = None,
  377. allowed_special: Union[Literal["all"], Set[str]] = set(),
  378. disallowed_special: Union[Literal["all"], Collection[str]] = "all",
  379. **kwargs: Any,
  380. ) -> None:
  381. """Create a new TextSplitter."""
  382. super().__init__(**kwargs)
  383. try:
  384. import tiktoken
  385. except ImportError:
  386. raise ImportError(
  387. "Could not import tiktoken python package. "
  388. "This is needed in order to for TokenTextSplitter. "
  389. "Please install it with `pip install tiktoken`."
  390. )
  391. if model_name is not None:
  392. enc = tiktoken.encoding_for_model(model_name)
  393. else:
  394. enc = tiktoken.get_encoding(encoding_name)
  395. self._tokenizer = enc
  396. self._allowed_special = allowed_special
  397. self._disallowed_special = disallowed_special
  398. def split_text(self, text: str) -> list[str]:
  399. def _encode(_text: str) -> list[int]:
  400. return self._tokenizer.encode(
  401. _text,
  402. allowed_special=self._allowed_special,
  403. disallowed_special=self._disallowed_special,
  404. )
  405. tokenizer = Tokenizer(
  406. chunk_overlap=self._chunk_overlap,
  407. tokens_per_chunk=self._chunk_size,
  408. decode=self._tokenizer.decode,
  409. encode=_encode,
  410. )
  411. return split_text_on_tokens(text=text, tokenizer=tokenizer)
  412. class RecursiveCharacterTextSplitter(TextSplitter):
  413. """Splitting text by recursively look at characters.
  414. Recursively tries to split by different characters to find one
  415. that works.
  416. """
  417. def __init__(
  418. self,
  419. separators: Optional[list[str]] = None,
  420. keep_separator: bool = True,
  421. **kwargs: Any,
  422. ) -> None:
  423. """Create a new TextSplitter."""
  424. super().__init__(keep_separator=keep_separator, **kwargs)
  425. self._separators = separators or ["\n\n", "\n", " ", ""]
  426. def _split_text(self, text: str, separators: list[str]) -> list[str]:
  427. """Split incoming text and return chunks."""
  428. final_chunks = []
  429. # Get appropriate separator to use
  430. separator = separators[-1]
  431. new_separators = []
  432. for i, _s in enumerate(separators):
  433. if _s == "":
  434. separator = _s
  435. break
  436. if re.search(_s, text):
  437. separator = _s
  438. new_separators = separators[i + 1:]
  439. break
  440. splits = _split_text_with_regex(text, separator, self._keep_separator)
  441. # Now go merging things, recursively splitting longer texts.
  442. _good_splits = []
  443. _separator = "" if self._keep_separator else separator
  444. for s in splits:
  445. if self._length_function(s) < self._chunk_size:
  446. _good_splits.append(s)
  447. else:
  448. if _good_splits:
  449. merged_text = self._merge_splits(_good_splits, _separator)
  450. final_chunks.extend(merged_text)
  451. _good_splits = []
  452. if not new_separators:
  453. final_chunks.append(s)
  454. else:
  455. other_info = self._split_text(s, new_separators)
  456. final_chunks.extend(other_info)
  457. if _good_splits:
  458. merged_text = self._merge_splits(_good_splits, _separator)
  459. final_chunks.extend(merged_text)
  460. return final_chunks
  461. def split_text(self, text: str) -> list[str]:
  462. return self._split_text(text, self._separators)