dataset.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. import base64
  2. import hashlib
  3. import hmac
  4. import json
  5. import logging
  6. import os
  7. import pickle
  8. import re
  9. import time
  10. from json import JSONDecodeError
  11. from flask import current_app
  12. from sqlalchemy import func
  13. from sqlalchemy.dialects.postgresql import JSONB
  14. from extensions.ext_database import db
  15. from extensions.ext_storage import storage
  16. from models import StringUUID
  17. from models.account import Account
  18. from models.model import App, Tag, TagBinding, UploadFile
  19. class Dataset(db.Model):
  20. __tablename__ = 'datasets'
  21. __table_args__ = (
  22. db.PrimaryKeyConstraint('id', name='dataset_pkey'),
  23. db.Index('dataset_tenant_idx', 'tenant_id'),
  24. db.Index('retrieval_model_idx', "retrieval_model", postgresql_using='gin')
  25. )
  26. INDEXING_TECHNIQUE_LIST = ['high_quality', 'economy', None]
  27. id = db.Column(StringUUID, server_default=db.text('uuid_generate_v4()'))
  28. tenant_id = db.Column(StringUUID, nullable=False)
  29. name = db.Column(db.String(255), nullable=False)
  30. description = db.Column(db.Text, nullable=True)
  31. provider = db.Column(db.String(255), nullable=False,
  32. server_default=db.text("'vendor'::character varying"))
  33. permission = db.Column(db.String(255), nullable=False,
  34. server_default=db.text("'only_me'::character varying"))
  35. data_source_type = db.Column(db.String(255))
  36. indexing_technique = db.Column(db.String(255), nullable=True)
  37. index_struct = db.Column(db.Text, nullable=True)
  38. created_by = db.Column(StringUUID, nullable=False)
  39. created_at = db.Column(db.DateTime, nullable=False,
  40. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  41. updated_by = db.Column(StringUUID, nullable=True)
  42. updated_at = db.Column(db.DateTime, nullable=False,
  43. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  44. embedding_model = db.Column(db.String(255), nullable=True)
  45. embedding_model_provider = db.Column(db.String(255), nullable=True)
  46. collection_binding_id = db.Column(StringUUID, nullable=True)
  47. retrieval_model = db.Column(JSONB, nullable=True)
  48. @property
  49. def dataset_keyword_table(self):
  50. dataset_keyword_table = db.session.query(DatasetKeywordTable).filter(
  51. DatasetKeywordTable.dataset_id == self.id).first()
  52. if dataset_keyword_table:
  53. return dataset_keyword_table
  54. return None
  55. @property
  56. def index_struct_dict(self):
  57. return json.loads(self.index_struct) if self.index_struct else None
  58. @property
  59. def created_by_account(self):
  60. return Account.query.get(self.created_by)
  61. @property
  62. def latest_process_rule(self):
  63. return DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) \
  64. .order_by(DatasetProcessRule.created_at.desc()).first()
  65. @property
  66. def app_count(self):
  67. return db.session.query(func.count(AppDatasetJoin.id)).filter(AppDatasetJoin.dataset_id == self.id,
  68. App.id == AppDatasetJoin.app_id).scalar()
  69. @property
  70. def document_count(self):
  71. return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
  72. @property
  73. def available_document_count(self):
  74. return db.session.query(func.count(Document.id)).filter(
  75. Document.dataset_id == self.id,
  76. Document.indexing_status == 'completed',
  77. Document.enabled == True,
  78. Document.archived == False
  79. ).scalar()
  80. @property
  81. def available_segment_count(self):
  82. return db.session.query(func.count(DocumentSegment.id)).filter(
  83. DocumentSegment.dataset_id == self.id,
  84. DocumentSegment.status == 'completed',
  85. DocumentSegment.enabled == True
  86. ).scalar()
  87. @property
  88. def word_count(self):
  89. return Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) \
  90. .filter(Document.dataset_id == self.id).scalar()
  91. @property
  92. def doc_form(self):
  93. document = db.session.query(Document).filter(
  94. Document.dataset_id == self.id).first()
  95. if document:
  96. return document.doc_form
  97. return None
  98. @property
  99. def retrieval_model_dict(self):
  100. default_retrieval_model = {
  101. 'search_method': 'semantic_search',
  102. 'reranking_enable': False,
  103. 'reranking_model': {
  104. 'reranking_provider_name': '',
  105. 'reranking_model_name': ''
  106. },
  107. 'top_k': 2,
  108. 'score_threshold_enabled': False
  109. }
  110. return self.retrieval_model if self.retrieval_model else default_retrieval_model
  111. @property
  112. def tags(self):
  113. tags = db.session.query(Tag).join(
  114. TagBinding,
  115. Tag.id == TagBinding.tag_id
  116. ).filter(
  117. TagBinding.target_id == self.id,
  118. TagBinding.tenant_id == self.tenant_id,
  119. Tag.tenant_id == self.tenant_id,
  120. Tag.type == 'knowledge'
  121. ).all()
  122. return tags if tags else []
  123. @staticmethod
  124. def gen_collection_name_by_id(dataset_id: str) -> str:
  125. normalized_dataset_id = dataset_id.replace("-", "_")
  126. return f'Vector_index_{normalized_dataset_id}_Node'
  127. class DatasetProcessRule(db.Model):
  128. __tablename__ = 'dataset_process_rules'
  129. __table_args__ = (
  130. db.PrimaryKeyConstraint('id', name='dataset_process_rule_pkey'),
  131. db.Index('dataset_process_rule_dataset_id_idx', 'dataset_id'),
  132. )
  133. id = db.Column(StringUUID, nullable=False,
  134. server_default=db.text('uuid_generate_v4()'))
  135. dataset_id = db.Column(StringUUID, nullable=False)
  136. mode = db.Column(db.String(255), nullable=False,
  137. server_default=db.text("'automatic'::character varying"))
  138. rules = db.Column(db.Text, nullable=True)
  139. created_by = db.Column(StringUUID, nullable=False)
  140. created_at = db.Column(db.DateTime, nullable=False,
  141. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  142. MODES = ['automatic', 'custom']
  143. PRE_PROCESSING_RULES = ['remove_stopwords', 'remove_extra_spaces', 'remove_urls_emails']
  144. AUTOMATIC_RULES = {
  145. 'pre_processing_rules': [
  146. {'id': 'remove_extra_spaces', 'enabled': True},
  147. {'id': 'remove_urls_emails', 'enabled': False}
  148. ],
  149. 'segmentation': {
  150. 'delimiter': '\n',
  151. 'max_tokens': 500,
  152. 'chunk_overlap': 50
  153. }
  154. }
  155. def to_dict(self):
  156. return {
  157. 'id': self.id,
  158. 'dataset_id': self.dataset_id,
  159. 'mode': self.mode,
  160. 'rules': self.rules_dict,
  161. 'created_by': self.created_by,
  162. 'created_at': self.created_at,
  163. }
  164. @property
  165. def rules_dict(self):
  166. try:
  167. return json.loads(self.rules) if self.rules else None
  168. except JSONDecodeError:
  169. return None
  170. class Document(db.Model):
  171. __tablename__ = 'documents'
  172. __table_args__ = (
  173. db.PrimaryKeyConstraint('id', name='document_pkey'),
  174. db.Index('document_dataset_id_idx', 'dataset_id'),
  175. db.Index('document_is_paused_idx', 'is_paused'),
  176. db.Index('document_tenant_idx', 'tenant_id'),
  177. )
  178. # initial fields
  179. id = db.Column(StringUUID, nullable=False,
  180. server_default=db.text('uuid_generate_v4()'))
  181. tenant_id = db.Column(StringUUID, nullable=False)
  182. dataset_id = db.Column(StringUUID, nullable=False)
  183. position = db.Column(db.Integer, nullable=False)
  184. data_source_type = db.Column(db.String(255), nullable=False)
  185. data_source_info = db.Column(db.Text, nullable=True)
  186. dataset_process_rule_id = db.Column(StringUUID, nullable=True)
  187. batch = db.Column(db.String(255), nullable=False)
  188. name = db.Column(db.String(255), nullable=False)
  189. created_from = db.Column(db.String(255), nullable=False)
  190. created_by = db.Column(StringUUID, nullable=False)
  191. created_api_request_id = db.Column(StringUUID, nullable=True)
  192. created_at = db.Column(db.DateTime, nullable=False,
  193. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  194. # start processing
  195. processing_started_at = db.Column(db.DateTime, nullable=True)
  196. # parsing
  197. file_id = db.Column(db.Text, nullable=True)
  198. word_count = db.Column(db.Integer, nullable=True)
  199. parsing_completed_at = db.Column(db.DateTime, nullable=True)
  200. # cleaning
  201. cleaning_completed_at = db.Column(db.DateTime, nullable=True)
  202. # split
  203. splitting_completed_at = db.Column(db.DateTime, nullable=True)
  204. # indexing
  205. tokens = db.Column(db.Integer, nullable=True)
  206. indexing_latency = db.Column(db.Float, nullable=True)
  207. completed_at = db.Column(db.DateTime, nullable=True)
  208. # pause
  209. is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text('false'))
  210. paused_by = db.Column(StringUUID, nullable=True)
  211. paused_at = db.Column(db.DateTime, nullable=True)
  212. # error
  213. error = db.Column(db.Text, nullable=True)
  214. stopped_at = db.Column(db.DateTime, nullable=True)
  215. # basic fields
  216. indexing_status = db.Column(db.String(
  217. 255), nullable=False, server_default=db.text("'waiting'::character varying"))
  218. enabled = db.Column(db.Boolean, nullable=False,
  219. server_default=db.text('true'))
  220. disabled_at = db.Column(db.DateTime, nullable=True)
  221. disabled_by = db.Column(StringUUID, nullable=True)
  222. archived = db.Column(db.Boolean, nullable=False,
  223. server_default=db.text('false'))
  224. archived_reason = db.Column(db.String(255), nullable=True)
  225. archived_by = db.Column(StringUUID, nullable=True)
  226. archived_at = db.Column(db.DateTime, nullable=True)
  227. updated_at = db.Column(db.DateTime, nullable=False,
  228. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  229. doc_type = db.Column(db.String(40), nullable=True)
  230. doc_metadata = db.Column(db.JSON, nullable=True)
  231. doc_form = db.Column(db.String(
  232. 255), nullable=False, server_default=db.text("'text_model'::character varying"))
  233. doc_language = db.Column(db.String(255), nullable=True)
  234. DATA_SOURCES = ['upload_file', 'notion_import']
  235. @property
  236. def display_status(self):
  237. status = None
  238. if self.indexing_status == 'waiting':
  239. status = 'queuing'
  240. elif self.indexing_status not in ['completed', 'error', 'waiting'] and self.is_paused:
  241. status = 'paused'
  242. elif self.indexing_status in ['parsing', 'cleaning', 'splitting', 'indexing']:
  243. status = 'indexing'
  244. elif self.indexing_status == 'error':
  245. status = 'error'
  246. elif self.indexing_status == 'completed' and not self.archived and self.enabled:
  247. status = 'available'
  248. elif self.indexing_status == 'completed' and not self.archived and not self.enabled:
  249. status = 'disabled'
  250. elif self.indexing_status == 'completed' and self.archived:
  251. status = 'archived'
  252. return status
  253. @property
  254. def data_source_info_dict(self):
  255. if self.data_source_info:
  256. try:
  257. data_source_info_dict = json.loads(self.data_source_info)
  258. except JSONDecodeError:
  259. data_source_info_dict = {}
  260. return data_source_info_dict
  261. return None
  262. @property
  263. def data_source_detail_dict(self):
  264. if self.data_source_info:
  265. if self.data_source_type == 'upload_file':
  266. data_source_info_dict = json.loads(self.data_source_info)
  267. file_detail = db.session.query(UploadFile). \
  268. filter(UploadFile.id == data_source_info_dict['upload_file_id']). \
  269. one_or_none()
  270. if file_detail:
  271. return {
  272. 'upload_file': {
  273. 'id': file_detail.id,
  274. 'name': file_detail.name,
  275. 'size': file_detail.size,
  276. 'extension': file_detail.extension,
  277. 'mime_type': file_detail.mime_type,
  278. 'created_by': file_detail.created_by,
  279. 'created_at': file_detail.created_at.timestamp()
  280. }
  281. }
  282. elif self.data_source_type == 'notion_import':
  283. return json.loads(self.data_source_info)
  284. return {}
  285. @property
  286. def average_segment_length(self):
  287. if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0:
  288. return self.word_count // self.segment_count
  289. return 0
  290. @property
  291. def dataset_process_rule(self):
  292. if self.dataset_process_rule_id:
  293. return DatasetProcessRule.query.get(self.dataset_process_rule_id)
  294. return None
  295. @property
  296. def dataset(self):
  297. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
  298. @property
  299. def segment_count(self):
  300. return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
  301. @property
  302. def hit_count(self):
  303. return DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) \
  304. .filter(DocumentSegment.document_id == self.id).scalar()
  305. class DocumentSegment(db.Model):
  306. __tablename__ = 'document_segments'
  307. __table_args__ = (
  308. db.PrimaryKeyConstraint('id', name='document_segment_pkey'),
  309. db.Index('document_segment_dataset_id_idx', 'dataset_id'),
  310. db.Index('document_segment_document_id_idx', 'document_id'),
  311. db.Index('document_segment_tenant_dataset_idx', 'dataset_id', 'tenant_id'),
  312. db.Index('document_segment_tenant_document_idx', 'document_id', 'tenant_id'),
  313. db.Index('document_segment_dataset_node_idx', 'dataset_id', 'index_node_id'),
  314. db.Index('document_segment_tenant_idx', 'tenant_id'),
  315. )
  316. # initial fields
  317. id = db.Column(StringUUID, nullable=False,
  318. server_default=db.text('uuid_generate_v4()'))
  319. tenant_id = db.Column(StringUUID, nullable=False)
  320. dataset_id = db.Column(StringUUID, nullable=False)
  321. document_id = db.Column(StringUUID, nullable=False)
  322. position = db.Column(db.Integer, nullable=False)
  323. content = db.Column(db.Text, nullable=False)
  324. answer = db.Column(db.Text, nullable=True)
  325. word_count = db.Column(db.Integer, nullable=False)
  326. tokens = db.Column(db.Integer, nullable=False)
  327. # indexing fields
  328. keywords = db.Column(db.JSON, nullable=True)
  329. index_node_id = db.Column(db.String(255), nullable=True)
  330. index_node_hash = db.Column(db.String(255), nullable=True)
  331. # basic fields
  332. hit_count = db.Column(db.Integer, nullable=False, default=0)
  333. enabled = db.Column(db.Boolean, nullable=False,
  334. server_default=db.text('true'))
  335. disabled_at = db.Column(db.DateTime, nullable=True)
  336. disabled_by = db.Column(StringUUID, nullable=True)
  337. status = db.Column(db.String(255), nullable=False,
  338. server_default=db.text("'waiting'::character varying"))
  339. created_by = db.Column(StringUUID, nullable=False)
  340. created_at = db.Column(db.DateTime, nullable=False,
  341. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  342. updated_by = db.Column(StringUUID, nullable=True)
  343. updated_at = db.Column(db.DateTime, nullable=False,
  344. server_default=db.text('CURRENT_TIMESTAMP(0)'))
  345. indexing_at = db.Column(db.DateTime, nullable=True)
  346. completed_at = db.Column(db.DateTime, nullable=True)
  347. error = db.Column(db.Text, nullable=True)
  348. stopped_at = db.Column(db.DateTime, nullable=True)
  349. @property
  350. def dataset(self):
  351. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
  352. @property
  353. def document(self):
  354. return db.session.query(Document).filter(Document.id == self.document_id).first()
  355. @property
  356. def previous_segment(self):
  357. return db.session.query(DocumentSegment).filter(
  358. DocumentSegment.document_id == self.document_id,
  359. DocumentSegment.position == self.position - 1
  360. ).first()
  361. @property
  362. def next_segment(self):
  363. return db.session.query(DocumentSegment).filter(
  364. DocumentSegment.document_id == self.document_id,
  365. DocumentSegment.position == self.position + 1
  366. ).first()
  367. def get_sign_content(self):
  368. pattern = r"/files/([a-f0-9\-]+)/image-preview"
  369. text = self.content
  370. match = re.search(pattern, text)
  371. if match:
  372. upload_file_id = match.group(1)
  373. nonce = os.urandom(16).hex()
  374. timestamp = str(int(time.time()))
  375. data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
  376. secret_key = current_app.config['SECRET_KEY'].encode()
  377. sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
  378. encoded_sign = base64.urlsafe_b64encode(sign).decode()
  379. params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
  380. replacement = r"\g<0>?{params}".format(params=params)
  381. text = re.sub(pattern, replacement, text)
  382. return text
  383. class AppDatasetJoin(db.Model):
  384. __tablename__ = 'app_dataset_joins'
  385. __table_args__ = (
  386. db.PrimaryKeyConstraint('id', name='app_dataset_join_pkey'),
  387. db.Index('app_dataset_join_app_dataset_idx', 'dataset_id', 'app_id'),
  388. )
  389. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
  390. app_id = db.Column(StringUUID, nullable=False)
  391. dataset_id = db.Column(StringUUID, nullable=False)
  392. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  393. @property
  394. def app(self):
  395. return App.query.get(self.app_id)
  396. class DatasetQuery(db.Model):
  397. __tablename__ = 'dataset_queries'
  398. __table_args__ = (
  399. db.PrimaryKeyConstraint('id', name='dataset_query_pkey'),
  400. db.Index('dataset_query_dataset_id_idx', 'dataset_id'),
  401. )
  402. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text('uuid_generate_v4()'))
  403. dataset_id = db.Column(StringUUID, nullable=False)
  404. content = db.Column(db.Text, nullable=False)
  405. source = db.Column(db.String(255), nullable=False)
  406. source_app_id = db.Column(StringUUID, nullable=True)
  407. created_by_role = db.Column(db.String, nullable=False)
  408. created_by = db.Column(StringUUID, nullable=False)
  409. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  410. class DatasetKeywordTable(db.Model):
  411. __tablename__ = 'dataset_keyword_tables'
  412. __table_args__ = (
  413. db.PrimaryKeyConstraint('id', name='dataset_keyword_table_pkey'),
  414. db.Index('dataset_keyword_table_dataset_id_idx', 'dataset_id'),
  415. )
  416. id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
  417. dataset_id = db.Column(StringUUID, nullable=False, unique=True)
  418. keyword_table = db.Column(db.Text, nullable=False)
  419. data_source_type = db.Column(db.String(255), nullable=False,
  420. server_default=db.text("'database'::character varying"))
  421. @property
  422. def keyword_table_dict(self):
  423. class SetDecoder(json.JSONDecoder):
  424. def __init__(self, *args, **kwargs):
  425. super().__init__(object_hook=self.object_hook, *args, **kwargs)
  426. def object_hook(self, dct):
  427. if isinstance(dct, dict):
  428. for keyword, node_idxs in dct.items():
  429. if isinstance(node_idxs, list):
  430. dct[keyword] = set(node_idxs)
  431. return dct
  432. # get dataset
  433. dataset = Dataset.query.filter_by(
  434. id=self.dataset_id
  435. ).first()
  436. if not dataset:
  437. return None
  438. if self.data_source_type == 'database':
  439. return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
  440. else:
  441. file_key = 'keyword_files/' + dataset.tenant_id + '/' + self.dataset_id + '.txt'
  442. try:
  443. keyword_table_text = storage.load_once(file_key)
  444. if keyword_table_text:
  445. return json.loads(keyword_table_text.decode('utf-8'), cls=SetDecoder)
  446. return None
  447. except Exception as e:
  448. logging.exception(str(e))
  449. return None
  450. class Embedding(db.Model):
  451. __tablename__ = 'embeddings'
  452. __table_args__ = (
  453. db.PrimaryKeyConstraint('id', name='embedding_pkey'),
  454. db.UniqueConstraint('model_name', 'hash', 'provider_name', name='embedding_hash_idx')
  455. )
  456. id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
  457. model_name = db.Column(db.String(40), nullable=False,
  458. server_default=db.text("'text-embedding-ada-002'::character varying"))
  459. hash = db.Column(db.String(64), nullable=False)
  460. embedding = db.Column(db.LargeBinary, nullable=False)
  461. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
  462. provider_name = db.Column(db.String(40), nullable=False,
  463. server_default=db.text("''::character varying"))
  464. def set_embedding(self, embedding_data: list[float]):
  465. self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
  466. def get_embedding(self) -> list[float]:
  467. return pickle.loads(self.embedding)
  468. class DatasetCollectionBinding(db.Model):
  469. __tablename__ = 'dataset_collection_bindings'
  470. __table_args__ = (
  471. db.PrimaryKeyConstraint('id', name='dataset_collection_bindings_pkey'),
  472. db.Index('provider_model_name_idx', 'provider_name', 'model_name')
  473. )
  474. id = db.Column(StringUUID, primary_key=True, server_default=db.text('uuid_generate_v4()'))
  475. provider_name = db.Column(db.String(40), nullable=False)
  476. model_name = db.Column(db.String(40), nullable=False)
  477. type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
  478. collection_name = db.Column(db.String(64), nullable=False)
  479. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))