indexing_runner.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889
  1. import concurrent.futures
  2. import datetime
  3. import json
  4. import logging
  5. import re
  6. import threading
  7. import time
  8. import uuid
  9. from typing import Optional, cast
  10. from flask import Flask, current_app
  11. from flask_login import current_user
  12. from sqlalchemy.orm.exc import ObjectDeletedError
  13. from configs import dify_config
  14. from core.errors.error import ProviderTokenNotInitError
  15. from core.llm_generator.llm_generator import LLMGenerator
  16. from core.model_manager import ModelInstance, ModelManager
  17. from core.model_runtime.entities.model_entities import ModelType, PriceType
  18. from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
  19. from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
  20. from core.rag.datasource.keyword.keyword_factory import Keyword
  21. from core.rag.docstore.dataset_docstore import DatasetDocumentStore
  22. from core.rag.extractor.entity.extract_setting import ExtractSetting
  23. from core.rag.index_processor.index_processor_base import BaseIndexProcessor
  24. from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
  25. from core.rag.models.document import Document
  26. from core.rag.splitter.fixed_text_splitter import (
  27. EnhanceRecursiveCharacterTextSplitter,
  28. FixedRecursiveCharacterTextSplitter,
  29. )
  30. from core.rag.splitter.text_splitter import TextSplitter
  31. from extensions.ext_database import db
  32. from extensions.ext_redis import redis_client
  33. from extensions.ext_storage import storage
  34. from libs import helper
  35. from models.dataset import Dataset, DatasetProcessRule, DocumentSegment
  36. from models.dataset import Document as DatasetDocument
  37. from models.model import UploadFile
  38. from services.feature_service import FeatureService
  39. class IndexingRunner:
  40. def __init__(self):
  41. self.storage = storage
  42. self.model_manager = ModelManager()
  43. def run(self, dataset_documents: list[DatasetDocument]):
  44. """Run the indexing process."""
  45. for dataset_document in dataset_documents:
  46. try:
  47. # get dataset
  48. dataset = Dataset.query.filter_by(
  49. id=dataset_document.dataset_id
  50. ).first()
  51. if not dataset:
  52. raise ValueError("no dataset found")
  53. # get the process rule
  54. processing_rule = db.session.query(DatasetProcessRule). \
  55. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  56. first()
  57. index_type = dataset_document.doc_form
  58. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  59. # extract
  60. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  61. # transform
  62. documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
  63. processing_rule.to_dict())
  64. # save segment
  65. self._load_segments(dataset, dataset_document, documents)
  66. # load
  67. self._load(
  68. index_processor=index_processor,
  69. dataset=dataset,
  70. dataset_document=dataset_document,
  71. documents=documents
  72. )
  73. except DocumentIsPausedException:
  74. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  75. except ProviderTokenNotInitError as e:
  76. dataset_document.indexing_status = 'error'
  77. dataset_document.error = str(e.description)
  78. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  79. db.session.commit()
  80. except ObjectDeletedError:
  81. logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
  82. except Exception as e:
  83. logging.exception("consume document failed")
  84. dataset_document.indexing_status = 'error'
  85. dataset_document.error = str(e)
  86. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  87. db.session.commit()
  88. def run_in_splitting_status(self, dataset_document: DatasetDocument):
  89. """Run the indexing process when the index_status is splitting."""
  90. try:
  91. # get dataset
  92. dataset = Dataset.query.filter_by(
  93. id=dataset_document.dataset_id
  94. ).first()
  95. if not dataset:
  96. raise ValueError("no dataset found")
  97. # get exist document_segment list and delete
  98. document_segments = DocumentSegment.query.filter_by(
  99. dataset_id=dataset.id,
  100. document_id=dataset_document.id
  101. ).all()
  102. for document_segment in document_segments:
  103. db.session.delete(document_segment)
  104. db.session.commit()
  105. # get the process rule
  106. processing_rule = db.session.query(DatasetProcessRule). \
  107. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  108. first()
  109. index_type = dataset_document.doc_form
  110. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  111. # extract
  112. text_docs = self._extract(index_processor, dataset_document, processing_rule.to_dict())
  113. # transform
  114. documents = self._transform(index_processor, dataset, text_docs, dataset_document.doc_language,
  115. processing_rule.to_dict())
  116. # save segment
  117. self._load_segments(dataset, dataset_document, documents)
  118. # load
  119. self._load(
  120. index_processor=index_processor,
  121. dataset=dataset,
  122. dataset_document=dataset_document,
  123. documents=documents
  124. )
  125. except DocumentIsPausedException:
  126. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  127. except ProviderTokenNotInitError as e:
  128. dataset_document.indexing_status = 'error'
  129. dataset_document.error = str(e.description)
  130. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  131. db.session.commit()
  132. except Exception as e:
  133. logging.exception("consume document failed")
  134. dataset_document.indexing_status = 'error'
  135. dataset_document.error = str(e)
  136. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  137. db.session.commit()
  138. def run_in_indexing_status(self, dataset_document: DatasetDocument):
  139. """Run the indexing process when the index_status is indexing."""
  140. try:
  141. # get dataset
  142. dataset = Dataset.query.filter_by(
  143. id=dataset_document.dataset_id
  144. ).first()
  145. if not dataset:
  146. raise ValueError("no dataset found")
  147. # get exist document_segment list and delete
  148. document_segments = DocumentSegment.query.filter_by(
  149. dataset_id=dataset.id,
  150. document_id=dataset_document.id
  151. ).all()
  152. documents = []
  153. if document_segments:
  154. for document_segment in document_segments:
  155. # transform segment to node
  156. if document_segment.status != "completed":
  157. document = Document(
  158. page_content=document_segment.content,
  159. metadata={
  160. "doc_id": document_segment.index_node_id,
  161. "doc_hash": document_segment.index_node_hash,
  162. "document_id": document_segment.document_id,
  163. "dataset_id": document_segment.dataset_id,
  164. }
  165. )
  166. documents.append(document)
  167. # build index
  168. # get the process rule
  169. processing_rule = db.session.query(DatasetProcessRule). \
  170. filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id). \
  171. first()
  172. index_type = dataset_document.doc_form
  173. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  174. self._load(
  175. index_processor=index_processor,
  176. dataset=dataset,
  177. dataset_document=dataset_document,
  178. documents=documents
  179. )
  180. except DocumentIsPausedException:
  181. raise DocumentIsPausedException('Document paused, document id: {}'.format(dataset_document.id))
  182. except ProviderTokenNotInitError as e:
  183. dataset_document.indexing_status = 'error'
  184. dataset_document.error = str(e.description)
  185. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  186. db.session.commit()
  187. except Exception as e:
  188. logging.exception("consume document failed")
  189. dataset_document.indexing_status = 'error'
  190. dataset_document.error = str(e)
  191. dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  192. db.session.commit()
  193. def indexing_estimate(self, tenant_id: str, extract_settings: list[ExtractSetting], tmp_processing_rule: dict,
  194. doc_form: str = None, doc_language: str = 'English', dataset_id: str = None,
  195. indexing_technique: str = 'economy') -> dict:
  196. """
  197. Estimate the indexing for the document.
  198. """
  199. # check document limit
  200. features = FeatureService.get_features(tenant_id)
  201. if features.billing.enabled:
  202. count = len(extract_settings)
  203. batch_upload_limit = dify_config.BATCH_UPLOAD_LIMIT
  204. if count > batch_upload_limit:
  205. raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
  206. embedding_model_instance = None
  207. if dataset_id:
  208. dataset = Dataset.query.filter_by(
  209. id=dataset_id
  210. ).first()
  211. if not dataset:
  212. raise ValueError('Dataset not found.')
  213. if dataset.indexing_technique == 'high_quality' or indexing_technique == 'high_quality':
  214. if dataset.embedding_model_provider:
  215. embedding_model_instance = self.model_manager.get_model_instance(
  216. tenant_id=tenant_id,
  217. provider=dataset.embedding_model_provider,
  218. model_type=ModelType.TEXT_EMBEDDING,
  219. model=dataset.embedding_model
  220. )
  221. else:
  222. embedding_model_instance = self.model_manager.get_default_model_instance(
  223. tenant_id=tenant_id,
  224. model_type=ModelType.TEXT_EMBEDDING,
  225. )
  226. else:
  227. if indexing_technique == 'high_quality':
  228. embedding_model_instance = self.model_manager.get_default_model_instance(
  229. tenant_id=tenant_id,
  230. model_type=ModelType.TEXT_EMBEDDING,
  231. )
  232. tokens = 0
  233. preview_texts = []
  234. total_segments = 0
  235. total_price = 0
  236. currency = 'USD'
  237. index_type = doc_form
  238. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  239. all_text_docs = []
  240. for extract_setting in extract_settings:
  241. # extract
  242. text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
  243. all_text_docs.extend(text_docs)
  244. processing_rule = DatasetProcessRule(
  245. mode=tmp_processing_rule["mode"],
  246. rules=json.dumps(tmp_processing_rule["rules"])
  247. )
  248. # get splitter
  249. splitter = self._get_splitter(processing_rule, embedding_model_instance)
  250. # split to documents
  251. documents = self._split_to_documents_for_estimate(
  252. text_docs=text_docs,
  253. splitter=splitter,
  254. processing_rule=processing_rule
  255. )
  256. total_segments += len(documents)
  257. for document in documents:
  258. if len(preview_texts) < 5:
  259. preview_texts.append(document.page_content)
  260. if indexing_technique == 'high_quality' or embedding_model_instance:
  261. tokens += embedding_model_instance.get_text_embedding_num_tokens(
  262. texts=[self.filter_string(document.page_content)]
  263. )
  264. if doc_form and doc_form == 'qa_model':
  265. model_instance = self.model_manager.get_default_model_instance(
  266. tenant_id=tenant_id,
  267. model_type=ModelType.LLM
  268. )
  269. model_type_instance = model_instance.model_type_instance
  270. model_type_instance = cast(LargeLanguageModel, model_type_instance)
  271. if len(preview_texts) > 0:
  272. # qa model document
  273. response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
  274. doc_language)
  275. document_qa_list = self.format_split_text(response)
  276. price_info = model_type_instance.get_price(
  277. model=model_instance.model,
  278. credentials=model_instance.credentials,
  279. price_type=PriceType.INPUT,
  280. tokens=total_segments * 2000,
  281. )
  282. return {
  283. "total_segments": total_segments * 20,
  284. "tokens": total_segments * 2000,
  285. "total_price": '{:f}'.format(price_info.total_amount),
  286. "currency": price_info.currency,
  287. "qa_preview": document_qa_list,
  288. "preview": preview_texts
  289. }
  290. if embedding_model_instance:
  291. embedding_model_type_instance = cast(TextEmbeddingModel, embedding_model_instance.model_type_instance)
  292. embedding_price_info = embedding_model_type_instance.get_price(
  293. model=embedding_model_instance.model,
  294. credentials=embedding_model_instance.credentials,
  295. price_type=PriceType.INPUT,
  296. tokens=tokens
  297. )
  298. total_price = '{:f}'.format(embedding_price_info.total_amount)
  299. currency = embedding_price_info.currency
  300. return {
  301. "total_segments": total_segments,
  302. "tokens": tokens,
  303. "total_price": total_price,
  304. "currency": currency,
  305. "preview": preview_texts
  306. }
  307. def _extract(self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict) \
  308. -> list[Document]:
  309. # load file
  310. if dataset_document.data_source_type not in ["upload_file", "notion_import", "website_crawl"]:
  311. return []
  312. data_source_info = dataset_document.data_source_info_dict
  313. text_docs = []
  314. if dataset_document.data_source_type == 'upload_file':
  315. if not data_source_info or 'upload_file_id' not in data_source_info:
  316. raise ValueError("no upload file found")
  317. file_detail = db.session.query(UploadFile). \
  318. filter(UploadFile.id == data_source_info['upload_file_id']). \
  319. one_or_none()
  320. if file_detail:
  321. extract_setting = ExtractSetting(
  322. datasource_type="upload_file",
  323. upload_file=file_detail,
  324. document_model=dataset_document.doc_form
  325. )
  326. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
  327. elif dataset_document.data_source_type == 'notion_import':
  328. if (not data_source_info or 'notion_workspace_id' not in data_source_info
  329. or 'notion_page_id' not in data_source_info):
  330. raise ValueError("no notion import info found")
  331. extract_setting = ExtractSetting(
  332. datasource_type="notion_import",
  333. notion_info={
  334. "notion_workspace_id": data_source_info['notion_workspace_id'],
  335. "notion_obj_id": data_source_info['notion_page_id'],
  336. "notion_page_type": data_source_info['type'],
  337. "document": dataset_document,
  338. "tenant_id": dataset_document.tenant_id
  339. },
  340. document_model=dataset_document.doc_form
  341. )
  342. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
  343. elif dataset_document.data_source_type == 'website_crawl':
  344. if (not data_source_info or 'provider' not in data_source_info
  345. or 'url' not in data_source_info or 'job_id' not in data_source_info):
  346. raise ValueError("no website import info found")
  347. extract_setting = ExtractSetting(
  348. datasource_type="website_crawl",
  349. website_info={
  350. "provider": data_source_info['provider'],
  351. "job_id": data_source_info['job_id'],
  352. "tenant_id": dataset_document.tenant_id,
  353. "url": data_source_info['url'],
  354. "mode": data_source_info['mode'],
  355. "only_main_content": data_source_info['only_main_content']
  356. },
  357. document_model=dataset_document.doc_form
  358. )
  359. text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule['mode'])
  360. # update document status to splitting
  361. self._update_document_index_status(
  362. document_id=dataset_document.id,
  363. after_indexing_status="splitting",
  364. extra_update_params={
  365. DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs),
  366. DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  367. }
  368. )
  369. # replace doc id to document model id
  370. text_docs = cast(list[Document], text_docs)
  371. for text_doc in text_docs:
  372. text_doc.metadata['document_id'] = dataset_document.id
  373. text_doc.metadata['dataset_id'] = dataset_document.dataset_id
  374. return text_docs
  375. @staticmethod
  376. def filter_string(text):
  377. text = re.sub(r'<\|', '<', text)
  378. text = re.sub(r'\|>', '>', text)
  379. text = re.sub(r'[\x00-\x08\x0B\x0C\x0E-\x1F\x7F\xEF\xBF\xBE]', '', text)
  380. # Unicode U+FFFE
  381. text = re.sub('\uFFFE', '', text)
  382. return text
  383. @staticmethod
  384. def _get_splitter(processing_rule: DatasetProcessRule,
  385. embedding_model_instance: Optional[ModelInstance]) -> TextSplitter:
  386. """
  387. Get the NodeParser object according to the processing rule.
  388. """
  389. if processing_rule.mode == "custom":
  390. # The user-defined segmentation rule
  391. rules = json.loads(processing_rule.rules)
  392. segmentation = rules["segmentation"]
  393. max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
  394. if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length:
  395. raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
  396. separator = segmentation["separator"]
  397. if separator:
  398. separator = separator.replace('\\n', '\n')
  399. if segmentation.get('chunk_overlap'):
  400. chunk_overlap = segmentation['chunk_overlap']
  401. else:
  402. chunk_overlap = 0
  403. character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
  404. chunk_size=segmentation["max_tokens"],
  405. chunk_overlap=chunk_overlap,
  406. fixed_separator=separator,
  407. separators=["\n\n", "。", ". ", " ", ""],
  408. embedding_model_instance=embedding_model_instance
  409. )
  410. else:
  411. # Automatic segmentation
  412. character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
  413. chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
  414. chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['chunk_overlap'],
  415. separators=["\n\n", "。", ". ", " ", ""],
  416. embedding_model_instance=embedding_model_instance
  417. )
  418. return character_splitter
  419. def _step_split(self, text_docs: list[Document], splitter: TextSplitter,
  420. dataset: Dataset, dataset_document: DatasetDocument, processing_rule: DatasetProcessRule) \
  421. -> list[Document]:
  422. """
  423. Split the text documents into documents and save them to the document segment.
  424. """
  425. documents = self._split_to_documents(
  426. text_docs=text_docs,
  427. splitter=splitter,
  428. processing_rule=processing_rule,
  429. tenant_id=dataset.tenant_id,
  430. document_form=dataset_document.doc_form,
  431. document_language=dataset_document.doc_language
  432. )
  433. # save node to document segment
  434. doc_store = DatasetDocumentStore(
  435. dataset=dataset,
  436. user_id=dataset_document.created_by,
  437. document_id=dataset_document.id
  438. )
  439. # add document segments
  440. doc_store.add_documents(documents)
  441. # update document status to indexing
  442. cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  443. self._update_document_index_status(
  444. document_id=dataset_document.id,
  445. after_indexing_status="indexing",
  446. extra_update_params={
  447. DatasetDocument.cleaning_completed_at: cur_time,
  448. DatasetDocument.splitting_completed_at: cur_time,
  449. }
  450. )
  451. # update segment status to indexing
  452. self._update_segments_by_document(
  453. dataset_document_id=dataset_document.id,
  454. update_params={
  455. DocumentSegment.status: "indexing",
  456. DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  457. }
  458. )
  459. return documents
  460. def _split_to_documents(self, text_docs: list[Document], splitter: TextSplitter,
  461. processing_rule: DatasetProcessRule, tenant_id: str,
  462. document_form: str, document_language: str) -> list[Document]:
  463. """
  464. Split the text documents into nodes.
  465. """
  466. all_documents = []
  467. all_qa_documents = []
  468. for text_doc in text_docs:
  469. # document clean
  470. document_text = self._document_clean(text_doc.page_content, processing_rule)
  471. text_doc.page_content = document_text
  472. # parse document to nodes
  473. documents = splitter.split_documents([text_doc])
  474. split_documents = []
  475. for document_node in documents:
  476. if document_node.page_content.strip():
  477. doc_id = str(uuid.uuid4())
  478. hash = helper.generate_text_hash(document_node.page_content)
  479. document_node.metadata['doc_id'] = doc_id
  480. document_node.metadata['doc_hash'] = hash
  481. # delete Spliter character
  482. page_content = document_node.page_content
  483. if page_content.startswith(".") or page_content.startswith("。"):
  484. page_content = page_content[1:]
  485. else:
  486. page_content = page_content
  487. document_node.page_content = page_content
  488. if document_node.page_content:
  489. split_documents.append(document_node)
  490. all_documents.extend(split_documents)
  491. # processing qa document
  492. if document_form == 'qa_model':
  493. for i in range(0, len(all_documents), 10):
  494. threads = []
  495. sub_documents = all_documents[i:i + 10]
  496. for doc in sub_documents:
  497. document_format_thread = threading.Thread(target=self.format_qa_document, kwargs={
  498. 'flask_app': current_app._get_current_object(),
  499. 'tenant_id': tenant_id, 'document_node': doc, 'all_qa_documents': all_qa_documents,
  500. 'document_language': document_language})
  501. threads.append(document_format_thread)
  502. document_format_thread.start()
  503. for thread in threads:
  504. thread.join()
  505. return all_qa_documents
  506. return all_documents
  507. def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
  508. format_documents = []
  509. if document_node.page_content is None or not document_node.page_content.strip():
  510. return
  511. with flask_app.app_context():
  512. try:
  513. # qa model document
  514. response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
  515. document_qa_list = self.format_split_text(response)
  516. qa_documents = []
  517. for result in document_qa_list:
  518. qa_document = Document(page_content=result['question'], metadata=document_node.metadata.model_copy())
  519. doc_id = str(uuid.uuid4())
  520. hash = helper.generate_text_hash(result['question'])
  521. qa_document.metadata['answer'] = result['answer']
  522. qa_document.metadata['doc_id'] = doc_id
  523. qa_document.metadata['doc_hash'] = hash
  524. qa_documents.append(qa_document)
  525. format_documents.extend(qa_documents)
  526. except Exception as e:
  527. logging.exception(e)
  528. all_qa_documents.extend(format_documents)
  529. def _split_to_documents_for_estimate(self, text_docs: list[Document], splitter: TextSplitter,
  530. processing_rule: DatasetProcessRule) -> list[Document]:
  531. """
  532. Split the text documents into nodes.
  533. """
  534. all_documents = []
  535. for text_doc in text_docs:
  536. # document clean
  537. document_text = self._document_clean(text_doc.page_content, processing_rule)
  538. text_doc.page_content = document_text
  539. # parse document to nodes
  540. documents = splitter.split_documents([text_doc])
  541. split_documents = []
  542. for document in documents:
  543. if document.page_content is None or not document.page_content.strip():
  544. continue
  545. doc_id = str(uuid.uuid4())
  546. hash = helper.generate_text_hash(document.page_content)
  547. document.metadata['doc_id'] = doc_id
  548. document.metadata['doc_hash'] = hash
  549. split_documents.append(document)
  550. all_documents.extend(split_documents)
  551. return all_documents
  552. @staticmethod
  553. def _document_clean(text: str, processing_rule: DatasetProcessRule) -> str:
  554. """
  555. Clean the document text according to the processing rules.
  556. """
  557. if processing_rule.mode == "automatic":
  558. rules = DatasetProcessRule.AUTOMATIC_RULES
  559. else:
  560. rules = json.loads(processing_rule.rules) if processing_rule.rules else {}
  561. if 'pre_processing_rules' in rules:
  562. pre_processing_rules = rules["pre_processing_rules"]
  563. for pre_processing_rule in pre_processing_rules:
  564. if pre_processing_rule["id"] == "remove_extra_spaces" and pre_processing_rule["enabled"] is True:
  565. # Remove extra spaces
  566. pattern = r'\n{3,}'
  567. text = re.sub(pattern, '\n\n', text)
  568. pattern = r'[\t\f\r\x20\u00a0\u1680\u180e\u2000-\u200a\u202f\u205f\u3000]{2,}'
  569. text = re.sub(pattern, ' ', text)
  570. elif pre_processing_rule["id"] == "remove_urls_emails" and pre_processing_rule["enabled"] is True:
  571. # Remove email
  572. pattern = r'([a-zA-Z0-9_.+-]+@[a-zA-Z0-9-]+\.[a-zA-Z0-9-.]+)'
  573. text = re.sub(pattern, '', text)
  574. # Remove URL
  575. pattern = r'https?://[^\s]+'
  576. text = re.sub(pattern, '', text)
  577. return text
  578. @staticmethod
  579. def format_split_text(text):
  580. regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
  581. matches = re.findall(regex, text, re.UNICODE)
  582. return [
  583. {
  584. "question": q,
  585. "answer": re.sub(r"\n\s*", "\n", a.strip())
  586. }
  587. for q, a in matches if q and a
  588. ]
  589. def _load(self, index_processor: BaseIndexProcessor, dataset: Dataset,
  590. dataset_document: DatasetDocument, documents: list[Document]) -> None:
  591. """
  592. insert index and update document/segment status to completed
  593. """
  594. embedding_model_instance = None
  595. if dataset.indexing_technique == 'high_quality':
  596. embedding_model_instance = self.model_manager.get_model_instance(
  597. tenant_id=dataset.tenant_id,
  598. provider=dataset.embedding_model_provider,
  599. model_type=ModelType.TEXT_EMBEDDING,
  600. model=dataset.embedding_model
  601. )
  602. # chunk nodes by chunk size
  603. indexing_start_at = time.perf_counter()
  604. tokens = 0
  605. chunk_size = 10
  606. # create keyword index
  607. create_keyword_thread = threading.Thread(target=self._process_keyword_index,
  608. args=(current_app._get_current_object(),
  609. dataset.id, dataset_document.id, documents))
  610. create_keyword_thread.start()
  611. if dataset.indexing_technique == 'high_quality':
  612. with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
  613. futures = []
  614. for i in range(0, len(documents), chunk_size):
  615. chunk_documents = documents[i:i + chunk_size]
  616. futures.append(executor.submit(self._process_chunk, current_app._get_current_object(), index_processor,
  617. chunk_documents, dataset,
  618. dataset_document, embedding_model_instance))
  619. for future in futures:
  620. tokens += future.result()
  621. create_keyword_thread.join()
  622. indexing_end_at = time.perf_counter()
  623. # update document status to completed
  624. self._update_document_index_status(
  625. document_id=dataset_document.id,
  626. after_indexing_status="completed",
  627. extra_update_params={
  628. DatasetDocument.tokens: tokens,
  629. DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None),
  630. DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at,
  631. DatasetDocument.error: None,
  632. }
  633. )
  634. @staticmethod
  635. def _process_keyword_index(flask_app, dataset_id, document_id, documents):
  636. with flask_app.app_context():
  637. dataset = Dataset.query.filter_by(id=dataset_id).first()
  638. if not dataset:
  639. raise ValueError("no dataset found")
  640. keyword = Keyword(dataset)
  641. keyword.create(documents)
  642. if dataset.indexing_technique != 'high_quality':
  643. document_ids = [document.metadata['doc_id'] for document in documents]
  644. db.session.query(DocumentSegment).filter(
  645. DocumentSegment.document_id == document_id,
  646. DocumentSegment.index_node_id.in_(document_ids),
  647. DocumentSegment.status == "indexing"
  648. ).update({
  649. DocumentSegment.status: "completed",
  650. DocumentSegment.enabled: True,
  651. DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  652. })
  653. db.session.commit()
  654. def _process_chunk(self, flask_app, index_processor, chunk_documents, dataset, dataset_document,
  655. embedding_model_instance):
  656. with flask_app.app_context():
  657. # check document is paused
  658. self._check_document_paused_status(dataset_document.id)
  659. tokens = 0
  660. if embedding_model_instance:
  661. tokens += sum(
  662. embedding_model_instance.get_text_embedding_num_tokens(
  663. [document.page_content]
  664. )
  665. for document in chunk_documents
  666. )
  667. # load index
  668. index_processor.load(dataset, chunk_documents, with_keywords=False)
  669. document_ids = [document.metadata['doc_id'] for document in chunk_documents]
  670. db.session.query(DocumentSegment).filter(
  671. DocumentSegment.document_id == dataset_document.id,
  672. DocumentSegment.index_node_id.in_(document_ids),
  673. DocumentSegment.status == "indexing"
  674. ).update({
  675. DocumentSegment.status: "completed",
  676. DocumentSegment.enabled: True,
  677. DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  678. })
  679. db.session.commit()
  680. return tokens
  681. @staticmethod
  682. def _check_document_paused_status(document_id: str):
  683. indexing_cache_key = 'document_{}_is_paused'.format(document_id)
  684. result = redis_client.get(indexing_cache_key)
  685. if result:
  686. raise DocumentIsPausedException()
  687. @staticmethod
  688. def _update_document_index_status(document_id: str, after_indexing_status: str,
  689. extra_update_params: Optional[dict] = None) -> None:
  690. """
  691. Update the document indexing status.
  692. """
  693. count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
  694. if count > 0:
  695. raise DocumentIsPausedException()
  696. document = DatasetDocument.query.filter_by(id=document_id).first()
  697. if not document:
  698. raise DocumentIsDeletedPausedException()
  699. update_params = {
  700. DatasetDocument.indexing_status: after_indexing_status
  701. }
  702. if extra_update_params:
  703. update_params.update(extra_update_params)
  704. DatasetDocument.query.filter_by(id=document_id).update(update_params)
  705. db.session.commit()
  706. @staticmethod
  707. def _update_segments_by_document(dataset_document_id: str, update_params: dict) -> None:
  708. """
  709. Update the document segment by document id.
  710. """
  711. DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
  712. db.session.commit()
  713. @staticmethod
  714. def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset):
  715. """
  716. Batch add segments index processing
  717. """
  718. documents = []
  719. for segment in segments:
  720. document = Document(
  721. page_content=segment.content,
  722. metadata={
  723. "doc_id": segment.index_node_id,
  724. "doc_hash": segment.index_node_hash,
  725. "document_id": segment.document_id,
  726. "dataset_id": segment.dataset_id,
  727. }
  728. )
  729. documents.append(document)
  730. # save vector index
  731. index_type = dataset.doc_form
  732. index_processor = IndexProcessorFactory(index_type).init_index_processor()
  733. index_processor.load(dataset, documents)
  734. def _transform(self, index_processor: BaseIndexProcessor, dataset: Dataset,
  735. text_docs: list[Document], doc_language: str, process_rule: dict) -> list[Document]:
  736. # get embedding model instance
  737. embedding_model_instance = None
  738. if dataset.indexing_technique == 'high_quality':
  739. if dataset.embedding_model_provider:
  740. embedding_model_instance = self.model_manager.get_model_instance(
  741. tenant_id=dataset.tenant_id,
  742. provider=dataset.embedding_model_provider,
  743. model_type=ModelType.TEXT_EMBEDDING,
  744. model=dataset.embedding_model
  745. )
  746. else:
  747. embedding_model_instance = self.model_manager.get_default_model_instance(
  748. tenant_id=dataset.tenant_id,
  749. model_type=ModelType.TEXT_EMBEDDING,
  750. )
  751. documents = index_processor.transform(text_docs, embedding_model_instance=embedding_model_instance,
  752. process_rule=process_rule, tenant_id=dataset.tenant_id,
  753. doc_language=doc_language)
  754. return documents
  755. def _load_segments(self, dataset, dataset_document, documents):
  756. # save node to document segment
  757. doc_store = DatasetDocumentStore(
  758. dataset=dataset,
  759. user_id=dataset_document.created_by,
  760. document_id=dataset_document.id
  761. )
  762. # add document segments
  763. doc_store.add_documents(documents)
  764. # update document status to indexing
  765. cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  766. self._update_document_index_status(
  767. document_id=dataset_document.id,
  768. after_indexing_status="indexing",
  769. extra_update_params={
  770. DatasetDocument.cleaning_completed_at: cur_time,
  771. DatasetDocument.splitting_completed_at: cur_time,
  772. }
  773. )
  774. # update segment status to indexing
  775. self._update_segments_by_document(
  776. dataset_document_id=dataset_document.id,
  777. update_params={
  778. DocumentSegment.status: "indexing",
  779. DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None)
  780. }
  781. )
  782. pass
  783. class DocumentIsPausedException(Exception):
  784. pass
  785. class DocumentIsDeletedPausedException(Exception):
  786. pass