fixed_text_splitter.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. """Functionality for splitting text."""
  2. from __future__ import annotations
  3. from typing import Any, Optional, cast
  4. from langchain.text_splitter import (
  5. TS,
  6. AbstractSet,
  7. Collection,
  8. Literal,
  9. RecursiveCharacterTextSplitter,
  10. TokenTextSplitter,
  11. Type,
  12. Union,
  13. )
  14. from core.model_manager import ModelInstance
  15. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  16. from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
  17. class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
  18. """
  19. This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
  20. """
  21. @classmethod
  22. def from_encoder(
  23. cls: Type[TS],
  24. embedding_model_instance: Optional[ModelInstance],
  25. allowed_special: Union[Literal[all], AbstractSet[str]] = set(),
  26. disallowed_special: Union[Literal[all], Collection[str]] = "all",
  27. **kwargs: Any,
  28. ):
  29. def _token_encoder(text: str) -> int:
  30. if not text:
  31. return 0
  32. if embedding_model_instance:
  33. embedding_model_type_instance = embedding_model_instance.model_type_instance
  34. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  35. return embedding_model_type_instance.get_num_tokens(
  36. model=embedding_model_instance.model,
  37. credentials=embedding_model_instance.credentials,
  38. texts=[text]
  39. )
  40. else:
  41. return GPT2Tokenizer.get_num_tokens(text)
  42. if issubclass(cls, TokenTextSplitter):
  43. extra_kwargs = {
  44. "model_name": embedding_model_instance.model if embedding_model_instance else 'gpt2',
  45. "allowed_special": allowed_special,
  46. "disallowed_special": disallowed_special,
  47. }
  48. kwargs = {**kwargs, **extra_kwargs}
  49. return cls(length_function=_token_encoder, **kwargs)
  50. class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
  51. def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any):
  52. """Create a new TextSplitter."""
  53. super().__init__(**kwargs)
  54. self._fixed_separator = fixed_separator
  55. self._separators = separators or ["\n\n", "\n", " ", ""]
  56. def split_text(self, text: str) -> list[str]:
  57. """Split incoming text and return chunks."""
  58. if self._fixed_separator:
  59. chunks = text.split(self._fixed_separator)
  60. else:
  61. chunks = list(text)
  62. final_chunks = []
  63. for chunk in chunks:
  64. if self._length_function(chunk) > self._chunk_size:
  65. final_chunks.extend(self.recursive_split_text(chunk))
  66. else:
  67. final_chunks.append(chunk)
  68. return final_chunks
  69. def recursive_split_text(self, text: str) -> list[str]:
  70. """Split incoming text and return chunks."""
  71. final_chunks = []
  72. # Get appropriate separator to use
  73. separator = self._separators[-1]
  74. for _s in self._separators:
  75. if _s == "":
  76. separator = _s
  77. break
  78. if _s in text:
  79. separator = _s
  80. break
  81. # Now that we have the separator, split the text
  82. if separator:
  83. splits = text.split(separator)
  84. else:
  85. splits = list(text)
  86. # Now go merging things, recursively splitting longer texts.
  87. _good_splits = []
  88. for s in splits:
  89. if self._length_function(s) < self._chunk_size:
  90. _good_splits.append(s)
  91. else:
  92. if _good_splits:
  93. merged_text = self._merge_splits(_good_splits, separator)
  94. final_chunks.extend(merged_text)
  95. _good_splits = []
  96. other_info = self.recursive_split_text(s)
  97. final_chunks.extend(other_info)
  98. if _good_splits:
  99. merged_text = self._merge_splits(_good_splits, separator)
  100. final_chunks.extend(merged_text)
  101. return final_chunks