Bläddra i källkod

feat: support multi token count

Yeuoly 8 månader sedan
förälder
incheckning
db726e02a0

+ 2 - 4
api/core/indexing_runner.py

@@ -720,10 +720,8 @@ class IndexingRunner:
 
 
             tokens = 0
             tokens = 0
             if embedding_model_instance:
             if embedding_model_instance:
-                tokens += sum(
-                    embedding_model_instance.get_text_embedding_num_tokens([document.page_content])
-                    for document in chunk_documents
-                )
+                page_content_list = [document.page_content for document in chunk_documents]
+                tokens += sum(embedding_model_instance.get_text_embedding_num_tokens(page_content_list))
 
 
             # load index
             # load index
             index_processor.load(dataset, chunk_documents, with_keywords=False)
             index_processor.load(dataset, chunk_documents, with_keywords=False)

+ 2 - 2
api/core/model_manager.py

@@ -175,7 +175,7 @@ class ModelInstance:
 
 
     def get_llm_num_tokens(
     def get_llm_num_tokens(
         self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
         self, prompt_messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None
-    ) -> int:
+    ) -> list[int]:
         """
         """
         Get number of tokens for llm
         Get number of tokens for llm
 
 
@@ -235,7 +235,7 @@ class ModelInstance:
             model=self.model,
             model=self.model,
             credentials=self.credentials,
             credentials=self.credentials,
             texts=texts,
             texts=texts,
-        )[0]  # TODO: fix this, this is only for temporary compatibility with old
+        )
 
 
     def invoke_rerank(
     def invoke_rerank(
         self,
         self,

+ 7 - 7
api/core/rag/docstore/dataset_docstore.py

@@ -79,7 +79,13 @@ class DatasetDocumentStore:
                 model=self._dataset.embedding_model,
                 model=self._dataset.embedding_model,
             )
             )
 
 
-        for doc in docs:
+        if embedding_model:
+            page_content_list = [doc.page_content for doc in docs]
+            tokens_list = embedding_model.get_text_embedding_num_tokens(page_content_list)
+        else:
+            tokens_list = [0] * len(docs)
+
+        for doc, tokens in zip(docs, tokens_list):
             if not isinstance(doc, Document):
             if not isinstance(doc, Document):
                 raise ValueError("doc must be a Document")
                 raise ValueError("doc must be a Document")
 
 
@@ -91,12 +97,6 @@ class DatasetDocumentStore:
                     f"doc_id {doc.metadata['doc_id']} already exists. Set allow_update to True to overwrite."
                     f"doc_id {doc.metadata['doc_id']} already exists. Set allow_update to True to overwrite."
                 )
                 )
 
 
-            # calc embedding use tokens
-            if embedding_model:
-                tokens = embedding_model.get_text_embedding_num_tokens(texts=[doc.page_content])
-            else:
-                tokens = 0
-
             if not segment_document:
             if not segment_document:
                 max_position += 1
                 max_position += 1
 
 

+ 5 - 3
api/core/rag/splitter/fixed_text_splitter.py

@@ -65,8 +65,9 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
             chunks = [text]
             chunks = [text]
 
 
         final_chunks = []
         final_chunks = []
-        for chunk in chunks:
-            if self._length_function(chunk) > self._chunk_size:
+        chunks_lengths = self._length_function(chunks)
+        for chunk, chunk_length in zip(chunks, chunks_lengths):
+            if chunk_length > self._chunk_size:
                 final_chunks.extend(self.recursive_split_text(chunk))
                 final_chunks.extend(self.recursive_split_text(chunk))
             else:
             else:
                 final_chunks.append(chunk)
                 final_chunks.append(chunk)
@@ -93,7 +94,8 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
         # Now go merging things, recursively splitting longer texts.
         # Now go merging things, recursively splitting longer texts.
         _good_splits = []
         _good_splits = []
         _good_splits_lengths = []  # cache the lengths of the splits
         _good_splits_lengths = []  # cache the lengths of the splits
-        for s in splits:
+        s_lens = self._length_function(splits)
+        for s, s_len in zip(splits, s_lens):
             s_len = self._length_function(s)
             s_len = self._length_function(s)
             if s_len < self._chunk_size:
             if s_len < self._chunk_size:
                 _good_splits.append(s)
                 _good_splits.append(s)

+ 5 - 6
api/core/rag/splitter/text_splitter.py

@@ -45,7 +45,7 @@ class TextSplitter(BaseDocumentTransformer, ABC):
         self,
         self,
         chunk_size: int = 4000,
         chunk_size: int = 4000,
         chunk_overlap: int = 200,
         chunk_overlap: int = 200,
-        length_function: Callable[[str], int] = len,
+        length_function: Callable[[list[str]], list[int]] = lambda x: [len(x) for x in x],
         keep_separator: bool = False,
         keep_separator: bool = False,
         add_start_index: bool = False,
         add_start_index: bool = False,
     ) -> None:
     ) -> None:
@@ -224,8 +224,8 @@ class CharacterTextSplitter(TextSplitter):
         splits = _split_text_with_regex(text, self._separator, self._keep_separator)
         splits = _split_text_with_regex(text, self._separator, self._keep_separator)
         _separator = "" if self._keep_separator else self._separator
         _separator = "" if self._keep_separator else self._separator
         _good_splits_lengths = []  # cache the lengths of the splits
         _good_splits_lengths = []  # cache the lengths of the splits
-        for split in splits:
-            _good_splits_lengths.append(self._length_function(split))
+        if splits:
+            _good_splits_lengths.extend(self._length_function(splits))
         return self._merge_splits(splits, _separator, _good_splits_lengths)
         return self._merge_splits(splits, _separator, _good_splits_lengths)
 
 
 
 
@@ -478,9 +478,8 @@ class RecursiveCharacterTextSplitter(TextSplitter):
         _good_splits = []
         _good_splits = []
         _good_splits_lengths = []  # cache the lengths of the splits
         _good_splits_lengths = []  # cache the lengths of the splits
         _separator = "" if self._keep_separator else separator
         _separator = "" if self._keep_separator else separator
-
-        for s in splits:
-            s_len = self._length_function(s)
+        s_lens = self._length_function(splits)
+        for s, s_len in zip(splits, s_lens):
             if s_len < self._chunk_size:
             if s_len < self._chunk_size:
                 _good_splits.append(s)
                 _good_splits.append(s)
                 _good_splits_lengths.append(s_len)
                 _good_splits_lengths.append(s_len)

+ 8 - 5
api/services/dataset_service.py

@@ -1390,7 +1390,7 @@ class SegmentService:
                 model=dataset.embedding_model,
                 model=dataset.embedding_model,
             )
             )
             # calc embedding use tokens
             # calc embedding use tokens
-            tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
+            tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
         lock_name = "add_segment_lock_document_id_{}".format(document.id)
         lock_name = "add_segment_lock_document_id_{}".format(document.id)
         with redis_client.lock(lock_name, timeout=600):
         with redis_client.lock(lock_name, timeout=600):
             max_position = (
             max_position = (
@@ -1467,9 +1467,12 @@ class SegmentService:
                 if dataset.indexing_technique == "high_quality" and embedding_model:
                 if dataset.indexing_technique == "high_quality" and embedding_model:
                     # calc embedding use tokens
                     # calc embedding use tokens
                     if document.doc_form == "qa_model":
                     if document.doc_form == "qa_model":
-                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment_item["answer"]])
+                        tokens = embedding_model.get_text_embedding_num_tokens(
+                            texts=[content + segment_item["answer"]]
+                        )[0]
                     else:
                     else:
-                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
+                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
+
                 segment_document = DocumentSegment(
                 segment_document = DocumentSegment(
                     tenant_id=current_user.current_tenant_id,
                     tenant_id=current_user.current_tenant_id,
                     dataset_id=document.dataset_id,
                     dataset_id=document.dataset_id,
@@ -1577,9 +1580,9 @@ class SegmentService:
 
 
                     # calc embedding use tokens
                     # calc embedding use tokens
                     if document.doc_form == "qa_model":
                     if document.doc_form == "qa_model":
-                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])
+                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0]
                     else:
                     else:
-                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])
+                        tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
                 segment.content = content
                 segment.content = content
                 segment.index_node_hash = segment_hash
                 segment.index_node_hash = segment_hash
                 segment.word_count = len(content)
                 segment.word_count = len(content)

+ 7 - 3
api/tasks/batch_create_segment_to_index_task.py

@@ -58,12 +58,16 @@ def batch_create_segment_to_index_task(
                 model=dataset.embedding_model,
                 model=dataset.embedding_model,
             )
             )
         word_count_change = 0
         word_count_change = 0
-        for segment in content:
+        if embedding_model:
+            tokens_list = embedding_model.get_text_embedding_num_tokens(
+                texts=[segment["content"] for segment in content]
+            )
+        else:
+            tokens_list = [0] * len(content)
+        for segment, tokens in zip(content, tokens_list):
             content = segment["content"]
             content = segment["content"]
             doc_id = str(uuid.uuid4())
             doc_id = str(uuid.uuid4())
             segment_hash = helper.generate_text_hash(content)
             segment_hash = helper.generate_text_hash(content)
-            # calc embedding use tokens
-            tokens = embedding_model.get_text_embedding_num_tokens(texts=[content]) if embedding_model else 0
             max_position = (
             max_position = (
                 db.session.query(func.max(DocumentSegment.position))
                 db.session.query(func.max(DocumentSegment.position))
                 .filter(DocumentSegment.document_id == dataset_document.id)
                 .filter(DocumentSegment.document_id == dataset_document.id)