indexing_runner.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. import concurrent.futures
  2. import datetime
  3. import json
  4. import logging
  5. import re
  6. import threading
  7. import time
  8. import uuid
  9. from typing import Optional, cast
  10. from flask import Flask, current_app
  11. from flask_login import current_user
  12. from sqlalchemy.orm.exc import ObjectDeletedError
  13. from configs import dify_config
  14. from core.errors.error import ProviderTokenNotInitError
  15. from core.llm_generator.llm_generator import LLMGenerator
  16. from core.model_manager import ModelInstance, ModelManager
  17. from core.model_runtime.entities.model_entities import ModelType
  18. from core.rag.cleaner.clean_processor import CleanProcessor
  19. from core.rag.datasource.keyword.keyword_factory import Keyword
  20. from core.rag.docstore.dataset_docstore import DatasetDocumentStore
  21. from core.rag.extractor.entity.extract_setting import ExtractSetting
  22. from core.rag.index_processor.index_processor_base import BaseIndexProcessor
  23. from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
  24. from core.rag.models.document import Document
  25. from core.rag.splitter.fixed_text_splitter import (
  26. EnhanceRecursiveCharacterTextSplitter,
  27. FixedRecursiveCharacterTextSplitter,
  28. )
  29. from core.rag.splitter.text_splitter import TextSplitter
  30. from extensions.ext_database import db
  31. from extensions.ext_redis import redis_client
  32. from extensions.ext_storage import storage
  33. from libs import helper
  34. from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
  35. from models.dataset import Document as DatasetDocument
  36. from models.model import UploadFile
  37. from services.feature_service import FeatureService
  38. class IndexingRunner:
  39. def __init__(self):
  40. self.storage = storage
  41. self.model_manager = ModelManager()
  42. def run(self, dataset_documents: list[DatasetDocument]):
  43. """Run the indexing process."""
  44. for dataset_document in dataset_documents:
  45. try:
  46. # get dataset
  47. dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
  48. if not dataset:
  49. raise ValueError("no dataset found")
  50. # get the process rule
  51. processing_rule = (
  52. db.session.query(DatasetProcessRule)
  53. .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
  54. .first()
  55. )
  56. index_type = dataset_document.doc_form
  57. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  58. # extract
  59. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  60. # transform
  61. documents = self._transform(
  62. index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
  63. )
  64. # save segment
  65. self._load_segments(dataset, dataset_document, documents)
  66. # load
  67. self._load(
  68. index_processor=index_processor,
  69. dataset=dataset,
  70. dataset_document=dataset_document,
  71. documents=documents,
  72. )
  73. except DocumentIsPausedError:
  74. raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
  75. except ProviderTokenNotInitError as e:
  76. dataset_document.indexing_status = "error"
  77. dataset_document.error = str(e.description)
  78. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  79. db.session.commit()
  80. except ObjectDeletedError:
  81. logging.warning("Document deleted, document id: {}".format(dataset_document.id))
  82. except Exception as e:
  83. logging.exception("consume document failed")
  84. dataset_document.indexing_status = "error"
  85. dataset_document.error = str(e)
  86. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  87. db.session.commit()
  88. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  89. """Run the indexing process when the index_status is splitting."""
  90. try:
  91. # get dataset
  92. dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
  93. if not dataset:
  94. raise ValueError("no dataset found")
  95. # get exist document_segment list and delete
  96. document_segments = DocumentSegment.query.filter_by(
  97. dataset_id=dataset.id, document_id=dataset_document.id
  98. ).all()
  99. for document_segment in document_segments:
  100. db.session.delete(document_segment)
  101. db.session.commit()
  102. # get the process rule
  103. processing_rule = (
  104. db.session.query(DatasetProcessRule)
  105. .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
  106. .first()
  107. )
  108. index_type = dataset_document.doc_form
  109. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  110. # extract
  111. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  112. # transform
  113. documents = self._transform(
  114. index_processor, dataset, text_docs, dataset_document.doc_language, processing_rule.to_dict()
  115. )
  116. # save segment
  117. self._load_segments(dataset, dataset_document, documents)
  118. # load
  119. self._load(
  120. index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
  121. )
  122. except DocumentIsPausedError:
  123. raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
  124. except ProviderTokenNotInitError as e:
  125. dataset_document.indexing_status = "error"
  126. dataset_document.error = str(e.description)
  127. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  128. db.session.commit()
  129. except Exception as e:
  130. logging.exception("consume document failed")
  131. dataset_document.indexing_status = "error"
  132. dataset_document.error = str(e)
  133. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  134. db.session.commit()
  135. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  136. """Run the indexing process when the index_status is indexing."""
  137. try:
  138. # get dataset
  139. dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first()
  140. if not dataset:
  141. raise ValueError("no dataset found")
  142. # get exist document_segment list and delete
  143. document_segments = DocumentSegment.query.filter_by(
  144. dataset_id=dataset.id, document_id=dataset_document.id
  145. ).all()
  146. documents = []
  147. if document_segments:
  148. for document_segment in document_segments:
  149. # transform segment to node
  150. if document_segment.status != "completed":
  151. document = Document(
  152. page_content=document_segment.content,
  153. metadata={
  154. "doc_id": document_segment.index_node_id,
  155. "doc_hash": document_segment.index_node_hash,
  156. "document_id": document_segment.document_id,
  157. "dataset_id": document_segment.dataset_id,
  158. },
  159. )
  160. documents.append(document)
  161. # build index
  162. # get the process rule
  163. processing_rule = (
  164. db.session.query(DatasetProcessRule)
  165. .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
  166. .first()
  167. )
  168. index_type = dataset_document.doc_form
  169. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  170. self._load(
  171. index_processor=index_processor, dataset=dataset, dataset_document=dataset_document, documents=documents
  172. )
  173. except DocumentIsPausedError:
  174. raise DocumentIsPausedError("Document paused, document id: {}".format(dataset_document.id))
  175. except ProviderTokenNotInitError as e:
  176. dataset_document.indexing_status = "error"
  177. dataset_document.error = str(e.description)
  178. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  179. db.session.commit()
  180. except Exception as e:
  181. logging.exception("consume document failed")
  182. dataset_document.indexing_status = "error"
  183. dataset_document.error = str(e)
  184. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  185. db.session.commit()
  186. def indexing_estimate(
  187. self,
  188. tenant_id: str,
  189. extract_settings: list[ExtractSetting],
  190. tmp_processing_rule: dict,
  191. doc_form: Optional[str] = None,
  192. doc_language: str = "English",
  193. dataset_id: Optional[str] = None,
  194. indexing_technique: str = "economy",
  195. ) -> dict:
  196. """
  197. Estimate the indexing for the document.
  198. """
  199. # check document limit
  200. features = FeatureService.get_features(tenant_id)
  201. if features.billing.enabled:
  202. count = len(extract_settings)
  203. batch_upload_limit = dify_config.BATCH_UPLOAD_LIMIT
  204. if count > batch_upload_limit:
  205. raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
  206. embedding_model_instance = None
  207. if dataset_id:
  208. dataset = Dataset.query.filter_by(id=dataset_id).first()
  209. if not dataset:
  210. raise ValueError("Dataset not found.")
  211. if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
  212. if dataset.embedding_model_provider:
  213. embedding_model_instance = self.model_manager.get_model_instance(
  214. tenant_id=tenant_id,
  215. provider=dataset.embedding_model_provider,
  216. model_type=ModelType.TEXT_EMBEDDING,
  217. model=dataset.embedding_model,
  218. )
  219. else:
  220. embedding_model_instance = self.model_manager.get_default_model_instance(
  221. tenant_id=tenant_id,
  222. model_type=ModelType.TEXT_EMBEDDING,
  223. )
  224. else:
  225. if indexing_technique == "high_quality":
  226. embedding_model_instance = self.model_manager.get_default_model_instance(
  227. tenant_id=tenant_id,
  228. model_type=ModelType.TEXT_EMBEDDING,
  229. )
  230. preview_texts = []
  231. total_segments = 0
  232. index_type = doc_form
  233. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  234. all_text_docs = []
  235. for extract_setting in extract_settings:
  236. # extract
  237. text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
  238. all_text_docs.extend(text_docs)
  239. processing_rule = DatasetProcessRule(
  240. mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
  241. )
  242. # get splitter
  243. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  244. # split to documents
  245. documents = self._split_to_documents_for_estimate(
  246. text_docs=text_docs, splitter=splitter, processing_rule=processing_rule
  247. )
  248. total_segments += len(documents)
  249. for document in documents:
  250. if len(preview_texts) < 5:
  251. preview_texts.append(document.page_content)
  252. if doc_form and doc_form == "qa_model":
  253. if len(preview_texts) > 0:
  254. # qa model document
  255. response = LLMGenerator.generate_qa_document(
  256. current_user.current_tenant_id, preview_texts[0], doc_language
  257. )
  258. document_qa_list = self.format_split_text(response)
  259. return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts}
  260. return {"total_segments": total_segments, "preview": preview_texts}
  261. def _extract(
  262. self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
  263. ) -> list[Document]:
  264. # load file
  265. if dataset_document.data_source_type not in {"upload_file", "notion_import", "website_crawl"}:
  266. return []
  267. data_source_info = dataset_document.data_source_info_dict
  268. text_docs = []
  269. if dataset_document.data_source_type == "upload_file":
  270. if not data_source_info or "upload_file_id" not in data_source_info:
  271. raise ValueError("no upload file found")
  272. file_detail = (
  273. db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
  274. )
  275. if file_detail:
  276. extract_setting = ExtractSetting(
  277. datasource_type="upload_file", upload_file=file_detail, document_model=dataset_document.doc_form
  278. )
  279. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
  280. elif dataset_document.data_source_type == "notion_import":
  281. if (
  282. not data_source_info
  283. or "notion_workspace_id" not in data_source_info
  284. or "notion_page_id" not in data_source_info
  285. ):
  286. raise ValueError("no notion import info found")
  287. extract_setting = ExtractSetting(
  288. datasource_type="notion_import",
  289. notion_info={
  290. "notion_workspace_id": data_source_info["notion_workspace_id"],
  291. "notion_obj_id": data_source_info["notion_page_id"],
  292. "notion_page_type": data_source_info["type"],
  293. "document": dataset_document,
  294. "tenant_id": dataset_document.tenant_id,
  295. },
  296. document_model=dataset_document.doc_form,
  297. )
  298. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
  299. elif dataset_document.data_source_type == "website_crawl":
  300. if (
  301. not data_source_info
  302. or "provider" not in data_source_info
  303. or "url" not in data_source_info
  304. or "job_id" not in data_source_info
  305. ):
  306. raise ValueError("no website import info found")
  307. extract_setting = ExtractSetting(
  308. datasource_type="website_crawl",
  309. website_info={
  310. "provider": data_source_info["provider"],
  311. "job_id": data_source_info["job_id"],
  312. "tenant_id": dataset_document.tenant_id,
  313. "url": data_source_info["url"],
  314. "mode": data_source_info["mode"],
  315. "only_main_content": data_source_info["only_main_content"],
  316. },
  317. document_model=dataset_document.doc_form,
  318. )
  319. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"])
  320. # update document status to splitting
  321. self._update_document_index_status(
  322. document_id=dataset_document.id,
  323. after_indexing_status="splitting",
  324. extra_update_params={
  325. DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
  326. DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  327. },
  328. )
  329. # replace doc id to document model id
  330. text_docs = cast(list[Document], text_docs)
  331. for text_doc in text_docs:
  332. text_doc.metadata["document_id"] = dataset_document.id
  333. text_doc.metadata["dataset_id"] = dataset_document.dataset_id
  334. return text_docs
  335. @staticmethod
  336. def filter_string(text):
  337. text = re.sub(r"<\|", "<", text)
  338. text = re.sub(r"\|>", ">", text)
  339. text = re.sub(r"[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]", "", text)
  340. # Unicode U+FFFE
  341. text = re.sub("\ufffe", "", text)
  342. return text
  343. @staticmethod
  344. def _get_splitter(
  345. processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance]
  346. ) -> TextSplitter:
  347. """
  348. Get the NodeParser object according to the processing rule.
  349. """
  350. if processing_rule.mode == "custom":
  351. # The user-defined segmentation rule
  352. rules = json.loads(processing_rule.rules)
  353. segmentation = rules["segmentation"]
  354. max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
  355. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
  356. raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
  357. separator = segmentation["separator"]
  358. if separator:
  359. separator = separator.replace("\\n", "\n")
  360. if segmentation.get("chunk_overlap"):
  361. chunk_overlap = segmentation["chunk_overlap"]
  362. else:
  363. chunk_overlap = 0
  364. character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
  365. chunk_size=segmentation["max_tokens"],
  366. chunk_overlap=chunk_overlap,
  367. fixed_separator=separator,
  368. separators=["\n\n", "。", ". ", " ", ""],
  369. embedding_model_instance=embedding_model_instance,
  370. )
  371. else:
  372. # Automatic segmentation
  373. character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
  374. chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"],
  375. chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"],
  376. separators=["\n\n", "。", ". ", " ", ""],
  377. embedding_model_instance=embedding_model_instance,
  378. )
  379. return character_splitter
  380. def _step_split(
  381. self,
  382. text_docs: list[Document],
  383. splitter: TextSplitter,
  384. dataset: Dataset,
  385. dataset_document: DatasetDocument,
  386. processing_rule: DatasetProcessRule,
  387. ) -> list[Document]:
  388. """
  389. Split the text documents into documents and save them to the document segment.
  390. """
  391. documents = self._split_to_documents(
  392. text_docs=text_docs,
  393. splitter=splitter,
  394. processing_rule=processing_rule,
  395. tenant_id=dataset.tenant_id,
  396. document_form=dataset_document.doc_form,
  397. document_language=dataset_document.doc_language,
  398. )
  399. # save node to document segment
  400. doc_store = DatasetDocumentStore(
  401. dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
  402. )
  403. # add document segments
  404. doc_store.add_documents(documents)
  405. # update document status to indexing
  406. cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  407. self._update_document_index_status(
  408. document_id=dataset_document.id,
  409. after_indexing_status="indexing",
  410. extra_update_params={
  411. DatasetDocument.cleaning_completed_at: cur_time,
  412. DatasetDocument.splitting_completed_at: cur_time,
  413. },
  414. )
  415. # update segment status to indexing
  416. self._update_segments_by_document(
  417. dataset_document_id=dataset_document.id,
  418. update_params={
  419. DocumentSegment.status: "indexing",
  420. DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  421. },
  422. )
  423. return documents
  424. def _split_to_documents(
  425. self,
  426. text_docs: list[Document],
  427. splitter: TextSplitter,
  428. processing_rule: DatasetProcessRule,
  429. tenant_id: str,
  430. document_form: str,
  431. document_language: str,
  432. ) -> list[Document]:
  433. """
  434. Split the text documents into nodes.
  435. """
  436. all_documents = []
  437. all_qa_documents = []
  438. for text_doc in text_docs:
  439. # document clean
  440. document_text = self._document_clean(text_doc.page_content, processing_rule)
  441. text_doc.page_content = document_text
  442. # parse document to nodes
  443. documents = splitter.split_documents([text_doc])
  444. split_documents = []
  445. for document_node in documents:
  446. if document_node.page_content.strip():
  447. doc_id = str(uuid.uuid4())
  448. hash = helper.generate_text_hash(document_node.page_content)
  449. document_node.metadata["doc_id"] = doc_id
  450. document_node.metadata["doc_hash"] = hash
  451. # delete Splitter character
  452. page_content = document_node.page_content
  453. if page_content.startswith(".") or page_content.startswith("。"):
  454. page_content = page_content[1:]
  455. else:
  456. page_content = page_content
  457. document_node.page_content = page_content
  458. if document_node.page_content:
  459. split_documents.append(document_node)
  460. all_documents.extend(split_documents)
  461. # processing qa document
  462. if document_form == "qa_model":
  463. for i in range(0, len(all_documents), 10):
  464. threads = []
  465. sub_documents = all_documents[i : i + 10]
  466. for doc in sub_documents:
  467. document_format_thread = threading.Thread(
  468. target=self.format_qa_document,
  469. kwargs={
  470. "flask_app": current_app._get_current_object(),
  471. "tenant_id": tenant_id,
  472. "document_node": doc,
  473. "all_qa_documents": all_qa_documents,
  474. "document_language": document_language,
  475. },
  476. )
  477. threads.append(document_format_thread)
  478. document_format_thread.start()
  479. for thread in threads:
  480. thread.join()
  481. return all_qa_documents
  482. return all_documents
  483. def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
  484. format_documents = []
  485. if document_node.page_content is None or not document_node.page_content.strip():
  486. return
  487. with flask_app.app_context():
  488. try:
  489. # qa model document
  490. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
  491. document_qa_list = self.format_split_text(response)
  492. qa_documents = []
  493. for result in document_qa_list:
  494. qa_document = Document(
  495. page_content=result["question"], metadata=document_node.metadata.model_copy()
  496. )
  497. doc_id = str(uuid.uuid4())
  498. hash = helper.generate_text_hash(result["question"])
  499. qa_document.metadata["answer"] = result["answer"]
  500. qa_document.metadata["doc_id"] = doc_id
  501. qa_document.metadata["doc_hash"] = hash
  502. qa_documents.append(qa_document)
  503. format_documents.extend(qa_documents)
  504. except Exception as e:
  505. logging.exception(e)
  506. all_qa_documents.extend(format_documents)
  507. def _split_to_documents_for_estimate(
  508. self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
  509. ) -> list[Document]:
  510. """
  511. Split the text documents into nodes.
  512. """
  513. all_documents = []
  514. for text_doc in text_docs:
  515. # document clean
  516. document_text = self._document_clean(text_doc.page_content, processing_rule)
  517. text_doc.page_content = document_text
  518. # parse document to nodes
  519. documents = splitter.split_documents([text_doc])
  520. split_documents = []
  521. for document in documents:
  522. if document.page_content is None or not document.page_content.strip():
  523. continue
  524. doc_id = str(uuid.uuid4())
  525. hash = helper.generate_text_hash(document.page_content)
  526. document.metadata["doc_id"] = doc_id
  527. document.metadata["doc_hash"] = hash
  528. split_documents.append(document)
  529. all_documents.extend(split_documents)
  530. return all_documents
  531. @staticmethod
  532. def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str:
  533. """
  534. Clean the document text according to the processing rules.
  535. """
  536. if processing_rule.mode == "automatic":
  537. rules = DatasetProcessRule.AUTOMATIC_RULES
  538. else:
  539. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  540. document_text = CleanProcessor.clean(text, {"rules": rules})
  541. return document_text
  542. @staticmethod
  543. def format_split_text(text):
  544. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
  545. matches = re.findall(regex, text, re.UNICODE)
  546. return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a]
  547. def _load(
  548. self,
  549. index_processor: BaseIndexProcessor,
  550. dataset: Dataset,
  551. dataset_document: DatasetDocument,
  552. documents: list[Document],
  553. ) -> None:
  554. """
  555. insert index and update document/segment status to completed
  556. """
  557. embedding_model_instance = None
  558. if dataset.indexing_technique == "high_quality":
  559. embedding_model_instance = self.model_manager.get_model_instance(
  560. tenant_id=dataset.tenant_id,
  561. provider=dataset.embedding_model_provider,
  562. model_type=ModelType.TEXT_EMBEDDING,
  563. model=dataset.embedding_model,
  564. )
  565. # chunk nodes by chunk size
  566. indexing_start_at = time.perf_counter()
  567. tokens = 0
  568. chunk_size = 10
  569. # create keyword index
  570. create_keyword_thread = threading.Thread(
  571. target=self._process_keyword_index,
  572. args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents),
  573. )
  574. create_keyword_thread.start()
  575. if dataset.indexing_technique == "high_quality":
  576. with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
  577. futures = []
  578. for i in range(0, len(documents), chunk_size):
  579. chunk_documents = documents[i : i + chunk_size]
  580. futures.append(
  581. executor.submit(
  582. self._process_chunk,
  583. current_app._get_current_object(),
  584. index_processor,
  585. chunk_documents,
  586. dataset,
  587. dataset_document,
  588. embedding_model_instance,
  589. )
  590. )
  591. for future in futures:
  592. tokens += future.result()
  593. create_keyword_thread.join()
  594. indexing_end_at = time.perf_counter()
  595. # update document status to completed
  596. self._update_document_index_status(
  597. document_id=dataset_document.id,
  598. after_indexing_status="completed",
  599. extra_update_params={
  600. DatasetDocument.tokens: tokens,
  601. DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  602. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  603. DatasetDocument.error: None,
  604. },
  605. )
  606. @staticmethod
  607. def _process_keyword_index(flask_app, dataset_id, document_id, documents):
  608. with flask_app.app_context():
  609. dataset = Dataset.query.filter_by(id=dataset_id).first()
  610. if not dataset:
  611. raise ValueError("no dataset found")
  612. keyword = Keyword(dataset)
  613. keyword.create(documents)
  614. if dataset.indexing_technique != "high_quality":
  615. document_ids = [document.metadata["doc_id"] for document in documents]
  616. db.session.query(DocumentSegment).filter(
  617. DocumentSegment.document_id == document_id,
  618. DocumentSegment.dataset_id == dataset_id,
  619. DocumentSegment.index_node_id.in_(document_ids),
  620. DocumentSegment.status == "indexing",
  621. ).update(
  622. {
  623. DocumentSegment.status: "completed",
  624. DocumentSegment.enabled: True,
  625. DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  626. }
  627. )
  628. db.session.commit()
  629. def _process_chunk(
  630. self, flask_app, index_processor, chunk_documents, dataset, dataset_document, embedding_model_instance
  631. ):
  632. with flask_app.app_context():
  633. # check document is paused
  634. self._check_document_paused_status(dataset_document.id)
  635. tokens = 0
  636. if embedding_model_instance:
  637. tokens += sum(
  638. embedding_model_instance.get_text_embedding_num_tokens([document.page_content])
  639. for document in chunk_documents
  640. )
  641. # load index
  642. index_processor.load(dataset, chunk_documents, with_keywords=False)
  643. document_ids = [document.metadata["doc_id"] for document in chunk_documents]
  644. db.session.query(DocumentSegment).filter(
  645. DocumentSegment.document_id == dataset_document.id,
  646. DocumentSegment.dataset_id == dataset.id,
  647. DocumentSegment.index_node_id.in_(document_ids),
  648. DocumentSegment.status == "indexing",
  649. ).update(
  650. {
  651. DocumentSegment.status: "completed",
  652. DocumentSegment.enabled: True,
  653. DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  654. }
  655. )
  656. db.session.commit()
  657. return tokens
  658. @staticmethod
  659. def _check_document_paused_status(document_id: str):
  660. indexing_cache_key = "document_{}_is_paused".format(document_id)
  661. result = redis_client.get(indexing_cache_key)
  662. if result:
  663. raise DocumentIsPausedError()
  664. @staticmethod
  665. def _update_document_index_status(
  666. document_id: str, after_indexing_status: str, extra_update_params: Optional[dict] = None
  667. ) -> None:
  668. """
  669. Update the document indexing status.
  670. """
  671. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  672. if count > 0:
  673. raise DocumentIsPausedError()
  674. document = DatasetDocument.query.filter_by(id=document_id).first()
  675. if not document:
  676. raise DocumentIsDeletedPausedError()
  677. update_params = {DatasetDocument.indexing_status: after_indexing_status}
  678. if extra_update_params:
  679. update_params.update(extra_update_params)
  680. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  681. db.session.commit()
  682. @staticmethod
  683. def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None:
  684. """
  685. Update the document segment by document id.
  686. """
  687. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  688. db.session.commit()
  689. @staticmethod
  690. def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset):
  691. """
  692. Batch add segments index processing
  693. """
  694. documents = []
  695. for segment in segments:
  696. document = Document(
  697. page_content=segment.content,
  698. metadata={
  699. "doc_id": segment.index_node_id,
  700. "doc_hash": segment.index_node_hash,
  701. "document_id": segment.document_id,
  702. "dataset_id": segment.dataset_id,
  703. },
  704. )
  705. documents.append(document)
  706. # save vector index
  707. index_type = dataset.doc_form
  708. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  709. index_processor.load(dataset, documents)
  710. def _transform(
  711. self,
  712. index_processor: BaseIndexProcessor,
  713. dataset: Dataset,
  714. text_docs: list[Document],
  715. doc_language: str,
  716. process_rule: dict,
  717. ) -> list[Document]:
  718. # get embedding model instance
  719. embedding_model_instance = None
  720. if dataset.indexing_technique == "high_quality":
  721. if dataset.embedding_model_provider:
  722. embedding_model_instance = self.model_manager.get_model_instance(
  723. tenant_id=dataset.tenant_id,
  724. provider=dataset.embedding_model_provider,
  725. model_type=ModelType.TEXT_EMBEDDING,
  726. model=dataset.embedding_model,
  727. )
  728. else:
  729. embedding_model_instance = self.model_manager.get_default_model_instance(
  730. tenant_id=dataset.tenant_id,
  731. model_type=ModelType.TEXT_EMBEDDING,
  732. )
  733. documents = index_processor.transform(
  734. text_docs,
  735. embedding_model_instance=embedding_model_instance,
  736. process_rule=process_rule,
  737. tenant_id=dataset.tenant_id,
  738. doc_language=doc_language,
  739. )
  740. return documents
  741. def _load_segments(self, dataset, dataset_document, documents):
  742. # save node to document segment
  743. doc_store = DatasetDocumentStore(
  744. dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
  745. )
  746. # add document segments
  747. doc_store.add_documents(documents)
  748. # update document status to indexing
  749. cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  750. self._update_document_index_status(
  751. document_id=dataset_document.id,
  752. after_indexing_status="indexing",
  753. extra_update_params={
  754. DatasetDocument.cleaning_completed_at: cur_time,
  755. DatasetDocument.splitting_completed_at: cur_time,
  756. },
  757. )
  758. # update segment status to indexing
  759. self._update_segments_by_document(
  760. dataset_document_id=dataset_document.id,
  761. update_params={
  762. DocumentSegment.status: "indexing",
  763. DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  764. },
  765. )
  766. pass
  767. class DocumentIsPausedError(Exception):
  768. pass
  769. class DocumentIsDeletedPausedError(Exception):
  770. pass