indexing_runner.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822
  1. import datetime
  2. import json
  3. import logging
  4. import re
  5. import threading
  6. import time
  7. import uuid
  8. from typing import Optional, cast
  9. from flask import Flask, current_app
  10. from flask_login import current_user
  11. from sqlalchemy.orm.exc import ObjectDeletedError
  12. from core.docstore.dataset_docstore import DatasetDocumentStore
  13. from core.errors.error import ProviderTokenNotInitError
  14. from core.generator.llm_generator import LLMGenerator
  15. from core.model_manager import ModelInstance, ModelManager
  16. from core.model_runtime.entities.model_entities import ModelType, PriceType
  17. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  18. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  19. from core.rag.extractor.entity.extract_setting import ExtractSetting
  20. from core.rag.index_processor.index_processor_base import BaseIndexProcessor
  21. from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
  22. from core.rag.models.document import Document
  23. from core.splitter.fixed_text_splitter import EnhanceRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter
  24. from core.splitter.text_splitter import TextSplitter
  25. from extensions.ext_database import db
  26. from extensions.ext_redis import redis_client
  27. from extensions.ext_storage import storage
  28. from libs import helper
  29. from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
  30. from models.dataset import Document as DatasetDocument
  31. from models.model import UploadFile
  32. from services.feature_service import FeatureService
  33. class IndexingRunner:
  34. def __init__(self):
  35. self.storage = storage
  36. self.model_manager = ModelManager()
  37. def run(self, dataset_documents: list[DatasetDocument]):
  38. """Run the indexing process."""
  39. for dataset_document in dataset_documents:
  40. try:
  41. # get dataset
  42. dataset = Dataset.query.filter_by(
  43. id=dataset_document.dataset_id
  44. ).first()
  45. if not dataset:
  46. raise ValueError("no dataset found")
  47. # get the process rule
  48. processing_rule = db.session.query(DatasetProcessRule). \
  49. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  50. first()
  51. index_type = dataset_document.doc_form
  52. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  53. # extract
  54. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  55. # transform
  56. documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
  57. processing_rule.to_dict())
  58. # save segment
  59. self._load_segments(dataset, dataset_document, documents)
  60. # load
  61. self._load(
  62. index_processor=index_processor,
  63. dataset=dataset,
  64. dataset_document=dataset_document,
  65. documents=documents
  66. )
  67. except DocumentIsPausedException:
  68. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  69. except ProviderTokenNotInitError as e:
  70. dataset_document.indexing_status = 'error'
  71. dataset_document.error = str(e.description)
  72. dataset_document.stopped_at = datetime.datetime.utcnow()
  73. db.session.commit()
  74. except ObjectDeletedError:
  75. logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
  76. except Exception as e:
  77. logging.exception("consume document failed")
  78. dataset_document.indexing_status = 'error'
  79. dataset_document.error = str(e)
  80. dataset_document.stopped_at = datetime.datetime.utcnow()
  81. db.session.commit()
  82. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  83. """Run the indexing process when the index_status is splitting."""
  84. try:
  85. # get dataset
  86. dataset = Dataset.query.filter_by(
  87. id=dataset_document.dataset_id
  88. ).first()
  89. if not dataset:
  90. raise ValueError("no dataset found")
  91. # get exist document_segment list and delete
  92. document_segments = DocumentSegment.query.filter_by(
  93. dataset_id=dataset.id,
  94. document_id=dataset_document.id
  95. ).all()
  96. for document_segment in document_segments:
  97. db.session.delete(document_segment)
  98. db.session.commit()
  99. # get the process rule
  100. processing_rule = db.session.query(DatasetProcessRule). \
  101. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  102. first()
  103. index_type = dataset_document.doc_form
  104. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  105. # extract
  106. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  107. # transform
  108. documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
  109. processing_rule.to_dict())
  110. # save segment
  111. self._load_segments(dataset, dataset_document, documents)
  112. # load
  113. self._load(
  114. index_processor=index_processor,
  115. dataset=dataset,
  116. dataset_document=dataset_document,
  117. documents=documents
  118. )
  119. except DocumentIsPausedException:
  120. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  121. except ProviderTokenNotInitError as e:
  122. dataset_document.indexing_status = 'error'
  123. dataset_document.error = str(e.description)
  124. dataset_document.stopped_at = datetime.datetime.utcnow()
  125. db.session.commit()
  126. except Exception as e:
  127. logging.exception("consume document failed")
  128. dataset_document.indexing_status = 'error'
  129. dataset_document.error = str(e)
  130. dataset_document.stopped_at = datetime.datetime.utcnow()
  131. db.session.commit()
  132. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  133. """Run the indexing process when the index_status is indexing."""
  134. try:
  135. # get dataset
  136. dataset = Dataset.query.filter_by(
  137. id=dataset_document.dataset_id
  138. ).first()
  139. if not dataset:
  140. raise ValueError("no dataset found")
  141. # get exist document_segment list and delete
  142. document_segments = DocumentSegment.query.filter_by(
  143. dataset_id=dataset.id,
  144. document_id=dataset_document.id
  145. ).all()
  146. documents = []
  147. if document_segments:
  148. for document_segment in document_segments:
  149. # transform segment to node
  150. if document_segment.status != "completed":
  151. document = Document(
  152. page_content=document_segment.content,
  153. metadata={
  154. "doc_id": document_segment.index_node_id,
  155. "doc_hash": document_segment.index_node_hash,
  156. "document_id": document_segment.document_id,
  157. "dataset_id": document_segment.dataset_id,
  158. }
  159. )
  160. documents.append(document)
  161. # build index
  162. # get the process rule
  163. processing_rule = db.session.query(DatasetProcessRule). \
  164. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  165. first()
  166. index_type = dataset_document.doc_form
  167. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  168. self._load(
  169. index_processor=index_processor,
  170. dataset=dataset,
  171. dataset_document=dataset_document,
  172. documents=documents
  173. )
  174. except DocumentIsPausedException:
  175. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  176. except ProviderTokenNotInitError as e:
  177. dataset_document.indexing_status = 'error'
  178. dataset_document.error = str(e.description)
  179. dataset_document.stopped_at = datetime.datetime.utcnow()
  180. db.session.commit()
  181. except Exception as e:
  182. logging.exception("consume document failed")
  183. dataset_document.indexing_status = 'error'
  184. dataset_document.error = str(e)
  185. dataset_document.stopped_at = datetime.datetime.utcnow()
  186. db.session.commit()
  187. def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
  188. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  189. indexing_technique: str = 'economy') -> dict:
  190. """
  191. Estimate the indexing for the document.
  192. """
  193. # check document limit
  194. features = FeatureService.get_features(tenant_id)
  195. if features.billing.enabled:
  196. count = len(extract_settings)
  197. batch_upload_limit = int(current_app.config['BATCH_UPLOAD_LIMIT'])
  198. if count > batch_upload_limit:
  199. raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
  200. embedding_model_instance = None
  201. if dataset_id:
  202. dataset = Dataset.query.filter_by(
  203. id=dataset_id
  204. ).first()
  205. if not dataset:
  206. raise ValueError('Dataset not found.')
  207. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  208. if dataset.embedding_model_provider:
  209. embedding_model_instance = self.model_manager.get_model_instance(
  210. tenant_id=tenant_id,
  211. provider=dataset.embedding_model_provider,
  212. model_type=ModelType.TEXT_EMBEDDING,
  213. model=dataset.embedding_model
  214. )
  215. else:
  216. embedding_model_instance = self.model_manager.get_default_model_instance(
  217. tenant_id=tenant_id,
  218. model_type=ModelType.TEXT_EMBEDDING,
  219. )
  220. else:
  221. if indexing_technique == 'high_quality':
  222. embedding_model_instance = self.model_manager.get_default_model_instance(
  223. tenant_id=tenant_id,
  224. model_type=ModelType.TEXT_EMBEDDING,
  225. )
  226. tokens = 0
  227. preview_texts = []
  228. total_segments = 0
  229. total_price = 0
  230. currency = 'USD'
  231. index_type = doc_form
  232. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  233. all_text_docs = []
  234. for extract_setting in extract_settings:
  235. # extract
  236. text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
  237. all_text_docs.extend(text_docs)
  238. processing_rule = DatasetProcessRule(
  239. mode=tmp_processing_rule["mode"],
  240. rules=json.dumps(tmp_processing_rule["rules"])
  241. )
  242. # get splitter
  243. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  244. # split to documents
  245. documents = self._split_to_documents_for_estimate(
  246. text_docs=text_docs,
  247. splitter=splitter,
  248. processing_rule=processing_rule
  249. )
  250. total_segments += len(documents)
  251. for document in documents:
  252. if len(preview_texts) < 5:
  253. preview_texts.append(document.page_content)
  254. if indexing_technique == 'high_quality' or embedding_model_instance:
  255. embedding_model_type_instance = embedding_model_instance.model_type_instance
  256. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  257. tokens += embedding_model_type_instance.get_num_tokens(
  258. model=embedding_model_instance.model,
  259. credentials=embedding_model_instance.credentials,
  260. texts=[self.filter_string(document.page_content)]
  261. )
  262. if doc_form and doc_form == 'qa_model':
  263. model_instance = self.model_manager.get_default_model_instance(
  264. tenant_id=tenant_id,
  265. model_type=ModelType.LLM
  266. )
  267. model_type_instance = model_instance.model_type_instance
  268. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  269. if len(preview_texts) > 0:
  270. # qa model document
  271. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
  272. doc_language)
  273. document_qa_list = self.format_split_text(response)
  274. price_info = model_type_instance.get_price(
  275. model=model_instance.model,
  276. credentials=model_instance.credentials,
  277. price_type=PriceType.INPUT,
  278. tokens=total_segments * 2000,
  279. )
  280. return {
  281. "total_segments": total_segments * 20,
  282. "tokens": total_segments * 2000,
  283. "total_price": '{:f}'.format(price_info.total_amount),
  284. "currency": price_info.currency,
  285. "qa_preview": document_qa_list,
  286. "preview": preview_texts
  287. }
  288. if embedding_model_instance:
  289. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
  290. embedding_price_info = embedding_model_type_instance.get_price(
  291. model=embedding_model_instance.model,
  292. credentials=embedding_model_instance.credentials,
  293. price_type=PriceType.INPUT,
  294. tokens=tokens
  295. )
  296. total_price = '{:f}'.format(embedding_price_info.total_amount)
  297. currency = embedding_price_info.currency
  298. return {
  299. "total_segments": total_segments,
  300. "tokens": tokens,
  301. "total_price": total_price,
  302. "currency": currency,
  303. "preview": preview_texts
  304. }
  305. def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
  306. -> list[Document]:
  307. # load file
  308. if dataset_document.data_source_type not in ["upload_file", "notion_import"]:
  309. return []
  310. data_source_info = dataset_document.data_source_info_dict
  311. text_docs = []
  312. if dataset_document.data_source_type == 'upload_file':
  313. if not data_source_info or 'upload_file_id' not in data_source_info:
  314. raise ValueError("no upload file found")
  315. file_detail = db.session.query(UploadFile). \
  316. filter(UploadFile.id == data_source_info['upload_file_id']). \
  317. one_or_none()
  318. if file_detail:
  319. extract_setting = ExtractSetting(
  320. datasource_type="upload_file",
  321. upload_file=file_detail,
  322. document_model=dataset_document.doc_form
  323. )
  324. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
  325. elif dataset_document.data_source_type == 'notion_import':
  326. if (not data_source_info or 'notion_workspace_id' not in data_source_info
  327. or 'notion_page_id' not in data_source_info):
  328. raise ValueError("no notion import info found")
  329. extract_setting = ExtractSetting(
  330. datasource_type="notion_import",
  331. notion_info={
  332. "notion_workspace_id": data_source_info['notion_workspace_id'],
  333. "notion_obj_id": data_source_info['notion_page_id'],
  334. "notion_page_type": data_source_info['type'],
  335. "document": dataset_document,
  336. "tenant_id": dataset_document.tenant_id
  337. },
  338. document_model=dataset_document.doc_form
  339. )
  340. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
  341. # update document status to splitting
  342. self._update_document_index_status(
  343. document_id=dataset_document.id,
  344. after_indexing_status="splitting",
  345. extra_update_params={
  346. DatasetDocument.word_count: sum([len(text_doc.page_content) for text_doc in text_docs]),
  347. DatasetDocument.parsing_completed_at: datetime.datetime.utcnow()
  348. }
  349. )
  350. # replace doc id to document model id
  351. text_docs = cast(list[Document], text_docs)
  352. for text_doc in text_docs:
  353. text_doc.metadata['document_id'] = dataset_document.id
  354. text_doc.metadata['dataset_id'] = dataset_document.dataset_id
  355. return text_docs
  356. def filter_string(self, text):
  357. text = re.sub(r'<\|', '<', text)
  358. text = re.sub(r'\|>', '>', text)
  359. text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
  360. # Unicode U+FFFE
  361. text = re.sub('\uFFFE', '', text)
  362. return text
  363. def _get_splitter(self, processing_rule: DatasetProcessRule,
  364. embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
  365. """
  366. Get the NodeParser object according to the processing rule.
  367. """
  368. if processing_rule.mode == "custom":
  369. # The user-defined segmentation rule
  370. rules = json.loads(processing_rule.rules)
  371. segmentation = rules["segmentation"]
  372. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > 1000:
  373. raise ValueError("Custom segment length should be between 50 and 1000.")
  374. separator = segmentation["separator"]
  375. if separator:
  376. separator = separator.replace('\\n', '\n')
  377. if 'chunk_overlap' in segmentation and segmentation['chunk_overlap']:
  378. chunk_overlap = segmentation['chunk_overlap']
  379. else:
  380. chunk_overlap = 0
  381. character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
  382. chunk_size=segmentation["max_tokens"],
  383. chunk_overlap=chunk_overlap,
  384. fixed_separator=separator,
  385. separators=["\n\n", "。", ".", " ", ""],
  386. embedding_model_instance=embedding_model_instance
  387. )
  388. else:
  389. # Automatic segmentation
  390. character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
  391. chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
  392. chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
  393. separators=["\n\n", "。", ".", " ", ""],
  394. embedding_model_instance=embedding_model_instance
  395. )
  396. return character_splitter
  397. def _step_split(self, text_docs: list[Document], splitter: TextSplitter,
  398. dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
  399. -> list[Document]:
  400. """
  401. Split the text documents into documents and save them to the document segment.
  402. """
  403. documents = self._split_to_documents(
  404. text_docs=text_docs,
  405. splitter=splitter,
  406. processing_rule=processing_rule,
  407. tenant_id=dataset.tenant_id,
  408. document_form=dataset_document.doc_form,
  409. document_language=dataset_document.doc_language
  410. )
  411. # save node to document segment
  412. doc_store = DatasetDocumentStore(
  413. dataset=dataset,
  414. user_id=dataset_document.created_by,
  415. document_id=dataset_document.id
  416. )
  417. # add document segments
  418. doc_store.add_documents(documents)
  419. # update document status to indexing
  420. cur_time = datetime.datetime.utcnow()
  421. self._update_document_index_status(
  422. document_id=dataset_document.id,
  423. after_indexing_status="indexing",
  424. extra_update_params={
  425. DatasetDocument.cleaning_completed_at: cur_time,
  426. DatasetDocument.splitting_completed_at: cur_time,
  427. }
  428. )
  429. # update segment status to indexing
  430. self._update_segments_by_document(
  431. dataset_document_id=dataset_document.id,
  432. update_params={
  433. DocumentSegment.status: "indexing",
  434. DocumentSegment.indexing_at: datetime.datetime.utcnow()
  435. }
  436. )
  437. return documents
  438. def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter,
  439. processing_rule: DatasetProcessRule, tenant_id: str,
  440. document_form: str, document_language: str) -> list[Document]:
  441. """
  442. Split the text documents into nodes.
  443. """
  444. all_documents = []
  445. all_qa_documents = []
  446. for text_doc in text_docs:
  447. # document clean
  448. document_text = self._document_clean(text_doc.page_content, processing_rule)
  449. text_doc.page_content = document_text
  450. # parse document to nodes
  451. documents = splitter.split_documents([text_doc])
  452. split_documents = []
  453. for document_node in documents:
  454. if document_node.page_content.strip():
  455. doc_id = str(uuid.uuid4())
  456. hash = helper.generate_text_hash(document_node.page_content)
  457. document_node.metadata['doc_id'] = doc_id
  458. document_node.metadata['doc_hash'] = hash
  459. # delete Spliter character
  460. page_content = document_node.page_content
  461. if page_content.startswith(".") or page_content.startswith("。"):
  462. page_content = page_content[1:]
  463. else:
  464. page_content = page_content
  465. document_node.page_content = page_content
  466. if document_node.page_content:
  467. split_documents.append(document_node)
  468. all_documents.extend(split_documents)
  469. # processing qa document
  470. if document_form == 'qa_model':
  471. for i in range(0, len(all_documents), 10):
  472. threads = []
  473. sub_documents = all_documents[i:i + 10]
  474. for doc in sub_documents:
  475. document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
  476. 'flask_app': current_app._get_current_object(),
  477. 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
  478. 'document_language': document_language})
  479. threads.append(document_format_thread)
  480. document_format_thread.start()
  481. for thread in threads:
  482. thread.join()
  483. return all_qa_documents
  484. return all_documents
  485. def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
  486. format_documents = []
  487. if document_node.page_content is None or not document_node.page_content.strip():
  488. return
  489. with flask_app.app_context():
  490. try:
  491. # qa model document
  492. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
  493. document_qa_list = self.format_split_text(response)
  494. qa_documents = []
  495. for result in document_qa_list:
  496. qa_document = Document(page_content=result['question'], metadata=document_node.metadata.copy())
  497. doc_id = str(uuid.uuid4())
  498. hash = helper.generate_text_hash(result['question'])
  499. qa_document.metadata['answer'] = result['answer']
  500. qa_document.metadata['doc_id'] = doc_id
  501. qa_document.metadata['doc_hash'] = hash
  502. qa_documents.append(qa_document)
  503. format_documents.extend(qa_documents)
  504. except Exception as e:
  505. logging.exception(e)
  506. all_qa_documents.extend(format_documents)
  507. def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter,
  508. processing_rule: DatasetProcessRule) -> list[Document]:
  509. """
  510. Split the text documents into nodes.
  511. """
  512. all_documents = []
  513. for text_doc in text_docs:
  514. # document clean
  515. document_text = self._document_clean(text_doc.page_content, processing_rule)
  516. text_doc.page_content = document_text
  517. # parse document to nodes
  518. documents = splitter.split_documents([text_doc])
  519. split_documents = []
  520. for document in documents:
  521. if document.page_content is None or not document.page_content.strip():
  522. continue
  523. doc_id = str(uuid.uuid4())
  524. hash = helper.generate_text_hash(document.page_content)
  525. document.metadata['doc_id'] = doc_id
  526. document.metadata['doc_hash'] = hash
  527. split_documents.append(document)
  528. all_documents.extend(split_documents)
  529. return all_documents
  530. def _document_clean(self, text: str, processing_rule: DatasetProcessRule) -> str:
  531. """
  532. Clean the document text according to the processing rules.
  533. """
  534. if processing_rule.mode == "automatic":
  535. rules = DatasetProcessRule.AUTOMATIC_RULES
  536. else:
  537. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  538. if 'pre_processing_rules' in rules:
  539. pre_processing_rules = rules["pre_processing_rules"]
  540. for pre_processing_rule in pre_processing_rules:
  541. if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
  542. # Remove extra spaces
  543. pattern = r'\n{3,}'
  544. text = re.sub(pattern, '\n\n', text)
  545. pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
  546. text = re.sub(pattern, ' ', text)
  547. elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
  548. # Remove email
  549. pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
  550. text = re.sub(pattern, '', text)
  551. # Remove URL
  552. pattern = r'https?://[^\s]+'
  553. text = re.sub(pattern, '', text)
  554. return text
  555. def format_split_text(self, text):
  556. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
  557. matches = re.findall(regex, text, re.UNICODE)
  558. return [
  559. {
  560. "question": q,
  561. "answer": re.sub(r"\n\s*", "\n", a.strip())
  562. }
  563. for q, a in matches if q and a
  564. ]
  565. def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
  566. dataset_document: DatasetDocument, documents: list[Document]) -> None:
  567. """
  568. insert index and update document/segment status to completed
  569. """
  570. embedding_model_instance = None
  571. if dataset.indexing_technique == 'high_quality':
  572. embedding_model_instance = self.model_manager.get_model_instance(
  573. tenant_id=dataset.tenant_id,
  574. provider=dataset.embedding_model_provider,
  575. model_type=ModelType.TEXT_EMBEDDING,
  576. model=dataset.embedding_model
  577. )
  578. # chunk nodes by chunk size
  579. indexing_start_at = time.perf_counter()
  580. tokens = 0
  581. chunk_size = 100
  582. embedding_model_type_instance = None
  583. if embedding_model_instance:
  584. embedding_model_type_instance = embedding_model_instance.model_type_instance
  585. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_type_instance)
  586. for i in range(0, len(documents), chunk_size):
  587. # check document is paused
  588. self._check_document_paused_status(dataset_document.id)
  589. chunk_documents = documents[i:i + chunk_size]
  590. if dataset.indexing_technique == 'high_quality' or embedding_model_type_instance:
  591. tokens += sum(
  592. embedding_model_type_instance.get_num_tokens(
  593. embedding_model_instance.model,
  594. embedding_model_instance.credentials,
  595. [document.page_content]
  596. )
  597. for document in chunk_documents
  598. )
  599. # load index
  600. index_processor.load(dataset, chunk_documents)
  601. db.session.add(dataset)
  602. document_ids = [document.metadata['doc_id'] for document in chunk_documents]
  603. db.session.query(DocumentSegment).filter(
  604. DocumentSegment.document_id == dataset_document.id,
  605. DocumentSegment.index_node_id.in_(document_ids),
  606. DocumentSegment.status == "indexing"
  607. ).update({
  608. DocumentSegment.status: "completed",
  609. DocumentSegment.enabled: True,
  610. DocumentSegment.completed_at: datetime.datetime.utcnow()
  611. })
  612. db.session.commit()
  613. indexing_end_at = time.perf_counter()
  614. # update document status to completed
  615. self._update_document_index_status(
  616. document_id=dataset_document.id,
  617. after_indexing_status="completed",
  618. extra_update_params={
  619. DatasetDocument.tokens: tokens,
  620. DatasetDocument.completed_at: datetime.datetime.utcnow(),
  621. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  622. }
  623. )
  624. def _check_document_paused_status(self, document_id: str):
  625. indexing_cache_key = 'document_{}_is_paused'.format(document_id)
  626. result = redis_client.get(indexing_cache_key)
  627. if result:
  628. raise DocumentIsPausedException()
  629. def _update_document_index_status(self, document_id: str, after_indexing_status: str,
  630. extra_update_params: Optional[dict] = None) -> None:
  631. """
  632. Update the document indexing status.
  633. """
  634. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  635. if count > 0:
  636. raise DocumentIsPausedException()
  637. document = DatasetDocument.query.filter_by(id=document_id).first()
  638. if not document:
  639. raise DocumentIsDeletedPausedException()
  640. update_params = {
  641. DatasetDocument.indexing_status: after_indexing_status
  642. }
  643. if extra_update_params:
  644. update_params.update(extra_update_params)
  645. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  646. db.session.commit()
  647. def _update_segments_by_document(self, dataset_document_id: str, update_params: dict) -> None:
  648. """
  649. Update the document segment by document id.
  650. """
  651. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  652. db.session.commit()
  653. def batch_add_segments(self, segments: list[DocumentSegment], dataset: Dataset):
  654. """
  655. Batch add segments index processing
  656. """
  657. documents = []
  658. for segment in segments:
  659. document = Document(
  660. page_content=segment.content,
  661. metadata={
  662. "doc_id": segment.index_node_id,
  663. "doc_hash": segment.index_node_hash,
  664. "document_id": segment.document_id,
  665. "dataset_id": segment.dataset_id,
  666. }
  667. )
  668. documents.append(document)
  669. # save vector index
  670. index_type = dataset.doc_form
  671. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  672. index_processor.load(dataset, documents)
  673. def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
  674. text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]:
  675. # get embedding model instance
  676. embedding_model_instance = None
  677. if dataset.indexing_technique == 'high_quality':
  678. if dataset.embedding_model_provider:
  679. embedding_model_instance = self.model_manager.get_model_instance(
  680. tenant_id=dataset.tenant_id,
  681. provider=dataset.embedding_model_provider,
  682. model_type=ModelType.TEXT_EMBEDDING,
  683. model=dataset.embedding_model
  684. )
  685. else:
  686. embedding_model_instance = self.model_manager.get_default_model_instance(
  687. tenant_id=dataset.tenant_id,
  688. model_type=ModelType.TEXT_EMBEDDING,
  689. )
  690. documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
  691. process_rule=process_rule, tenant_id=dataset.tenant_id,
  692. doc_language=doc_language)
  693. return documents
  694. def _load_segments(self, dataset, dataset_document, documents):
  695. # save node to document segment
  696. doc_store = DatasetDocumentStore(
  697. dataset=dataset,
  698. user_id=dataset_document.created_by,
  699. document_id=dataset_document.id
  700. )
  701. # add document segments
  702. doc_store.add_documents(documents)
  703. # update document status to indexing
  704. cur_time = datetime.datetime.utcnow()
  705. self._update_document_index_status(
  706. document_id=dataset_document.id,
  707. after_indexing_status="indexing",
  708. extra_update_params={
  709. DatasetDocument.cleaning_completed_at: cur_time,
  710. DatasetDocument.splitting_completed_at: cur_time,
  711. }
  712. )
  713. # update segment status to indexing
  714. self._update_segments_by_document(
  715. dataset_document_id=dataset_document.id,
  716. update_params={
  717. DocumentSegment.status: "indexing",
  718. DocumentSegment.indexing_at: datetime.datetime.utcnow()
  719. }
  720. )
  721. pass
  722. class DocumentIsPausedException(Exception):
  723. pass
  724. class DocumentIsDeletedPausedException(Exception):
  725. pass