indexing_runner.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872
  1. import concurrent.futures
  2. import datetime
  3. import json
  4. import logging
  5. import re
  6. import threading
  7. import time
  8. import uuid
  9. from typing import Optional, cast
  10. from flask import Flask, current_app
  11. from flask_login import current_user
  12. from sqlalchemy.orm.exc import ObjectDeletedError
  13. from core.errors.error import ProviderTokenNotInitError
  14. from core.llm_generator.llm_generator import LLMGenerator
  15. from core.model_manager import ModelInstance, ModelManager
  16. from core.model_runtime.entities.model_entities import ModelType, PriceType
  17. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  18. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  19. from core.rag.datasource.keyword.keyword_factory import Keyword
  20. from core.rag.docstore.dataset_docstore import DatasetDocumentStore
  21. from core.rag.extractor.entity.extract_setting import ExtractSetting
  22. from core.rag.index_processor.index_processor_base import BaseIndexProcessor
  23. from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
  24. from core.rag.models.document import Document
  25. from core.rag.splitter.fixed_text_splitter import (
  26. EnhanceRecursiveCharacterTextSplitter,
  27. FixedRecursiveCharacterTextSplitter,
  28. )
  29. from core.rag.splitter.text_splitter import TextSplitter
  30. from extensions.ext_database import db
  31. from extensions.ext_redis import redis_client
  32. from extensions.ext_storage import storage
  33. from libs import helper
  34. from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
  35. from models.dataset import Document as DatasetDocument
  36. from models.model import UploadFile
  37. from services.feature_service import FeatureService
  38. class IndexingRunner:
  39. def __init__(self):
  40. self.storage = storage
  41. self.model_manager = ModelManager()
  42. def run(self, dataset_documents: list[DatasetDocument]):
  43. """Run the indexing process."""
  44. for dataset_document in dataset_documents:
  45. try:
  46. # get dataset
  47. dataset = Dataset.query.filter_by(
  48. id=dataset_document.dataset_id
  49. ).first()
  50. if not dataset:
  51. raise ValueError("no dataset found")
  52. # get the process rule
  53. processing_rule = db.session.query(DatasetProcessRule). \
  54. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  55. first()
  56. index_type = dataset_document.doc_form
  57. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  58. # extract
  59. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  60. # transform
  61. documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
  62. processing_rule.to_dict())
  63. # save segment
  64. self._load_segments(dataset, dataset_document, documents)
  65. # load
  66. self._load(
  67. index_processor=index_processor,
  68. dataset=dataset,
  69. dataset_document=dataset_document,
  70. documents=documents
  71. )
  72. except DocumentIsPausedException:
  73. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  74. except ProviderTokenNotInitError as e:
  75. dataset_document.indexing_status = 'error'
  76. dataset_document.error = str(e.description)
  77. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  78. db.session.commit()
  79. except ObjectDeletedError:
  80. logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
  81. except Exception as e:
  82. logging.exception("consume document failed")
  83. dataset_document.indexing_status = 'error'
  84. dataset_document.error = str(e)
  85. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  86. db.session.commit()
  87. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  88. """Run the indexing process when the index_status is splitting."""
  89. try:
  90. # get dataset
  91. dataset = Dataset.query.filter_by(
  92. id=dataset_document.dataset_id
  93. ).first()
  94. if not dataset:
  95. raise ValueError("no dataset found")
  96. # get exist document_segment list and delete
  97. document_segments = DocumentSegment.query.filter_by(
  98. dataset_id=dataset.id,
  99. document_id=dataset_document.id
  100. ).all()
  101. for document_segment in document_segments:
  102. db.session.delete(document_segment)
  103. db.session.commit()
  104. # get the process rule
  105. processing_rule = db.session.query(DatasetProcessRule). \
  106. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  107. first()
  108. index_type = dataset_document.doc_form
  109. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  110. # extract
  111. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  112. # transform
  113. documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
  114. processing_rule.to_dict())
  115. # save segment
  116. self._load_segments(dataset, dataset_document, documents)
  117. # load
  118. self._load(
  119. index_processor=index_processor,
  120. dataset=dataset,
  121. dataset_document=dataset_document,
  122. documents=documents
  123. )
  124. except DocumentIsPausedException:
  125. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  126. except ProviderTokenNotInitError as e:
  127. dataset_document.indexing_status = 'error'
  128. dataset_document.error = str(e.description)
  129. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  130. db.session.commit()
  131. except Exception as e:
  132. logging.exception("consume document failed")
  133. dataset_document.indexing_status = 'error'
  134. dataset_document.error = str(e)
  135. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  136. db.session.commit()
  137. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  138. """Run the indexing process when the index_status is indexing."""
  139. try:
  140. # get dataset
  141. dataset = Dataset.query.filter_by(
  142. id=dataset_document.dataset_id
  143. ).first()
  144. if not dataset:
  145. raise ValueError("no dataset found")
  146. # get exist document_segment list and delete
  147. document_segments = DocumentSegment.query.filter_by(
  148. dataset_id=dataset.id,
  149. document_id=dataset_document.id
  150. ).all()
  151. documents = []
  152. if document_segments:
  153. for document_segment in document_segments:
  154. # transform segment to node
  155. if document_segment.status != "completed":
  156. document = Document(
  157. page_content=document_segment.content,
  158. metadata={
  159. "doc_id": document_segment.index_node_id,
  160. "doc_hash": document_segment.index_node_hash,
  161. "document_id": document_segment.document_id,
  162. "dataset_id": document_segment.dataset_id,
  163. }
  164. )
  165. documents.append(document)
  166. # build index
  167. # get the process rule
  168. processing_rule = db.session.query(DatasetProcessRule). \
  169. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  170. first()
  171. index_type = dataset_document.doc_form
  172. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  173. self._load(
  174. index_processor=index_processor,
  175. dataset=dataset,
  176. dataset_document=dataset_document,
  177. documents=documents
  178. )
  179. except DocumentIsPausedException:
  180. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  181. except ProviderTokenNotInitError as e:
  182. dataset_document.indexing_status = 'error'
  183. dataset_document.error = str(e.description)
  184. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  185. db.session.commit()
  186. except Exception as e:
  187. logging.exception("consume document failed")
  188. dataset_document.indexing_status = 'error'
  189. dataset_document.error = str(e)
  190. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  191. db.session.commit()
  192. def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
  193. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  194. indexing_technique: str = 'economy') -> dict:
  195. """
  196. Estimate the indexing for the document.
  197. """
  198. # check document limit
  199. features = FeatureService.get_features(tenant_id)
  200. if features.billing.enabled:
  201. count = len(extract_settings)
  202. batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
  203. if count > batch_upload_limit:
  204. raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
  205. embedding_model_instance = None
  206. if dataset_id:
  207. dataset = Dataset.query.filter_by(
  208. id=dataset_id
  209. ).first()
  210. if not dataset:
  211. raise ValueError('Dataset not found.')
  212. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  213. if dataset.embedding_model_provider:
  214. embedding_model_instance = self.model_manager.get_model_instance(
  215. tenant_id=tenant_id,
  216. provider=dataset.embedding_model_provider,
  217. model_type=ModelType.TEXT_EMBEDDING,
  218. model=dataset.embedding_model
  219. )
  220. else:
  221. embedding_model_instance = self.model_manager.get_default_model_instance(
  222. tenant_id=tenant_id,
  223. model_type=ModelType.TEXT_EMBEDDING,
  224. )
  225. else:
  226. if indexing_technique == 'high_quality':
  227. embedding_model_instance = self.model_manager.get_default_model_instance(
  228. tenant_id=tenant_id,
  229. model_type=ModelType.TEXT_EMBEDDING,
  230. )
  231. tokens = 0
  232. preview_texts = []
  233. total_segments = 0
  234. total_price = 0
  235. currency = 'USD'
  236. index_type = doc_form
  237. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  238. all_text_docs = []
  239. for extract_setting in extract_settings:
  240. # extract
  241. text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
  242. all_text_docs.extend(text_docs)
  243. processing_rule = DatasetProcessRule(
  244. mode=tmp_processing_rule["mode"],
  245. rules=json.dumps(tmp_processing_rule["rules"])
  246. )
  247. # get splitter
  248. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  249. # split to documents
  250. documents = self._split_to_documents_for_estimate(
  251. text_docs=text_docs,
  252. splitter=splitter,
  253. processing_rule=processing_rule
  254. )
  255. total_segments += len(documents)
  256. for document in documents:
  257. if len(preview_texts) < 5:
  258. preview_texts.append(document.page_content)
  259. if indexing_technique == 'high_quality' or embedding_model_instance:
  260. embedding_model_type_instance = embedding_model_instance.model_type_instance
  261. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  262. tokens += embedding_model_type_instance.get_num_tokens(
  263. model=embedding_model_instance.model,
  264. credentials=embedding_model_instance.credentials,
  265. texts=[self.filter_string(document.page_content)]
  266. )
  267. if doc_form and doc_form == 'qa_model':
  268. model_instance = self.model_manager.get_default_model_instance(
  269. tenant_id=tenant_id,
  270. model_type=ModelType.LLM
  271. )
  272. model_type_instance = model_instance.model_type_instance
  273. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  274. if len(preview_texts) > 0:
  275. # qa model document
  276. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
  277. doc_language)
  278. document_qa_list = self.format_split_text(response)
  279. price_info = model_type_instance.get_price(
  280. model=model_instance.model,
  281. credentials=model_instance.credentials,
  282. price_type=PriceType.INPUT,
  283. tokens=total_segments * 2000,
  284. )
  285. return {
  286. "total_segments": total_segments * 20,
  287. "tokens": total_segments * 2000,
  288. "total_price": '{:f}'.format(price_info.total_amount),
  289. "currency": price_info.currency,
  290. "qa_preview": document_qa_list,
  291. "preview": preview_texts
  292. }
  293. if embedding_model_instance:
  294. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
  295. embedding_price_info = embedding_model_type_instance.get_price(
  296. model=embedding_model_instance.model,
  297. credentials=embedding_model_instance.credentials,
  298. price_type=PriceType.INPUT,
  299. tokens=tokens
  300. )
  301. total_price = '{:f}'.format(embedding_price_info.total_amount)
  302. currency = embedding_price_info.currency
  303. return {
  304. "total_segments": total_segments,
  305. "tokens": tokens,
  306. "total_price": total_price,
  307. "currency": currency,
  308. "preview": preview_texts
  309. }
  310. def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
  311. -> list[Document]:
  312. # load file
  313. if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
  314. return []
  315. data_source_info = dataset_document.data_source_info_dict
  316. text_docs = []
  317. if dataset_document.data_source_type == 'upload_file':
  318. if not data_source_info or 'upload_file_id' not in data_source_info:
  319. raise ValueError("no upload file found")
  320. file_detail = db.session.query(UploadFile). \
  321. filter(UploadFile.id == data_source_info['upload_file_id']). \
  322. one_or_none()
  323. if file_detail:
  324. extract_setting = ExtractSetting(
  325. datasource_type="upload_file",
  326. upload_file=file_detail,
  327. document_model=dataset_document.doc_form
  328. )
  329. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
  330. elif dataset_document.data_source_type == 'notion_import':
  331. if (not data_source_info or 'notion_workspace_id' not in data_source_info
  332. or 'notion_page_id' not in data_source_info):
  333. raise ValueError("no notion import info found")
  334. extract_setting = ExtractSetting(
  335. datasource_type="notion_import",
  336. notion_info={
  337. "notion_workspace_id": data_source_info['notion_workspace_id'],
  338. "notion_obj_id": data_source_info['notion_page_id'],
  339. "notion_page_type": data_source_info['type'],
  340. "document": dataset_document,
  341. "tenant_id": dataset_document.tenant_id
  342. },
  343. document_model=dataset_document.doc_form
  344. )
  345. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
  346. # update document status to splitting
  347. self._update_document_index_status(
  348. document_id=dataset_document.id,
  349. after_indexing_status="splitting",
  350. extra_update_params={
  351. DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
  352. DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  353. }
  354. )
  355. # replace doc id to document model id
  356. text_docs = cast(list[Document], text_docs)
  357. for text_doc in text_docs:
  358. text_doc.metadata['document_id'] = dataset_document.id
  359. text_doc.metadata['dataset_id'] = dataset_document.dataset_id
  360. return text_docs
  361. def filter_string(self, text):
  362. text = re.sub(r'<\|', '<', text)
  363. text = re.sub(r'\|>', '>', text)
  364. text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
  365. # Unicode U+FFFE
  366. text = re.sub('\uFFFE', '', text)
  367. return text
  368. def _get_splitter(self, processing_rule: DatasetProcessRule,
  369. embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
  370. """
  371. Get the NodeParser object according to the processing rule.
  372. """
  373. if processing_rule.mode == "custom":
  374. # The user-defined segmentation rule
  375. rules = json.loads(processing_rule.rules)
  376. segmentation = rules["segmentation"]
  377. max_segmentation_tokens_length = int(current_app.config['INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH'])
  378. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
  379. raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
  380. separator = segmentation["separator"]
  381. if separator:
  382. separator = separator.replace('\\n', '\n')
  383. if segmentation.get('chunk_overlap'):
  384. chunk_overlap = segmentation['chunk_overlap']
  385. else:
  386. chunk_overlap = 0
  387. character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
  388. chunk_size=segmentation["max_tokens"],
  389. chunk_overlap=chunk_overlap,
  390. fixed_separator=separator,
  391. separators=["\n\n", "。", ". ", " ", ""],
  392. embedding_model_instance=embedding_model_instance
  393. )
  394. else:
  395. # Automatic segmentation
  396. character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
  397. chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
  398. chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
  399. separators=["\n\n", "。", ". ", " ", ""],
  400. embedding_model_instance=embedding_model_instance
  401. )
  402. return character_splitter
  403. def _step_split(self, text_docs: list[Document], splitter: TextSplitter,
  404. dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
  405. -> list[Document]:
  406. """
  407. Split the text documents into documents and save them to the document segment.
  408. """
  409. documents = self._split_to_documents(
  410. text_docs=text_docs,
  411. splitter=splitter,
  412. processing_rule=processing_rule,
  413. tenant_id=dataset.tenant_id,
  414. document_form=dataset_document.doc_form,
  415. document_language=dataset_document.doc_language
  416. )
  417. # save node to document segment
  418. doc_store = DatasetDocumentStore(
  419. dataset=dataset,
  420. user_id=dataset_document.created_by,
  421. document_id=dataset_document.id
  422. )
  423. # add document segments
  424. doc_store.add_documents(documents)
  425. # update document status to indexing
  426. cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  427. self._update_document_index_status(
  428. document_id=dataset_document.id,
  429. after_indexing_status="indexing",
  430. extra_update_params={
  431. DatasetDocument.cleaning_completed_at: cur_time,
  432. DatasetDocument.splitting_completed_at: cur_time,
  433. }
  434. )
  435. # update segment status to indexing
  436. self._update_segments_by_document(
  437. dataset_document_id=dataset_document.id,
  438. update_params={
  439. DocumentSegment.status: "indexing",
  440. DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  441. }
  442. )
  443. return documents
  444. def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter,
  445. processing_rule: DatasetProcessRule, tenant_id: str,
  446. document_form: str, document_language: str) -> list[Document]:
  447. """
  448. Split the text documents into nodes.
  449. """
  450. all_documents = []
  451. all_qa_documents = []
  452. for text_doc in text_docs:
  453. # document clean
  454. document_text = self._document_clean(text_doc.page_content, processing_rule)
  455. text_doc.page_content = document_text
  456. # parse document to nodes
  457. documents = splitter.split_documents([text_doc])
  458. split_documents = []
  459. for document_node in documents:
  460. if document_node.page_content.strip():
  461. doc_id = str(uuid.uuid4())
  462. hash = helper.generate_text_hash(document_node.page_content)
  463. document_node.metadata['doc_id'] = doc_id
  464. document_node.metadata['doc_hash'] = hash
  465. # delete Spliter character
  466. page_content = document_node.page_content
  467. if page_content.startswith(".") or page_content.startswith("。"):
  468. page_content = page_content[1:]
  469. else:
  470. page_content = page_content
  471. document_node.page_content = page_content
  472. if document_node.page_content:
  473. split_documents.append(document_node)
  474. all_documents.extend(split_documents)
  475. # processing qa document
  476. if document_form == 'qa_model':
  477. for i in range(0, len(all_documents), 10):
  478. threads = []
  479. sub_documents = all_documents[i:i + 10]
  480. for doc in sub_documents:
  481. document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
  482. 'flask_app': current_app._get_current_object(),
  483. 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
  484. 'document_language': document_language})
  485. threads.append(document_format_thread)
  486. document_format_thread.start()
  487. for thread in threads:
  488. thread.join()
  489. return all_qa_documents
  490. return all_documents
  491. def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
  492. format_documents = []
  493. if document_node.page_content is None or not document_node.page_content.strip():
  494. return
  495. with flask_app.app_context():
  496. try:
  497. # qa model document
  498. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
  499. document_qa_list = self.format_split_text(response)
  500. qa_documents = []
  501. for result in document_qa_list:
  502. qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
  503. doc_id = str(uuid.uuid4())
  504. hash = helper.generate_text_hash(result['question'])
  505. qa_document.metadata['answer'] = result['answer']
  506. qa_document.metadata['doc_id'] = doc_id
  507. qa_document.metadata['doc_hash'] = hash
  508. qa_documents.append(qa_document)
  509. format_documents.extend(qa_documents)
  510. except Exception as e:
  511. logging.exception(e)
  512. all_qa_documents.extend(format_documents)
  513. def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter,
  514. processing_rule: DatasetProcessRule) -> list[Document]:
  515. """
  516. Split the text documents into nodes.
  517. """
  518. all_documents = []
  519. for text_doc in text_docs:
  520. # document clean
  521. document_text = self._document_clean(text_doc.page_content, processing_rule)
  522. text_doc.page_content = document_text
  523. # parse document to nodes
  524. documents = splitter.split_documents([text_doc])
  525. split_documents = []
  526. for document in documents:
  527. if document.page_content is None or not document.page_content.strip():
  528. continue
  529. doc_id = str(uuid.uuid4())
  530. hash = helper.generate_text_hash(document.page_content)
  531. document.metadata['doc_id'] = doc_id
  532. document.metadata['doc_hash'] = hash
  533. split_documents.append(document)
  534. all_documents.extend(split_documents)
  535. return all_documents
  536. def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
  537. """
  538. Clean the document text according to the processing rules.
  539. """
  540. if processing_rule.mode == "automatic":
  541. rules = DatasetProcessRule.AUTOMATIC_RULES
  542. else:
  543. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  544. if 'pre_processing_rules' in rules:
  545. pre_processing_rules = rules["pre_processing_rules"]
  546. for pre_processing_rule in pre_processing_rules:
  547. if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
  548. # Remove extra spaces
  549. pattern = r'\n{3,}'
  550. text = re.sub(pattern, '\n\n', text)
  551. pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
  552. text = re.sub(pattern, ' ', text)
  553. elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
  554. # Remove email
  555. pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
  556. text = re.sub(pattern, '', text)
  557. # Remove URL
  558. pattern = r'https?://[^\s]+'
  559. text = re.sub(pattern, '', text)
  560. return text
  561. def format_split_text(self, text):
  562. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
  563. matches = re.findall(regex, text, re.UNICODE)
  564. return [
  565. {
  566. "question": q,
  567. "answer": re.sub(r"\n\s*", "\n", a.strip())
  568. }
  569. for q, a in matches if q and a
  570. ]
  571. def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
  572. dataset_document: DatasetDocument, documents: list[Document]) -> None:
  573. """
  574. insert index and update document/segment status to completed
  575. """
  576. embedding_model_instance = None
  577. if dataset.indexing_technique == 'high_quality':
  578. embedding_model_instance = self.model_manager.get_model_instance(
  579. tenant_id=dataset.tenant_id,
  580. provider=dataset.embedding_model_provider,
  581. model_type=ModelType.TEXT_EMBEDDING,
  582. model=dataset.embedding_model
  583. )
  584. # chunk nodes by chunk size
  585. indexing_start_at = time.perf_counter()
  586. tokens = 0
  587. chunk_size = 10
  588. embedding_model_type_instance = None
  589. if embedding_model_instance:
  590. embedding_model_type_instance = embedding_model_instance.model_type_instance
  591. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  592. # create keyword index
  593. create_keyword_thread = threading.Thread(target=self._process_keyword_index,
  594. args=(current_app._get_current_object(),
  595. dataset.id, dataset_document.id, documents))
  596. create_keyword_thread.start()
  597. if dataset.indexing_technique == 'high_quality':
  598. with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
  599. futures = []
  600. for i in range(0, len(documents), chunk_size):
  601. chunk_documents = documents[i:i + chunk_size]
  602. futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
  603. chunk_documents, dataset,
  604. dataset_document, embedding_model_instance,
  605. embedding_model_type_instance))
  606. for future in futures:
  607. tokens += future.result()
  608. create_keyword_thread.join()
  609. indexing_end_at = time.perf_counter()
  610. # update document status to completed
  611. self._update_document_index_status(
  612. document_id=dataset_document.id,
  613. after_indexing_status="completed",
  614. extra_update_params={
  615. DatasetDocument.tokens: tokens,
  616. DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  617. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  618. }
  619. )
  620. def _process_keyword_index(self, flask_app, dataset_id, document_id, documents):
  621. with flask_app.app_context():
  622. dataset = Dataset.query.filter_by(id=dataset_id).first()
  623. if not dataset:
  624. raise ValueError("no dataset found")
  625. keyword = Keyword(dataset)
  626. keyword.create(documents)
  627. if dataset.indexing_technique != 'high_quality':
  628. document_ids = [document.metadata['doc_id'] for document in documents]
  629. db.session.query(DocumentSegment).filter(
  630. DocumentSegment.document_id == document_id,
  631. DocumentSegment.index_node_id.in_(document_ids),
  632. DocumentSegment.status == "indexing"
  633. ).update({
  634. DocumentSegment.status: "completed",
  635. DocumentSegment.enabled: True,
  636. DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  637. })
  638. db.session.commit()
  639. def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
  640. embedding_model_instance, embedding_model_type_instance):
  641. with flask_app.app_context():
  642. # check document is paused
  643. self._check_document_paused_status(dataset_document.id)
  644. tokens = 0
  645. if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
  646. tokens += sum(
  647. embedding_model_type_instance.get_num_tokens(
  648. embedding_model_instance.model,
  649. embedding_model_instance.credentials,
  650. [document.page_content]
  651. )
  652. for document in chunk_documents
  653. )
  654. # load index
  655. index_processor.load(dataset, chunk_documents, with_keywords=False)
  656. document_ids = [document.metadata['doc_id'] for document in chunk_documents]
  657. db.session.query(DocumentSegment).filter(
  658. DocumentSegment.document_id == dataset_document.id,
  659. DocumentSegment.index_node_id.in_(document_ids),
  660. DocumentSegment.status == "indexing"
  661. ).update({
  662. DocumentSegment.status: "completed",
  663. DocumentSegment.enabled: True,
  664. DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  665. })
  666. db.session.commit()
  667. return tokens
  668. def _check_document_paused_status(self, document_id: str):
  669. indexing_cache_key = 'document_{}_is_paused'.format(document_id)
  670. result = redis_client.get(indexing_cache_key)
  671. if result:
  672. raise DocumentIsPausedException()
  673. def _update_document_index_status(self, document_id: str, after_indexing_status: str,
  674. extra_update_params: Optional[dict] = None) -> None:
  675. """
  676. Update the document indexing status.
  677. """
  678. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  679. if count > 0:
  680. raise DocumentIsPausedException()
  681. document = DatasetDocument.query.filter_by(id=document_id).first()
  682. if not document:
  683. raise DocumentIsDeletedPausedException()
  684. update_params = {
  685. DatasetDocument.indexing_status: after_indexing_status
  686. }
  687. if extra_update_params:
  688. update_params.update(extra_update_params)
  689. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  690. db.session.commit()
  691. def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
  692. """
  693. Update the document segment by document id.
  694. """
  695. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  696. db.session.commit()
  697. def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset):
  698. """
  699. Batch add segments index processing
  700. """
  701. documents = []
  702. for segment in segments:
  703. document = Document(
  704. page_content=segment.content,
  705. metadata={
  706. "doc_id": segment.index_node_id,
  707. "doc_hash": segment.index_node_hash,
  708. "document_id": segment.document_id,
  709. "dataset_id": segment.dataset_id,
  710. }
  711. )
  712. documents.append(document)
  713. # save vector index
  714. index_type = dataset.doc_form
  715. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  716. index_processor.load(dataset, documents)
  717. def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
  718. text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]:
  719. # get embedding model instance
  720. embedding_model_instance = None
  721. if dataset.indexing_technique == 'high_quality':
  722. if dataset.embedding_model_provider:
  723. embedding_model_instance = self.model_manager.get_model_instance(
  724. tenant_id=dataset.tenant_id,
  725. provider=dataset.embedding_model_provider,
  726. model_type=ModelType.TEXT_EMBEDDING,
  727. model=dataset.embedding_model
  728. )
  729. else:
  730. embedding_model_instance = self.model_manager.get_default_model_instance(
  731. tenant_id=dataset.tenant_id,
  732. model_type=ModelType.TEXT_EMBEDDING,
  733. )
  734. documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
  735. process_rule=process_rule, tenant_id=dataset.tenant_id,
  736. doc_language=doc_language)
  737. return documents
  738. def _load_segments(self, dataset, dataset_document, documents):
  739. # save node to document segment
  740. doc_store = DatasetDocumentStore(
  741. dataset=dataset,
  742. user_id=dataset_document.created_by,
  743. document_id=dataset_document.id
  744. )
  745. # add document segments
  746. doc_store.add_documents(documents)
  747. # update document status to indexing
  748. cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  749. self._update_document_index_status(
  750. document_id=dataset_document.id,
  751. after_indexing_status="indexing",
  752. extra_update_params={
  753. DatasetDocument.cleaning_completed_at: cur_time,
  754. DatasetDocument.splitting_completed_at: cur_time,
  755. }
  756. )
  757. # update segment status to indexing
  758. self._update_segments_by_document(
  759. dataset_document_id=dataset_document.id,
  760. update_params={
  761. DocumentSegment.status: "indexing",
  762. DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  763. }
  764. )
  765. pass
  766. class DocumentIsPausedException(Exception):
  767. pass
  768. class DocumentIsDeletedPausedException(Exception):
  769. pass