indexing_runner.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879
  1. import concurrent.futures
  2. import datetime
  3. import json
  4. import logging
  5. import re
  6. import threading
  7. import time
  8. import uuid
  9. from typing import Optional, cast
  10. from flask import Flask, current_app
  11. from flask_login import current_user
  12. from sqlalchemy.orm.exc import ObjectDeletedError
  13. from configs import dify_config
  14. from core.errors.error import ProviderTokenNotInitError
  15. from core.llm_generator.llm_generator import LLMGenerator
  16. from core.model_manager import ModelInstance, ModelManager
  17. from core.model_runtime.entities.model_entities import ModelType
  18. from core.rag.cleaner.clean_processor import CleanProcessor
  19. from core.rag.datasource.keyword.keyword_factory import Keyword
  20. from core.rag.docstore.dataset_docstore import DatasetDocumentStore
  21. from core.rag.extractor.entity.extract_setting import ExtractSetting
  22. from core.rag.index_processor.index_processor_base import BaseIndexProcessor
  23. from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
  24. from core.rag.models.document import Document
  25. from core.rag.splitter.fixed_text_splitter import (
  26. EnhanceRecursiveCharacterTextSplitter,
  27. FixedRecursiveCharacterTextSplitter,
  28. )
  29. from core.rag.splitter.text_splitter import TextSplitter
  30. from core.tools.utils.text_processing_utils import remove_leading_symbols
  31. from core.tools.utils.web_reader_tool import get_image_upload_file_ids
  32. from extensions.ext_database import db
  33. from extensions.ext_redis import redis_client
  34. from extensions.ext_storage import storage
  35. from libs import helper
  36. from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
  37. from models.dataset import Document as DatasetDocument
  38. from models.model import UploadFile
  39. from services.feature_service import FeatureService
  40. class IndexingRunner:
  41. def __init__(self):
  42. self.storage = storage
  43. self.model_manager = ModelManager()
  44. def run(self, dataset_documents: list[DatasetDocument]):
  45. """Run the indexing process."""
  46. for dataset_document in dataset_documents:
  47. try:
  48. # get dataset
  49. dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
  50. if not dataset:
  51. raise ValueError("no dataset found")
  52. # get the process rule
  53. processing_rule = (
  54. db.session.query(DatasetProcessRule)
  55. .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
  56. .first()
  57. )
  58. index_type = dataset_document.doc_form
  59. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  60. # extract
  61. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  62. # transform
  63. documents = self._transform(
  64. index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
  65. )
  66. # save segment
  67. self._load_segments(dataset, dataset_document, documents)
  68. # load
  69. self._load(
  70. index_processor=index_processor,
  71. dataset=dataset,
  72. dataset_document=dataset_document,
  73. documents=documents,
  74. )
  75. except DocumentIsPausedError:
  76. raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
  77. except ProviderTokenNotInitError as e:
  78. dataset_document.indexing_status = "error"
  79. dataset_document.error = str(e.description)
  80. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  81. db.session.commit()
  82. except ObjectDeletedError:
  83. logging.warning("Document deleted, document id: {}".format(dataset_document.id))
  84. except Exception as e:
  85. logging.exception("consume document failed")
  86. dataset_document.indexing_status = "error"
  87. dataset_document.error = str(e)
  88. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  89. db.session.commit()
  90. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  91. """Run the indexing process when the index_status is splitting."""
  92. try:
  93. # get dataset
  94. dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
  95. if not dataset:
  96. raise ValueError("no dataset found")
  97. # get exist document_segment list and delete
  98. document_segments = DocumentSegment.query.filter_by(
  99. dataset_id=dataset.id, document_id=dataset_document.id
  100. ).all()
  101. for document_segment in document_segments:
  102. db.session.delete(document_segment)
  103. db.session.commit()
  104. # get the process rule
  105. processing_rule = (
  106. db.session.query(DatasetProcessRule)
  107. .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
  108. .first()
  109. )
  110. index_type = dataset_document.doc_form
  111. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  112. # extract
  113. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  114. # transform
  115. documents = self._transform(
  116. index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
  117. )
  118. # save segment
  119. self._load_segments(dataset, dataset_document, documents)
  120. # load
  121. self._load(
  122. index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
  123. )
  124. except DocumentIsPausedError:
  125. raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
  126. except ProviderTokenNotInitError as e:
  127. dataset_document.indexing_status = "error"
  128. dataset_document.error = str(e.description)
  129. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  130. db.session.commit()
  131. except Exception as e:
  132. logging.exception("consume document failed")
  133. dataset_document.indexing_status = "error"
  134. dataset_document.error = str(e)
  135. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  136. db.session.commit()
  137. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  138. """Run the indexing process when the index_status is indexing."""
  139. try:
  140. # get dataset
  141. dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
  142. if not dataset:
  143. raise ValueError("no dataset found")
  144. # get exist document_segment list and delete
  145. document_segments = DocumentSegment.query.filter_by(
  146. dataset_id=dataset.id, document_id=dataset_document.id
  147. ).all()
  148. documents = []
  149. if document_segments:
  150. for document_segment in document_segments:
  151. # transform segment to node
  152. if document_segment.status != "completed":
  153. document = Document(
  154. page_content=document_segment.content,
  155. metadata={
  156. "doc_id": document_segment.index_node_id,
  157. "doc_hash": document_segment.index_node_hash,
  158. "document_id": document_segment.document_id,
  159. "dataset_id": document_segment.dataset_id,
  160. },
  161. )
  162. documents.append(document)
  163. # build index
  164. # get the process rule
  165. processing_rule = (
  166. db.session.query(DatasetProcessRule)
  167. .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
  168. .first()
  169. )
  170. index_type = dataset_document.doc_form
  171. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  172. self._load(
  173. index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
  174. )
  175. except DocumentIsPausedError:
  176. raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
  177. except ProviderTokenNotInitError as e:
  178. dataset_document.indexing_status = "error"
  179. dataset_document.error = str(e.description)
  180. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  181. db.session.commit()
  182. except Exception as e:
  183. logging.exception("consume document failed")
  184. dataset_document.indexing_status = "error"
  185. dataset_document.error = str(e)
  186. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  187. db.session.commit()
  188. def indexing_estimate(
  189. self,
  190. tenant_id: str,
  191. extract_settings: list[ExtractSetting],
  192. tmp_processing_rule: dict,
  193. doc_form: Optional[str] = None,
  194. doc_language: str = "English",
  195. dataset_id: Optional[str] = None,
  196. indexing_technique: str = "economy",
  197. ) -> dict:
  198. """
  199. Estimate the indexing for the document.
  200. """
  201. # check document limit
  202. features = FeatureService.get_features(tenant_id)
  203. if features.billing.enabled:
  204. count = len(extract_settings)
  205. batch_upload_limit = dify_config.BATCH_UPLOAD_LIMIT
  206. if count > batch_upload_limit:
  207. raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
  208. embedding_model_instance = None
  209. if dataset_id:
  210. dataset = Dataset.query.filter_by(id=dataset_id).first()
  211. if not dataset:
  212. raise ValueError("Dataset not found.")
  213. if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
  214. if dataset.embedding_model_provider:
  215. embedding_model_instance = self.model_manager.get_model_instance(
  216. tenant_id=tenant_id,
  217. provider=dataset.embedding_model_provider,
  218. model_type=ModelType.TEXT_EMBEDDING,
  219. model=dataset.embedding_model,
  220. )
  221. else:
  222. embedding_model_instance = self.model_manager.get_default_model_instance(
  223. tenant_id=tenant_id,
  224. model_type=ModelType.TEXT_EMBEDDING,
  225. )
  226. else:
  227. if indexing_technique == "high_quality":
  228. embedding_model_instance = self.model_manager.get_default_model_instance(
  229. tenant_id=tenant_id,
  230. model_type=ModelType.TEXT_EMBEDDING,
  231. )
  232. preview_texts = []
  233. total_segments = 0
  234. index_type = doc_form
  235. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  236. all_text_docs = []
  237. for extract_setting in extract_settings:
  238. # extract
  239. text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
  240. all_text_docs.extend(text_docs)
  241. processing_rule = DatasetProcessRule(
  242. mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
  243. )
  244. # get splitter
  245. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  246. # split to documents
  247. documents = self._split_to_documents_for_estimate(
  248. text_docs=text_docs, splitter=splitter, processing_rule=processing_rule
  249. )
  250. total_segments += len(documents)
  251. for document in documents:
  252. if len(preview_texts) < 5:
  253. preview_texts.append(document.page_content)
  254. # delete image files and related db records
  255. image_upload_file_ids = get_image_upload_file_ids(document.page_content)
  256. for upload_file_id in image_upload_file_ids:
  257. image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
  258. try:
  259. storage.delete(image_file.key)
  260. except Exception:
  261. logging.exception(
  262. "Delete image_files failed while indexing_estimate, \
  263. image_upload_file_is: {}".format(upload_file_id)
  264. )
  265. db.session.delete(image_file)
  266. if doc_form and doc_form == "qa_model":
  267. if len(preview_texts) > 0:
  268. # qa model document
  269. response = LLMGenerator.generate_qa_document(
  270. current_user.current_tenant_id, preview_texts[0], doc_language
  271. )
  272. document_qa_list = self.format_split_text(response)
  273. return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts}
  274. return {"total_segments": total_segments, "preview": preview_texts}
  275. def _extract(
  276. self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
  277. ) -> list[Document]:
  278. # load file
  279. if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
  280. return []
  281. data_source_info = dataset_document.data_source_info_dict
  282. text_docs = []
  283. if dataset_document.data_source_type == "upload_file":
  284. if not data_source_info or "upload_file_id" not in data_source_info:
  285. raise ValueError("no upload file found")
  286. file_detail = (
  287. db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
  288. )
  289. if file_detail:
  290. extract_setting = ExtractSetting(
  291. datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
  292. )
  293. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
  294. elif dataset_document.data_source_type == "notion_import":
  295. if (
  296. not data_source_info
  297. or "notion_workspace_id" not in data_source_info
  298. or "notion_page_id" not in data_source_info
  299. ):
  300. raise ValueError("no notion import info found")
  301. extract_setting = ExtractSetting(
  302. datasource_type="notion_import",
  303. notion_info={
  304. "notion_workspace_id": data_source_info["notion_workspace_id"],
  305. "notion_obj_id": data_source_info["notion_page_id"],
  306. "notion_page_type": data_source_info["type"],
  307. "document": dataset_document,
  308. "tenant_id": dataset_document.tenant_id,
  309. },
  310. document_model=dataset_document.doc_form,
  311. )
  312. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
  313. elif dataset_document.data_source_type == "website_crawl":
  314. if (
  315. not data_source_info
  316. or "provider" not in data_source_info
  317. or "url" not in data_source_info
  318. or "job_id" not in data_source_info
  319. ):
  320. raise ValueError("no website import info found")
  321. extract_setting = ExtractSetting(
  322. datasource_type="website_crawl",
  323. website_info={
  324. "provider": data_source_info["provider"],
  325. "job_id": data_source_info["job_id"],
  326. "tenant_id": dataset_document.tenant_id,
  327. "url": data_source_info["url"],
  328. "mode": data_source_info["mode"],
  329. "only_main_content": data_source_info["only_main_content"],
  330. },
  331. document_model=dataset_document.doc_form,
  332. )
  333. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
  334. # update document status to splitting
  335. self._update_document_index_status(
  336. document_id=dataset_document.id,
  337. after_indexing_status="splitting",
  338. extra_update_params={
  339. DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
  340. DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  341. },
  342. )
  343. # replace doc id to document model id
  344. text_docs = cast(list[Document], text_docs)
  345. for text_doc in text_docs:
  346. text_doc.metadata["document_id"] = dataset_document.id
  347. text_doc.metadata["dataset_id"] = dataset_document.dataset_id
  348. return text_docs
  349. @staticmethod
  350. def filter_string(text):
  351. text = re.sub(r"<\|", "<", text)
  352. text = re.sub(r"\|>", ">", text)
  353. text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text)
  354. # Unicode U+FFFE
  355. text = re.sub("\ufffe", "", text)
  356. return text
  357. @staticmethod
  358. def _get_splitter(
  359. processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance]
  360. ) -> TextSplitter:
  361. """
  362. Get the NodeParser object according to the processing rule.
  363. """
  364. if processing_rule.mode == "custom":
  365. # The user-defined segmentation rule
  366. rules = json.loads(processing_rule.rules)
  367. segmentation = rules["segmentation"]
  368. max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
  369. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
  370. raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
  371. separator = segmentation["separator"]
  372. if separator:
  373. separator = separator.replace("\\n", "\n")
  374. if segmentation.get("chunk_overlap"):
  375. chunk_overlap = segmentation["chunk_overlap"]
  376. else:
  377. chunk_overlap = 0
  378. character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
  379. chunk_size=segmentation["max_tokens"],
  380. chunk_overlap=chunk_overlap,
  381. fixed_separator=separator,
  382. separators=["\n\n", "。", ". ", " ", ""],
  383. embedding_model_instance=embedding_model_instance,
  384. )
  385. else:
  386. # Automatic segmentation
  387. character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
  388. chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"],
  389. chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"],
  390. separators=["\n\n", "。", ". ", " ", ""],
  391. embedding_model_instance=embedding_model_instance,
  392. )
  393. return character_splitter
  394. def _step_split(
  395. self,
  396. text_docs: list[Document],
  397. splitter: TextSplitter,
  398. dataset: Dataset,
  399. dataset_document: DatasetDocument,
  400. processing_rule: DatasetProcessRule,
  401. ) -> list[Document]:
  402. """
  403. Split the text documents into documents and save them to the document segment.
  404. """
  405. documents = self._split_to_documents(
  406. text_docs=text_docs,
  407. splitter=splitter,
  408. processing_rule=processing_rule,
  409. tenant_id=dataset.tenant_id,
  410. document_form=dataset_document.doc_form,
  411. document_language=dataset_document.doc_language,
  412. )
  413. # save node to document segment
  414. doc_store = DatasetDocumentStore(
  415. dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
  416. )
  417. # add document segments
  418. doc_store.add_documents(documents)
  419. # update document status to indexing
  420. cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  421. self._update_document_index_status(
  422. document_id=dataset_document.id,
  423. after_indexing_status="indexing",
  424. extra_update_params={
  425. DatasetDocument.cleaning_completed_at: cur_time,
  426. DatasetDocument.splitting_completed_at: cur_time,
  427. },
  428. )
  429. # update segment status to indexing
  430. self._update_segments_by_document(
  431. dataset_document_id=dataset_document.id,
  432. update_params={
  433. DocumentSegment.status: "indexing",
  434. DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  435. },
  436. )
  437. return documents
  438. def _split_to_documents(
  439. self,
  440. text_docs: list[Document],
  441. splitter: TextSplitter,
  442. processing_rule: DatasetProcessRule,
  443. tenant_id: str,
  444. document_form: str,
  445. document_language: str,
  446. ) -> list[Document]:
  447. """
  448. Split the text documents into nodes.
  449. """
  450. all_documents = []
  451. all_qa_documents = []
  452. for text_doc in text_docs:
  453. # document clean
  454. document_text = self._document_clean(text_doc.page_content, processing_rule)
  455. text_doc.page_content = document_text
  456. # parse document to nodes
  457. documents = splitter.split_documents([text_doc])
  458. split_documents = []
  459. for document_node in documents:
  460. if document_node.page_content.strip():
  461. doc_id = str(uuid.uuid4())
  462. hash = helper.generate_text_hash(document_node.page_content)
  463. document_node.metadata["doc_id"] = doc_id
  464. document_node.metadata["doc_hash"] = hash
  465. # delete Splitter character
  466. page_content = document_node.page_content
  467. document_node.page_content = remove_leading_symbols(page_content)
  468. if document_node.page_content:
  469. split_documents.append(document_node)
  470. all_documents.extend(split_documents)
  471. # processing qa document
  472. if document_form == "qa_model":
  473. for i in range(0, len(all_documents), 10):
  474. threads = []
  475. sub_documents = all_documents[i : i + 10]
  476. for doc in sub_documents:
  477. document_format_thread = threading.Thread(
  478. target=self.format_qa_document,
  479. kwargs={
  480. "flask_app": current_app._get_current_object(),
  481. "tenant_id": tenant_id,
  482. "document_node": doc,
  483. "all_qa_documents": all_qa_documents,
  484. "document_language": document_language,
  485. },
  486. )
  487. threads.append(document_format_thread)
  488. document_format_thread.start()
  489. for thread in threads:
  490. thread.join()
  491. return all_qa_documents
  492. return all_documents
  493. def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
  494. format_documents = []
  495. if document_node.page_content is None or not document_node.page_content.strip():
  496. return
  497. with flask_app.app_context():
  498. try:
  499. # qa model document
  500. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
  501. document_qa_list = self.format_split_text(response)
  502. qa_documents = []
  503. for result in document_qa_list:
  504. qa_document = Document(
  505. page_content=result["question"], metadata=document_node.metadata.model_copy()
  506. )
  507. doc_id = str(uuid.uuid4())
  508. hash = helper.generate_text_hash(result["question"])
  509. qa_document.metadata["answer"] = result["answer"]
  510. qa_document.metadata["doc_id"] = doc_id
  511. qa_document.metadata["doc_hash"] = hash
  512. qa_documents.append(qa_document)
  513. format_documents.extend(qa_documents)
  514. except Exception as e:
  515. logging.exception("Failed to format qa document")
  516. all_qa_documents.extend(format_documents)
  517. def _split_to_documents_for_estimate(
  518. self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
  519. ) -> list[Document]:
  520. """
  521. Split the text documents into nodes.
  522. """
  523. all_documents = []
  524. for text_doc in text_docs:
  525. # document clean
  526. document_text = self._document_clean(text_doc.page_content, processing_rule)
  527. text_doc.page_content = document_text
  528. # parse document to nodes
  529. documents = splitter.split_documents([text_doc])
  530. split_documents = []
  531. for document in documents:
  532. if document.page_content is None or not document.page_content.strip():
  533. continue
  534. doc_id = str(uuid.uuid4())
  535. hash = helper.generate_text_hash(document.page_content)
  536. document.metadata["doc_id"] = doc_id
  537. document.metadata["doc_hash"] = hash
  538. split_documents.append(document)
  539. all_documents.extend(split_documents)
  540. return all_documents
  541. @staticmethod
  542. def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str:
  543. """
  544. Clean the document text according to the processing rules.
  545. """
  546. if processing_rule.mode == "automatic":
  547. rules = DatasetProcessRule.AUTOMATIC_RULES
  548. else:
  549. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  550. document_text = CleanProcessor.clean(text, {"rules": rules})
  551. return document_text
  552. @staticmethod
  553. def format_split_text(text):
  554. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
  555. matches = re.findall(regex, text, re.UNICODE)
  556. return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a]
  557. def _load(
  558. self,
  559. index_processor: BaseIndexProcessor,
  560. dataset: Dataset,
  561. dataset_document: DatasetDocument,
  562. documents: list[Document],
  563. ) -> None:
  564. """
  565. insert index and update document/segment status to completed
  566. """
  567. embedding_model_instance = None
  568. if dataset.indexing_technique == "high_quality":
  569. embedding_model_instance = self.model_manager.get_model_instance(
  570. tenant_id=dataset.tenant_id,
  571. provider=dataset.embedding_model_provider,
  572. model_type=ModelType.TEXT_EMBEDDING,
  573. model=dataset.embedding_model,
  574. )
  575. # chunk nodes by chunk size
  576. indexing_start_at = time.perf_counter()
  577. tokens = 0
  578. chunk_size = 10
  579. # create keyword index
  580. create_keyword_thread = threading.Thread(
  581. target=self._process_keyword_index,
  582. args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
  583. )
  584. create_keyword_thread.start()
  585. if dataset.indexing_technique == "high_quality":
  586. with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
  587. futures = []
  588. for i in range(0, len(documents), chunk_size):
  589. chunk_documents = documents[i : i + chunk_size]
  590. futures.append(
  591. executor.submit(
  592. self._process_chunk,
  593. current_app._get_current_object(),
  594. index_processor,
  595. chunk_documents,
  596. dataset,
  597. dataset_document,
  598. embedding_model_instance,
  599. )
  600. )
  601. for future in futures:
  602. tokens += future.result()
  603. create_keyword_thread.join()
  604. indexing_end_at = time.perf_counter()
  605. # update document status to completed
  606. self._update_document_index_status(
  607. document_id=dataset_document.id,
  608. after_indexing_status="completed",
  609. extra_update_params={
  610. DatasetDocument.tokens: tokens,
  611. DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  612. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  613. DatasetDocument.error: None,
  614. },
  615. )
  616. @staticmethod
  617. def _process_keyword_index(flask_app, dataset_id, document_id, documents):
  618. with flask_app.app_context():
  619. dataset = Dataset.query.filter_by(id=dataset_id).first()
  620. if not dataset:
  621. raise ValueError("no dataset found")
  622. keyword = Keyword(dataset)
  623. keyword.create(documents)
  624. if dataset.indexing_technique != "high_quality":
  625. document_ids = [document.metadata["doc_id"] for document in documents]
  626. db.session.query(DocumentSegment).filter(
  627. DocumentSegment.document_id == document_id,
  628. DocumentSegment.dataset_id == dataset_id,
  629. DocumentSegment.index_node_id.in_(document_ids),
  630. DocumentSegment.status == "indexing",
  631. ).update(
  632. {
  633. DocumentSegment.status: "completed",
  634. DocumentSegment.enabled: True,
  635. DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  636. }
  637. )
  638. db.session.commit()
  639. def _process_chunk(
  640. self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance
  641. ):
  642. with flask_app.app_context():
  643. # check document is paused
  644. self._check_document_paused_status(dataset_document.id)
  645. tokens = 0
  646. if embedding_model_instance:
  647. tokens += sum(
  648. embedding_model_instance.get_text_embedding_num_tokens([document.page_content])
  649. for document in chunk_documents
  650. )
  651. # load index
  652. index_processor.load(dataset, chunk_documents, with_keywords=False)
  653. document_ids = [document.metadata["doc_id"] for document in chunk_documents]
  654. db.session.query(DocumentSegment).filter(
  655. DocumentSegment.document_id == dataset_document.id,
  656. DocumentSegment.dataset_id == dataset.id,
  657. DocumentSegment.index_node_id.in_(document_ids),
  658. DocumentSegment.status == "indexing",
  659. ).update(
  660. {
  661. DocumentSegment.status: "completed",
  662. DocumentSegment.enabled: True,
  663. DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  664. }
  665. )
  666. db.session.commit()
  667. return tokens
  668. @staticmethod
  669. def _check_document_paused_status(document_id: str):
  670. indexing_cache_key = "document_{}_is_paused".format(document_id)
  671. result = redis_client.get(indexing_cache_key)
  672. if result:
  673. raise DocumentIsPausedError()
  674. @staticmethod
  675. def _update_document_index_status(
  676. document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None
  677. ) -> None:
  678. """
  679. Update the document indexing status.
  680. """
  681. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  682. if count > 0:
  683. raise DocumentIsPausedError()
  684. document = DatasetDocument.query.filter_by(id=document_id).first()
  685. if not document:
  686. raise DocumentIsDeletedPausedError()
  687. update_params = {DatasetDocument.indexing_status: after_indexing_status}
  688. if extra_update_params:
  689. update_params.update(extra_update_params)
  690. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  691. db.session.commit()
  692. @staticmethod
  693. def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None:
  694. """
  695. Update the document segment by document id.
  696. """
  697. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  698. db.session.commit()
  699. @staticmethod
  700. def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset):
  701. """
  702. Batch add segments index processing
  703. """
  704. documents = []
  705. for segment in segments:
  706. document = Document(
  707. page_content=segment.content,
  708. metadata={
  709. "doc_id": segment.index_node_id,
  710. "doc_hash": segment.index_node_hash,
  711. "document_id": segment.document_id,
  712. "dataset_id": segment.dataset_id,
  713. },
  714. )
  715. documents.append(document)
  716. # save vector index
  717. index_type = dataset.doc_form
  718. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  719. index_processor.load(dataset, documents)
  720. def _transform(
  721. self,
  722. index_processor: BaseIndexProcessor,
  723. dataset: Dataset,
  724. text_docs: list[Document],
  725. doc_language: str,
  726. process_rule: dict,
  727. ) -> list[Document]:
  728. # get embedding model instance
  729. embedding_model_instance = None
  730. if dataset.indexing_technique == "high_quality":
  731. if dataset.embedding_model_provider:
  732. embedding_model_instance = self.model_manager.get_model_instance(
  733. tenant_id=dataset.tenant_id,
  734. provider=dataset.embedding_model_provider,
  735. model_type=ModelType.TEXT_EMBEDDING,
  736. model=dataset.embedding_model,
  737. )
  738. else:
  739. embedding_model_instance = self.model_manager.get_default_model_instance(
  740. tenant_id=dataset.tenant_id,
  741. model_type=ModelType.TEXT_EMBEDDING,
  742. )
  743. documents = index_processor.transform(
  744. text_docs,
  745. embedding_model_instance=embedding_model_instance,
  746. process_rule=process_rule,
  747. tenant_id=dataset.tenant_id,
  748. doc_language=doc_language,
  749. )
  750. return documents
  751. def _load_segments(self, dataset, dataset_document, documents):
  752. # save node to document segment
  753. doc_store = DatasetDocumentStore(
  754. dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
  755. )
  756. # add document segments
  757. doc_store.add_documents(documents)
  758. # update document status to indexing
  759. cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  760. self._update_document_index_status(
  761. document_id=dataset_document.id,
  762. after_indexing_status="indexing",
  763. extra_update_params={
  764. DatasetDocument.cleaning_completed_at: cur_time,
  765. DatasetDocument.splitting_completed_at: cur_time,
  766. },
  767. )
  768. # update segment status to indexing
  769. self._update_segments_by_document(
  770. dataset_document_id=dataset_document.id,
  771. update_params={
  772. DocumentSegment.status: "indexing",
  773. DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  774. },
  775. )
  776. pass
  777. class DocumentIsPausedError(Exception):
  778. pass
  779. class DocumentIsDeletedPausedError(Exception):
  780. pass