fixed_text_splitter.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. """Functionality for splitting text."""
  2. from __future__ import annotations
  3. from typing import Any, Optional
  4. from core.model_manager import ModelInstance
  5. from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
  6. from core.rag.splitter.text_splitter import (
  7. TS,
  8. Collection,
  9. Literal,
  10. RecursiveCharacterTextSplitter,
  11. Set,
  12. TokenTextSplitter,
  13. Union,
  14. )
  15. class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
  16. """
  17. This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
  18. """
  19. @classmethod
  20. def from_encoder(
  21. cls: type[TS],
  22. embedding_model_instance: Optional[ModelInstance],
  23. allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
  24. disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
  25. **kwargs: Any,
  26. ):
  27. def _token_encoder(texts: list[str]) -> list[int]:
  28. if not texts:
  29. return []
  30. if embedding_model_instance:
  31. return embedding_model_instance.get_text_embedding_num_tokens(texts=texts)
  32. else:
  33. return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
  34. if issubclass(cls, TokenTextSplitter):
  35. extra_kwargs = {
  36. "model_name": embedding_model_instance.model if embedding_model_instance else "gpt2",
  37. "allowed_special": allowed_special,
  38. "disallowed_special": disallowed_special,
  39. }
  40. kwargs = {**kwargs, **extra_kwargs}
  41. return cls(length_function=_token_encoder, **kwargs)
  42. class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
  43. def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any):
  44. """Create a new TextSplitter."""
  45. super().__init__(**kwargs)
  46. self._fixed_separator = fixed_separator
  47. self._separators = separators or ["\n\n", "\n", " ", ""]
  48. def split_text(self, text: str) -> list[str]:
  49. """Split incoming text and return chunks."""
  50. if self._fixed_separator:
  51. chunks = text.split(self._fixed_separator)
  52. else:
  53. chunks = [text]
  54. final_chunks = []
  55. chunks_lengths = self._length_function(chunks)
  56. for chunk, chunk_length in zip(chunks, chunks_lengths):
  57. if chunk_length > self._chunk_size:
  58. final_chunks.extend(self.recursive_split_text(chunk))
  59. else:
  60. final_chunks.append(chunk)
  61. return final_chunks
  62. def recursive_split_text(self, text: str) -> list[str]:
  63. """Split incoming text and return chunks."""
  64. final_chunks = []
  65. separator = self._separators[-1]
  66. new_separators = []
  67. for i, _s in enumerate(self._separators):
  68. if _s == "":
  69. separator = _s
  70. break
  71. if _s in text:
  72. separator = _s
  73. new_separators = self._separators[i + 1 :]
  74. break
  75. # Now that we have the separator, split the text
  76. if separator:
  77. if separator == " ":
  78. splits = text.split()
  79. else:
  80. splits = text.split(separator)
  81. else:
  82. splits = list(text)
  83. splits = [s for s in splits if (s not in {"", "\n"})]
  84. _good_splits = []
  85. _good_splits_lengths = [] # cache the lengths of the splits
  86. _separator = "" if self._keep_separator else separator
  87. s_lens = self._length_function(splits)
  88. if _separator != "":
  89. for s, s_len in zip(splits, s_lens):
  90. if s_len < self._chunk_size:
  91. _good_splits.append(s)
  92. _good_splits_lengths.append(s_len)
  93. else:
  94. if _good_splits:
  95. merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
  96. final_chunks.extend(merged_text)
  97. _good_splits = []
  98. _good_splits_lengths = []
  99. if not new_separators:
  100. final_chunks.append(s)
  101. else:
  102. other_info = self._split_text(s, new_separators)
  103. final_chunks.extend(other_info)
  104. if _good_splits:
  105. merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths)
  106. final_chunks.extend(merged_text)
  107. else:
  108. current_part = ""
  109. current_length = 0
  110. overlap_part = ""
  111. overlap_part_length = 0
  112. for s, s_len in zip(splits, s_lens):
  113. if current_length + s_len <= self._chunk_size - self._chunk_overlap:
  114. current_part += s
  115. current_length += s_len
  116. elif current_length + s_len <= self._chunk_size:
  117. current_part += s
  118. current_length += s_len
  119. overlap_part += s
  120. overlap_part_length += s_len
  121. else:
  122. final_chunks.append(current_part)
  123. current_part = overlap_part + s
  124. current_length = s_len + overlap_part_length
  125. overlap_part = ""
  126. overlap_part_length = 0
  127. if current_part:
  128. final_chunks.append(current_part)
  129. return final_chunks