dataset.py 60 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502
  1. import base64
  2. import enum
  3. import hashlib
  4. import hmac
  5. import json
  6. import logging
  7. import os
  8. import pickle
  9. import re
  10. import time
  11. from json import JSONDecodeError
  12. from typing import Any, cast
  13. from sqlalchemy import func
  14. from sqlalchemy.dialects.postgresql import JSONB
  15. from sqlalchemy.orm import Mapped
  16. from configs import dify_config
  17. from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
  18. from core.rag.retrieval.retrieval_methods import RetrievalMethod
  19. from extensions.ext_storage import storage
  20. from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
  21. from .account import Account
  22. from .engine import db
  23. from .model import App, Tag, TagBinding, UploadFile
  24. from .types import StringUUID
  25. class DatasetPermissionEnum(enum.StrEnum):
  26. ONLY_ME = "only_me"
  27. ALL_TEAM = "all_team_members"
  28. PARTIAL_TEAM = "partial_members"
  29. class Dataset(db.Model): # type: ignore[name-defined]
  30. __tablename__ = "datasets"
  31. __table_args__ = (
  32. db.PrimaryKeyConstraint("id", name="dataset_pkey"),
  33. db.Index("dataset_tenant_idx", "tenant_id"),
  34. db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
  35. )
  36. INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
  37. PROVIDER_LIST = ["vendor", "external", None]
  38. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  39. tenant_id = db.Column(StringUUID, nullable=False)
  40. name = db.Column(db.String(255), nullable=False)
  41. description = db.Column(db.Text, nullable=True)
  42. provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying"))
  43. permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying"))
  44. data_source_type = db.Column(db.String(255))
  45. indexing_technique = db.Column(db.String(255), nullable=True)
  46. index_struct = db.Column(db.Text, nullable=True)
  47. created_by = db.Column(StringUUID, nullable=False)
  48. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  49. updated_by = db.Column(StringUUID, nullable=True)
  50. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  51. embedding_model = db.Column(db.String(255), nullable=True)
  52. embedding_model_provider = db.Column(db.String(255), nullable=True)
  53. collection_binding_id = db.Column(StringUUID, nullable=True)
  54. retrieval_model = db.Column(JSONB, nullable=True)
  55. built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
  56. @property
  57. def dataset_keyword_table(self):
  58. dataset_keyword_table = (
  59. db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first()
  60. )
  61. if dataset_keyword_table:
  62. return dataset_keyword_table
  63. return None
  64. @property
  65. def index_struct_dict(self):
  66. return json.loads(self.index_struct) if self.index_struct else None
  67. @property
  68. def external_retrieval_model(self):
  69. default_retrieval_model = {
  70. "top_k": 2,
  71. "score_threshold": 0.0,
  72. }
  73. return self.retrieval_model or default_retrieval_model
  74. @property
  75. def created_by_account(self):
  76. return db.session.get(Account, self.created_by)
  77. @property
  78. def latest_process_rule(self):
  79. return (
  80. DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id)
  81. .order_by(DatasetProcessRule.created_at.desc())
  82. .first()
  83. )
  84. @property
  85. def app_count(self):
  86. return (
  87. db.session.query(func.count(AppDatasetJoin.id))
  88. .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
  89. .scalar()
  90. )
  91. @property
  92. def document_count(self):
  93. return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
  94. @property
  95. def available_document_count(self):
  96. return (
  97. db.session.query(func.count(Document.id))
  98. .filter(
  99. Document.dataset_id == self.id,
  100. Document.indexing_status == "completed",
  101. Document.enabled == True,
  102. Document.archived == False,
  103. )
  104. .scalar()
  105. )
  106. @property
  107. def available_segment_count(self):
  108. return (
  109. db.session.query(func.count(DocumentSegment.id))
  110. .filter(
  111. DocumentSegment.dataset_id == self.id,
  112. DocumentSegment.status == "completed",
  113. DocumentSegment.enabled == True,
  114. )
  115. .scalar()
  116. )
  117. @property
  118. def word_count(self):
  119. return (
  120. Document.query.with_entities(func.coalesce(func.sum(Document.word_count)))
  121. .filter(Document.dataset_id == self.id)
  122. .scalar()
  123. )
  124. @property
  125. def doc_form(self):
  126. document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
  127. if document:
  128. return document.doc_form
  129. return None
  130. @property
  131. def retrieval_model_dict(self):
  132. default_retrieval_model = {
  133. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  134. "reranking_enable": False,
  135. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  136. "top_k": 2,
  137. "score_threshold_enabled": False,
  138. }
  139. return self.retrieval_model or default_retrieval_model
  140. @property
  141. def tags(self):
  142. tags = (
  143. db.session.query(Tag)
  144. .join(TagBinding, Tag.id == TagBinding.tag_id)
  145. .filter(
  146. TagBinding.target_id == self.id,
  147. TagBinding.tenant_id == self.tenant_id,
  148. Tag.tenant_id == self.tenant_id,
  149. Tag.type == "knowledge",
  150. )
  151. .all()
  152. )
  153. return tags or []
  154. @property
  155. def categories(self):
  156. categories = (
  157. db.session.query(Tag)
  158. .join(TagBinding, Tag.id == TagBinding.tag_id)
  159. .filter(
  160. TagBinding.target_id == self.id,
  161. TagBinding.tenant_id == self.tenant_id,
  162. Tag.tenant_id == self.tenant_id,
  163. Tag.type == "knowledge_category",
  164. )
  165. .all()
  166. )
  167. return categories or []
  168. @property
  169. def external_knowledge_info(self):
  170. if self.provider != "external":
  171. return None
  172. external_knowledge_binding = (
  173. db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first()
  174. )
  175. if not external_knowledge_binding:
  176. return None
  177. external_knowledge_api = (
  178. db.session.query(ExternalKnowledgeApis)
  179. .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
  180. .first()
  181. )
  182. if not external_knowledge_api:
  183. return None
  184. return {
  185. "external_knowledge_id": external_knowledge_binding.external_knowledge_id,
  186. "external_knowledge_api_id": external_knowledge_api.id,
  187. "external_knowledge_api_name": external_knowledge_api.name,
  188. "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
  189. }
  190. @property
  191. def doc_metadata(self):
  192. dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all()
  193. doc_metadata = [
  194. {
  195. "id": dataset_metadata.id,
  196. "name": dataset_metadata.name,
  197. "type": dataset_metadata.type,
  198. }
  199. for dataset_metadata in dataset_metadatas
  200. ]
  201. if self.built_in_field_enabled:
  202. doc_metadata.append(
  203. {
  204. "id": "built-in",
  205. "name": BuiltInField.document_name.value,
  206. "type": "string",
  207. }
  208. )
  209. doc_metadata.append(
  210. {
  211. "id": "built-in",
  212. "name": BuiltInField.uploader.value,
  213. "type": "string",
  214. }
  215. )
  216. doc_metadata.append(
  217. {
  218. "id": "built-in",
  219. "name": BuiltInField.upload_date.value,
  220. "type": "time",
  221. }
  222. )
  223. doc_metadata.append(
  224. {
  225. "id": "built-in",
  226. "name": BuiltInField.last_update_date.value,
  227. "type": "time",
  228. }
  229. )
  230. doc_metadata.append(
  231. {
  232. "id": "built-in",
  233. "name": BuiltInField.source.value,
  234. "type": "string",
  235. }
  236. )
  237. return doc_metadata
  238. @staticmethod
  239. def gen_collection_name_by_id(dataset_id: str) -> str:
  240. normalized_dataset_id = dataset_id.replace("-", "_")
  241. return f"Vector_index_{normalized_dataset_id}_Node"
  242. class DatasetProcessRule(db.Model): # type: ignore[name-defined]
  243. __tablename__ = "dataset_process_rules"
  244. __table_args__ = (
  245. db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
  246. db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
  247. )
  248. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  249. dataset_id = db.Column(StringUUID, nullable=False)
  250. mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
  251. rules = db.Column(db.Text, nullable=True)
  252. created_by = db.Column(StringUUID, nullable=False)
  253. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  254. MODES = ["automatic", "custom", "hierarchical"]
  255. PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
  256. AUTOMATIC_RULES: dict[str, Any] = {
  257. "pre_processing_rules": [
  258. {"id": "remove_extra_spaces", "enabled": True},
  259. {"id": "remove_urls_emails", "enabled": False},
  260. ],
  261. "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
  262. }
  263. def to_dict(self):
  264. return {
  265. "id": self.id,
  266. "dataset_id": self.dataset_id,
  267. "mode": self.mode,
  268. "rules": self.rules_dict,
  269. }
  270. @property
  271. def rules_dict(self):
  272. try:
  273. return json.loads(self.rules) if self.rules else None
  274. except JSONDecodeError:
  275. return None
  276. class Document(db.Model): # type: ignore[name-defined]
  277. __tablename__ = "documents"
  278. __table_args__ = (
  279. db.PrimaryKeyConstraint("id", name="document_pkey"),
  280. db.Index("document_dataset_id_idx", "dataset_id"),
  281. db.Index("document_is_paused_idx", "is_paused"),
  282. db.Index("document_tenant_idx", "tenant_id"),
  283. db.Index("document_metadata_idx", "doc_metadata", postgresql_using="gin"),
  284. )
  285. # initial fields
  286. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  287. tenant_id = db.Column(StringUUID, nullable=False)
  288. dataset_id = db.Column(StringUUID, nullable=False)
  289. position = db.Column(db.Integer, nullable=False)
  290. data_source_type = db.Column(db.String(255), nullable=False)
  291. data_source_info = db.Column(db.Text, nullable=True)
  292. dataset_process_rule_id = db.Column(StringUUID, nullable=True)
  293. batch = db.Column(db.String(255), nullable=False)
  294. name = db.Column(db.String(255), nullable=False)
  295. created_from = db.Column(db.String(255), nullable=False)
  296. created_by = db.Column(StringUUID, nullable=False)
  297. created_api_request_id = db.Column(StringUUID, nullable=True)
  298. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  299. # start processing
  300. processing_started_at = db.Column(db.DateTime, nullable=True)
  301. # parsing
  302. file_id = db.Column(db.Text, nullable=True)
  303. word_count = db.Column(db.Integer, nullable=True)
  304. parsing_completed_at = db.Column(db.DateTime, nullable=True)
  305. # cleaning
  306. cleaning_completed_at = db.Column(db.DateTime, nullable=True)
  307. # split
  308. splitting_completed_at = db.Column(db.DateTime, nullable=True)
  309. # indexing
  310. tokens = db.Column(db.Integer, nullable=True)
  311. indexing_latency = db.Column(db.Float, nullable=True)
  312. completed_at = db.Column(db.DateTime, nullable=True)
  313. # pause
  314. is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
  315. paused_by = db.Column(StringUUID, nullable=True)
  316. paused_at = db.Column(db.DateTime, nullable=True)
  317. # error
  318. error = db.Column(db.Text, nullable=True)
  319. stopped_at = db.Column(db.DateTime, nullable=True)
  320. # basic fields
  321. indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
  322. enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
  323. disabled_at = db.Column(db.DateTime, nullable=True)
  324. disabled_by = db.Column(StringUUID, nullable=True)
  325. archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
  326. archived_reason = db.Column(db.String(255), nullable=True)
  327. archived_by = db.Column(StringUUID, nullable=True)
  328. archived_at = db.Column(db.DateTime, nullable=True)
  329. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  330. doc_type = db.Column(db.String(40), nullable=True)
  331. doc_metadata = db.Column(JSONB, nullable=True)
  332. doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
  333. doc_language = db.Column(db.String(255), nullable=True)
  334. check_status = db.Column(db.Integer, nullable=False)
  335. check_by = db.Column(db.String(40), nullable=True)
  336. check_at = db.Column(db.DateTime, nullable=True)
  337. disable_applicant = db.Column(StringUUID, nullable=True)
  338. enable_applicant = db.Column(db.String(40), nullable=False)
  339. DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
  340. @property
  341. def display_status(self):
  342. status = None
  343. if self.indexing_status == "waiting":
  344. status = "queuing"
  345. elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused:
  346. status = "paused"
  347. elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}:
  348. status = "indexing"
  349. elif self.indexing_status == "error":
  350. status = "error"
  351. elif self.indexing_status == "completed" and not self.archived and self.enabled:
  352. status = "available"
  353. elif self.indexing_status == "completed" and not self.archived and not self.enabled:
  354. status = "disabled"
  355. elif self.indexing_status == "completed" and self.archived:
  356. status = "archived"
  357. return status
  358. @property
  359. def data_source_info_dict(self):
  360. if self.data_source_info:
  361. try:
  362. data_source_info_dict = json.loads(self.data_source_info)
  363. except JSONDecodeError:
  364. data_source_info_dict = {}
  365. return data_source_info_dict
  366. return None
  367. @property
  368. def data_source_detail_dict(self):
  369. if self.data_source_info:
  370. if self.data_source_type == "upload_file":
  371. data_source_info_dict = json.loads(self.data_source_info)
  372. file_detail = (
  373. db.session.query(UploadFile)
  374. .filter(UploadFile.id == data_source_info_dict["upload_file_id"])
  375. .one_or_none()
  376. )
  377. if file_detail:
  378. return {
  379. "upload_file": {
  380. "id": file_detail.id,
  381. "name": file_detail.name,
  382. "size": file_detail.size,
  383. "extension": file_detail.extension,
  384. "mime_type": file_detail.mime_type,
  385. "created_by": file_detail.created_by,
  386. "created_at": file_detail.created_at.timestamp(),
  387. }
  388. }
  389. elif self.data_source_type in {"notion_import", "website_crawl"}:
  390. return json.loads(self.data_source_info)
  391. return {}
  392. @property
  393. def average_segment_length(self):
  394. if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0:
  395. return self.word_count // self.segment_count
  396. return 0
  397. @property
  398. def dataset_process_rule(self):
  399. if self.dataset_process_rule_id:
  400. return db.session.get(DatasetProcessRule, self.dataset_process_rule_id)
  401. return None
  402. @property
  403. def dataset(self):
  404. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
  405. @property
  406. def segment_count(self):
  407. return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
  408. @property
  409. def hit_count(self):
  410. return (
  411. DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
  412. .filter(DocumentSegment.document_id == self.id)
  413. .scalar()
  414. )
  415. @property
  416. def uploader(self):
  417. user = db.session.query(Account).filter(Account.id == self.created_by).first()
  418. return user.name if user else None
  419. @property
  420. def upload_date(self):
  421. return self.created_at
  422. @property
  423. def last_update_date(self):
  424. return self.updated_at
  425. @property
  426. def doc_metadata_details(self):
  427. if self.doc_metadata:
  428. document_metadatas = (
  429. db.session.query(DatasetMetadata)
  430. .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
  431. .filter(
  432. DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
  433. )
  434. .all()
  435. )
  436. metadata_list = []
  437. for metadata in document_metadatas:
  438. metadata_dict = {
  439. "id": metadata.id,
  440. "name": metadata.name,
  441. "type": metadata.type,
  442. "value": self.doc_metadata.get(metadata.name),
  443. }
  444. metadata_list.append(metadata_dict)
  445. # deal built-in fields
  446. metadata_list.extend(self.get_built_in_fields())
  447. return metadata_list
  448. return None
  449. @property
  450. def process_rule_dict(self):
  451. if self.dataset_process_rule_id:
  452. return self.dataset_process_rule.to_dict()
  453. return None
  454. def get_built_in_fields(self):
  455. built_in_fields = []
  456. built_in_fields.append(
  457. {
  458. "id": "built-in",
  459. "name": BuiltInField.document_name,
  460. "type": "string",
  461. "value": self.name,
  462. }
  463. )
  464. built_in_fields.append(
  465. {
  466. "id": "built-in",
  467. "name": BuiltInField.uploader,
  468. "type": "string",
  469. "value": self.uploader,
  470. }
  471. )
  472. built_in_fields.append(
  473. {
  474. "id": "built-in",
  475. "name": BuiltInField.upload_date,
  476. "type": "time",
  477. "value": self.created_at.timestamp(),
  478. }
  479. )
  480. built_in_fields.append(
  481. {
  482. "id": "built-in",
  483. "name": BuiltInField.last_update_date,
  484. "type": "time",
  485. "value": self.updated_at.timestamp(),
  486. }
  487. )
  488. built_in_fields.append(
  489. {
  490. "id": "built-in",
  491. "name": BuiltInField.source,
  492. "type": "string",
  493. "value": MetadataDataSource[self.data_source_type].value,
  494. }
  495. )
  496. return built_in_fields
  497. def to_dict(self):
  498. return {
  499. "id": self.id,
  500. "tenant_id": self.tenant_id,
  501. "dataset_id": self.dataset_id,
  502. "position": self.position,
  503. "data_source_type": self.data_source_type,
  504. "data_source_info": self.data_source_info,
  505. "dataset_process_rule_id": self.dataset_process_rule_id,
  506. "batch": self.batch,
  507. "name": self.name,
  508. "created_from": self.created_from,
  509. "created_by": self.created_by,
  510. "created_api_request_id": self.created_api_request_id,
  511. "created_at": self.created_at,
  512. "processing_started_at": self.processing_started_at,
  513. "file_id": self.file_id,
  514. "word_count": self.word_count,
  515. "parsing_completed_at": self.parsing_completed_at,
  516. "cleaning_completed_at": self.cleaning_completed_at,
  517. "splitting_completed_at": self.splitting_completed_at,
  518. "tokens": self.tokens,
  519. "indexing_latency": self.indexing_latency,
  520. "completed_at": self.completed_at,
  521. "is_paused": self.is_paused,
  522. "paused_by": self.paused_by,
  523. "paused_at": self.paused_at,
  524. "error": self.error,
  525. "stopped_at": self.stopped_at,
  526. "indexing_status": self.indexing_status,
  527. "enabled": self.enabled,
  528. "disabled_at": self.disabled_at,
  529. "disabled_by": self.disabled_by,
  530. "archived": self.archived,
  531. "archived_reason": self.archived_reason,
  532. "archived_by": self.archived_by,
  533. "archived_at": self.archived_at,
  534. "updated_at": self.updated_at,
  535. "doc_type": self.doc_type,
  536. "doc_metadata": self.doc_metadata,
  537. "doc_form": self.doc_form,
  538. "doc_language": self.doc_language,
  539. "display_status": self.display_status,
  540. "data_source_info_dict": self.data_source_info_dict,
  541. "average_segment_length": self.average_segment_length,
  542. "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
  543. "dataset": self.dataset.to_dict() if self.dataset else None,
  544. "segment_count": self.segment_count,
  545. "hit_count": self.hit_count,
  546. }
  547. @classmethod
  548. def from_dict(cls, data: dict):
  549. return cls(
  550. id=data.get("id"),
  551. tenant_id=data.get("tenant_id"),
  552. dataset_id=data.get("dataset_id"),
  553. position=data.get("position"),
  554. data_source_type=data.get("data_source_type"),
  555. data_source_info=data.get("data_source_info"),
  556. dataset_process_rule_id=data.get("dataset_process_rule_id"),
  557. batch=data.get("batch"),
  558. name=data.get("name"),
  559. created_from=data.get("created_from"),
  560. created_by=data.get("created_by"),
  561. created_api_request_id=data.get("created_api_request_id"),
  562. created_at=data.get("created_at"),
  563. processing_started_at=data.get("processing_started_at"),
  564. file_id=data.get("file_id"),
  565. word_count=data.get("word_count"),
  566. parsing_completed_at=data.get("parsing_completed_at"),
  567. cleaning_completed_at=data.get("cleaning_completed_at"),
  568. splitting_completed_at=data.get("splitting_completed_at"),
  569. tokens=data.get("tokens"),
  570. indexing_latency=data.get("indexing_latency"),
  571. completed_at=data.get("completed_at"),
  572. is_paused=data.get("is_paused"),
  573. paused_by=data.get("paused_by"),
  574. paused_at=data.get("paused_at"),
  575. error=data.get("error"),
  576. stopped_at=data.get("stopped_at"),
  577. indexing_status=data.get("indexing_status"),
  578. enabled=data.get("enabled"),
  579. disabled_at=data.get("disabled_at"),
  580. disabled_by=data.get("disabled_by"),
  581. archived=data.get("archived"),
  582. archived_reason=data.get("archived_reason"),
  583. archived_by=data.get("archived_by"),
  584. archived_at=data.get("archived_at"),
  585. updated_at=data.get("updated_at"),
  586. doc_type=data.get("doc_type"),
  587. doc_metadata=data.get("doc_metadata"),
  588. doc_form=data.get("doc_form"),
  589. doc_language=data.get("doc_language"),
  590. )
  591. class Template(db.Model): # type: ignore[name-defined]
  592. __tablename__ = "template"
  593. __table_args__ = (
  594. db.PrimaryKeyConstraint("id", name="template_pkey"),
  595. db.Index("template_dataset_id_idx", "dataset_id"),
  596. db.Index("template_is_paused_idx", "is_paused"),
  597. db.Index("template_tenant_idx", "tenant_id"),
  598. db.Index("template_metadata_idx", "doc_metadata", postgresql_using="gin"),
  599. )
  600. # initial fields
  601. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  602. tenant_id = db.Column(StringUUID, nullable=False)
  603. dataset_id = db.Column(StringUUID, nullable=False)
  604. position = db.Column(db.Integer, nullable=False)
  605. data_source_type = db.Column(db.String(255), nullable=False)
  606. data_source_info = db.Column(db.Text, nullable=True)
  607. dataset_process_rule_id = db.Column(StringUUID, nullable=True)
  608. batch = db.Column(db.String(255), nullable=False)
  609. name = db.Column(db.String(255), nullable=False)
  610. created_from = db.Column(db.String(255), nullable=False)
  611. created_by = db.Column(StringUUID, nullable=False)
  612. created_api_request_id = db.Column(StringUUID, nullable=True)
  613. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  614. # start processing
  615. processing_started_at = db.Column(db.DateTime, nullable=True)
  616. # parsing
  617. file_id = db.Column(db.Text, nullable=True)
  618. file_url = db.Column(db.Text, nullable=True)
  619. word_count = db.Column(db.Integer, nullable=True)
  620. parsing_completed_at = db.Column(db.DateTime, nullable=True)
  621. # cleaning
  622. cleaning_completed_at = db.Column(db.DateTime, nullable=True)
  623. # split
  624. splitting_completed_at = db.Column(db.DateTime, nullable=True)
  625. # indexing
  626. tokens = db.Column(db.Integer, nullable=True)
  627. indexing_latency = db.Column(db.Float, nullable=True)
  628. completed_at = db.Column(db.DateTime, nullable=True)
  629. # pause
  630. is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
  631. paused_by = db.Column(StringUUID, nullable=True)
  632. paused_at = db.Column(db.DateTime, nullable=True)
  633. # error
  634. error = db.Column(db.Text, nullable=True)
  635. stopped_at = db.Column(db.DateTime, nullable=True)
  636. # basic fields
  637. indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
  638. enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
  639. disabled_at = db.Column(db.DateTime, nullable=True)
  640. disabled_by = db.Column(StringUUID, nullable=True)
  641. archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
  642. archived_reason = db.Column(db.String(255), nullable=True)
  643. archived_by = db.Column(StringUUID, nullable=True)
  644. archived_at = db.Column(db.DateTime, nullable=True)
  645. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  646. doc_type = db.Column(db.String(40), nullable=True)
  647. doc_metadata = db.Column(JSONB, nullable=True)
  648. doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
  649. doc_language = db.Column(db.String(255), nullable=True)
  650. DATA_SOURCES = ["upload_file"]
  651. @property
  652. def display_status(self):
  653. status = None
  654. if self.indexing_status == "waiting":
  655. status = "queuing"
  656. elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused:
  657. status = "paused"
  658. elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}:
  659. status = "indexing"
  660. elif self.indexing_status == "error":
  661. status = "error"
  662. elif self.indexing_status == "completed" and not self.archived and self.enabled:
  663. status = "available"
  664. elif self.indexing_status == "completed" and not self.archived and not self.enabled:
  665. status = "disabled"
  666. elif self.indexing_status == "completed" and self.archived:
  667. status = "archived"
  668. return status
  669. @property
  670. def data_source_info_dict(self):
  671. if self.data_source_info:
  672. try:
  673. data_source_info_dict = json.loads(self.data_source_info)
  674. except JSONDecodeError:
  675. data_source_info_dict = {}
  676. return data_source_info_dict
  677. return None
  678. @property
  679. def data_source_detail_dict(self):
  680. if self.data_source_info:
  681. if self.data_source_type == "upload_file":
  682. data_source_info_dict = json.loads(self.data_source_info)
  683. file_detail = (
  684. db.session.query(UploadFile)
  685. .filter(UploadFile.id == data_source_info_dict["upload_file_id"])
  686. .one_or_none()
  687. )
  688. if file_detail:
  689. return {
  690. "upload_file": {
  691. "id": file_detail.id,
  692. "name": file_detail.name,
  693. "size": file_detail.size,
  694. "extension": file_detail.extension,
  695. "mime_type": file_detail.mime_type,
  696. "created_by": file_detail.created_by,
  697. "created_at": file_detail.created_at.timestamp(),
  698. }
  699. }
  700. elif self.data_source_type in {"notion_import", "website_crawl"}:
  701. return json.loads(self.data_source_info)
  702. return {}
  703. @property
  704. def average_segment_length(self):
  705. if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0:
  706. return self.word_count // self.segment_count
  707. return 0
  708. @property
  709. def dataset_process_rule(self):
  710. if self.dataset_process_rule_id:
  711. return db.session.get(DatasetProcessRule, self.dataset_process_rule_id)
  712. return None
  713. @property
  714. def dataset(self):
  715. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
  716. @property
  717. def segment_count(self):
  718. return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
  719. @property
  720. def hit_count(self):
  721. return (
  722. DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
  723. .filter(DocumentSegment.document_id == self.id)
  724. .scalar()
  725. )
  726. @property
  727. def uploader(self):
  728. user = db.session.query(Account).filter(Account.id == self.created_by).first()
  729. return user.name if user else None
  730. @property
  731. def upload_date(self):
  732. return self.created_at
  733. @property
  734. def last_update_date(self):
  735. return self.updated_at
  736. @property
  737. def doc_metadata_details(self):
  738. if self.doc_metadata:
  739. document_metadatas = (
  740. db.session.query(DatasetMetadata)
  741. .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
  742. .filter(
  743. DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
  744. )
  745. .all()
  746. )
  747. metadata_list = []
  748. for metadata in document_metadatas:
  749. metadata_dict = {
  750. "id": metadata.id,
  751. "name": metadata.name,
  752. "type": metadata.type,
  753. "value": self.doc_metadata.get(metadata.name),
  754. }
  755. metadata_list.append(metadata_dict)
  756. # deal built-in fields
  757. metadata_list.extend(self.get_built_in_fields())
  758. return metadata_list
  759. return None
  760. @property
  761. def process_rule_dict(self):
  762. if self.dataset_process_rule_id:
  763. return self.dataset_process_rule.to_dict()
  764. return None
  765. def get_built_in_fields(self):
  766. built_in_fields = []
  767. built_in_fields.append(
  768. {
  769. "id": "built-in",
  770. "name": BuiltInField.document_name,
  771. "type": "string",
  772. "value": self.name,
  773. }
  774. )
  775. built_in_fields.append(
  776. {
  777. "id": "built-in",
  778. "name": BuiltInField.uploader,
  779. "type": "string",
  780. "value": self.uploader,
  781. }
  782. )
  783. built_in_fields.append(
  784. {
  785. "id": "built-in",
  786. "name": BuiltInField.upload_date,
  787. "type": "time",
  788. "value": self.created_at.timestamp(),
  789. }
  790. )
  791. built_in_fields.append(
  792. {
  793. "id": "built-in",
  794. "name": BuiltInField.last_update_date,
  795. "type": "time",
  796. "value": self.updated_at.timestamp(),
  797. }
  798. )
  799. built_in_fields.append(
  800. {
  801. "id": "built-in",
  802. "name": BuiltInField.source,
  803. "type": "string",
  804. "value": MetadataDataSource[self.data_source_type].value,
  805. }
  806. )
  807. return built_in_fields
  808. def to_dict(self):
  809. return {
  810. "id": self.id,
  811. "tenant_id": self.tenant_id,
  812. "dataset_id": self.dataset_id,
  813. "position": self.position,
  814. "data_source_type": self.data_source_type,
  815. "data_source_info": self.data_source_info,
  816. "dataset_process_rule_id": self.dataset_process_rule_id,
  817. "batch": self.batch,
  818. "name": self.name,
  819. "created_from": self.created_from,
  820. "created_by": self.created_by,
  821. "created_api_request_id": self.created_api_request_id,
  822. "created_at": self.created_at,
  823. "processing_started_at": self.processing_started_at,
  824. "file_id": self.file_id,
  825. "word_count": self.word_count,
  826. "parsing_completed_at": self.parsing_completed_at,
  827. "cleaning_completed_at": self.cleaning_completed_at,
  828. "splitting_completed_at": self.splitting_completed_at,
  829. "tokens": self.tokens,
  830. "indexing_latency": self.indexing_latency,
  831. "completed_at": self.completed_at,
  832. "is_paused": self.is_paused,
  833. "paused_by": self.paused_by,
  834. "paused_at": self.paused_at,
  835. "error": self.error,
  836. "stopped_at": self.stopped_at,
  837. "indexing_status": self.indexing_status,
  838. "enabled": self.enabled,
  839. "disabled_at": self.disabled_at,
  840. "disabled_by": self.disabled_by,
  841. "archived": self.archived,
  842. "archived_reason": self.archived_reason,
  843. "archived_by": self.archived_by,
  844. "archived_at": self.archived_at,
  845. "updated_at": self.updated_at,
  846. "doc_type": self.doc_type,
  847. "doc_metadata": self.doc_metadata,
  848. "doc_form": self.doc_form,
  849. "doc_language": self.doc_language,
  850. "display_status": self.display_status,
  851. "data_source_info_dict": self.data_source_info_dict,
  852. "average_segment_length": self.average_segment_length,
  853. "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
  854. "dataset": self.dataset.to_dict() if self.dataset else None,
  855. "segment_count": self.segment_count,
  856. "hit_count": self.hit_count,
  857. }
  858. @classmethod
  859. def from_dict(cls, data: dict):
  860. return cls(
  861. id=data.get("id"),
  862. tenant_id=data.get("tenant_id"),
  863. dataset_id=data.get("dataset_id"),
  864. position=data.get("position"),
  865. data_source_type=data.get("data_source_type"),
  866. data_source_info=data.get("data_source_info"),
  867. dataset_process_rule_id=data.get("dataset_process_rule_id"),
  868. batch=data.get("batch"),
  869. name=data.get("name"),
  870. created_from=data.get("created_from"),
  871. created_by=data.get("created_by"),
  872. created_api_request_id=data.get("created_api_request_id"),
  873. created_at=data.get("created_at"),
  874. processing_started_at=data.get("processing_started_at"),
  875. file_id=data.get("file_id"),
  876. word_count=data.get("word_count"),
  877. parsing_completed_at=data.get("parsing_completed_at"),
  878. cleaning_completed_at=data.get("cleaning_completed_at"),
  879. splitting_completed_at=data.get("splitting_completed_at"),
  880. tokens=data.get("tokens"),
  881. indexing_latency=data.get("indexing_latency"),
  882. completed_at=data.get("completed_at"),
  883. is_paused=data.get("is_paused"),
  884. paused_by=data.get("paused_by"),
  885. paused_at=data.get("paused_at"),
  886. error=data.get("error"),
  887. stopped_at=data.get("stopped_at"),
  888. indexing_status=data.get("indexing_status"),
  889. enabled=data.get("enabled"),
  890. disabled_at=data.get("disabled_at"),
  891. disabled_by=data.get("disabled_by"),
  892. archived=data.get("archived"),
  893. archived_reason=data.get("archived_reason"),
  894. archived_by=data.get("archived_by"),
  895. archived_at=data.get("archived_at"),
  896. updated_at=data.get("updated_at"),
  897. doc_type=data.get("doc_type"),
  898. doc_metadata=data.get("doc_metadata"),
  899. doc_form=data.get("doc_form"),
  900. doc_language=data.get("doc_language"),
  901. )
  902. class DocumentSegment(db.Model): # type: ignore[name-defined]
  903. __tablename__ = "document_segments"
  904. __table_args__ = (
  905. db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
  906. db.Index("document_segment_dataset_id_idx", "dataset_id"),
  907. db.Index("document_segment_document_id_idx", "document_id"),
  908. db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
  909. db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
  910. db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"),
  911. db.Index("document_segment_tenant_idx", "tenant_id"),
  912. )
  913. # initial fields
  914. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  915. tenant_id = db.Column(StringUUID, nullable=False)
  916. dataset_id = db.Column(StringUUID, nullable=False)
  917. document_id = db.Column(StringUUID, nullable=False)
  918. position: Mapped[int]
  919. content = db.Column(db.Text, nullable=False)
  920. answer = db.Column(db.Text, nullable=True)
  921. word_count = db.Column(db.Integer, nullable=False)
  922. tokens = db.Column(db.Integer, nullable=False)
  923. # indexing fields
  924. keywords = db.Column(db.JSON, nullable=True)
  925. index_node_id = db.Column(db.String(255), nullable=True)
  926. index_node_hash = db.Column(db.String(255), nullable=True)
  927. # basic fields
  928. hit_count = db.Column(db.Integer, nullable=False, default=0)
  929. enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
  930. disabled_at = db.Column(db.DateTime, nullable=True)
  931. disabled_by = db.Column(StringUUID, nullable=True)
  932. status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
  933. created_by = db.Column(StringUUID, nullable=False)
  934. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  935. updated_by = db.Column(StringUUID, nullable=True)
  936. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  937. indexing_at = db.Column(db.DateTime, nullable=True)
  938. completed_at = db.Column(db.DateTime, nullable=True)
  939. error = db.Column(db.Text, nullable=True)
  940. stopped_at = db.Column(db.DateTime, nullable=True)
  941. @property
  942. def dataset(self):
  943. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
  944. @property
  945. def document(self):
  946. return db.session.query(Document).filter(Document.id == self.document_id).first()
  947. @property
  948. def previous_segment(self):
  949. return (
  950. db.session.query(DocumentSegment)
  951. .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1)
  952. .first()
  953. )
  954. @property
  955. def next_segment(self):
  956. return (
  957. db.session.query(DocumentSegment)
  958. .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1)
  959. .first()
  960. )
  961. @property
  962. def child_chunks(self):
  963. process_rule = self.document.dataset_process_rule
  964. if process_rule.mode == "hierarchical":
  965. rules = Rule(**process_rule.rules_dict)
  966. if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
  967. child_chunks = (
  968. db.session.query(ChildChunk)
  969. .filter(ChildChunk.segment_id == self.id)
  970. .order_by(ChildChunk.position.asc())
  971. .all()
  972. )
  973. return child_chunks or []
  974. else:
  975. return []
  976. else:
  977. return []
  978. def get_child_chunks(self):
  979. process_rule = self.document.dataset_process_rule
  980. if process_rule.mode == "hierarchical":
  981. rules = Rule(**process_rule.rules_dict)
  982. if rules.parent_mode:
  983. child_chunks = (
  984. db.session.query(ChildChunk)
  985. .filter(ChildChunk.segment_id == self.id)
  986. .order_by(ChildChunk.position.asc())
  987. .all()
  988. )
  989. return child_chunks or []
  990. else:
  991. return []
  992. else:
  993. return []
  994. @property
  995. def sign_content(self):
  996. return self.get_sign_content()
  997. def get_sign_content(self):
  998. signed_urls = []
  999. text = self.content
  1000. # For data before v0.10.0
  1001. pattern = r"/files/([a-f0-9\-]+)/image-preview"
  1002. matches = re.finditer(pattern, text)
  1003. for match in matches:
  1004. upload_file_id = match.group(1)
  1005. nonce = os.urandom(16).hex()
  1006. timestamp = str(int(time.time()))
  1007. data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
  1008. secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
  1009. sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
  1010. encoded_sign = base64.urlsafe_b64encode(sign).decode()
  1011. params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
  1012. signed_url = f"{match.group(0)}?{params}"
  1013. signed_urls.append((match.start(), match.end(), signed_url))
  1014. # For data after v0.10.0
  1015. pattern = r"/files/([a-f0-9\-]+)/file-preview"
  1016. matches = re.finditer(pattern, text)
  1017. for match in matches:
  1018. upload_file_id = match.group(1)
  1019. nonce = os.urandom(16).hex()
  1020. timestamp = str(int(time.time()))
  1021. data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
  1022. secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
  1023. sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
  1024. encoded_sign = base64.urlsafe_b64encode(sign).decode()
  1025. params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
  1026. signed_url = f"{match.group(0)}?{params}"
  1027. signed_urls.append((match.start(), match.end(), signed_url))
  1028. # Reconstruct the text with signed URLs
  1029. offset = 0
  1030. for start, end, signed_url in signed_urls:
  1031. text = text[: start + offset] + signed_url + text[end + offset :]
  1032. offset += len(signed_url) - (end - start)
  1033. return text
  1034. class ChildChunk(db.Model): # type: ignore[name-defined]
  1035. __tablename__ = "child_chunks"
  1036. __table_args__ = (
  1037. db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
  1038. db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
  1039. )
  1040. # initial fields
  1041. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  1042. tenant_id = db.Column(StringUUID, nullable=False)
  1043. dataset_id = db.Column(StringUUID, nullable=False)
  1044. document_id = db.Column(StringUUID, nullable=False)
  1045. segment_id = db.Column(StringUUID, nullable=False)
  1046. position = db.Column(db.Integer, nullable=False)
  1047. content = db.Column(db.Text, nullable=False)
  1048. word_count = db.Column(db.Integer, nullable=False)
  1049. # indexing fields
  1050. index_node_id = db.Column(db.String(255), nullable=True)
  1051. index_node_hash = db.Column(db.String(255), nullable=True)
  1052. type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
  1053. created_by = db.Column(StringUUID, nullable=False)
  1054. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  1055. updated_by = db.Column(StringUUID, nullable=True)
  1056. updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  1057. indexing_at = db.Column(db.DateTime, nullable=True)
  1058. completed_at = db.Column(db.DateTime, nullable=True)
  1059. error = db.Column(db.Text, nullable=True)
  1060. @property
  1061. def dataset(self):
  1062. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
  1063. @property
  1064. def document(self):
  1065. return db.session.query(Document).filter(Document.id == self.document_id).first()
  1066. @property
  1067. def segment(self):
  1068. return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
  1069. class AppDatasetJoin(db.Model): # type: ignore[name-defined]
  1070. __tablename__ = "app_dataset_joins"
  1071. __table_args__ = (
  1072. db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
  1073. db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
  1074. )
  1075. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
  1076. app_id = db.Column(StringUUID, nullable=False)
  1077. dataset_id = db.Column(StringUUID, nullable=False)
  1078. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  1079. @property
  1080. def app(self):
  1081. return db.session.get(App, self.app_id)
  1082. class DatasetQuery(db.Model): # type: ignore[name-defined]
  1083. __tablename__ = "dataset_queries"
  1084. __table_args__ = (
  1085. db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
  1086. db.Index("dataset_query_dataset_id_idx", "dataset_id"),
  1087. )
  1088. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
  1089. dataset_id = db.Column(StringUUID, nullable=False)
  1090. content = db.Column(db.Text, nullable=False)
  1091. source = db.Column(db.String(255), nullable=False)
  1092. source_app_id = db.Column(StringUUID, nullable=True)
  1093. created_by_role = db.Column(db.String, nullable=False)
  1094. created_by = db.Column(StringUUID, nullable=False)
  1095. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  1096. class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
  1097. __tablename__ = "dataset_keyword_tables"
  1098. __table_args__ = (
  1099. db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
  1100. db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
  1101. )
  1102. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  1103. dataset_id = db.Column(StringUUID, nullable=False, unique=True)
  1104. keyword_table = db.Column(db.Text, nullable=False)
  1105. data_source_type = db.Column(
  1106. db.String(255), nullable=False, server_default=db.text("'database'::character varying")
  1107. )
  1108. @property
  1109. def keyword_table_dict(self):
  1110. class SetDecoder(json.JSONDecoder):
  1111. def __init__(self, *args, **kwargs):
  1112. super().__init__(object_hook=self.object_hook, *args, **kwargs)
  1113. def object_hook(self, dct):
  1114. if isinstance(dct, dict):
  1115. for keyword, node_idxs in dct.items():
  1116. if isinstance(node_idxs, list):
  1117. dct[keyword] = set(node_idxs)
  1118. return dct
  1119. # get dataset
  1120. dataset = Dataset.query.filter_by(id=self.dataset_id).first()
  1121. if not dataset:
  1122. return None
  1123. if self.data_source_type == "database":
  1124. return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
  1125. else:
  1126. file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt"
  1127. try:
  1128. keyword_table_text = storage.load_once(file_key)
  1129. if keyword_table_text:
  1130. return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder)
  1131. return None
  1132. except Exception as e:
  1133. logging.exception(f"Failed to load keyword table from file: {file_key}")
  1134. return None
  1135. class Embedding(db.Model): # type: ignore[name-defined]
  1136. __tablename__ = "embeddings"
  1137. __table_args__ = (
  1138. db.PrimaryKeyConstraint("id", name="embedding_pkey"),
  1139. db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
  1140. db.Index("created_at_idx", "created_at"),
  1141. )
  1142. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  1143. model_name = db.Column(
  1144. db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")
  1145. )
  1146. hash = db.Column(db.String(64), nullable=False)
  1147. embedding = db.Column(db.LargeBinary, nullable=False)
  1148. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  1149. provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying"))
  1150. def set_embedding(self, embedding_data: list[float]):
  1151. self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
  1152. def get_embedding(self) -> list[float]:
  1153. return cast(list[float], pickle.loads(self.embedding)) # noqa: S301
  1154. class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
  1155. __tablename__ = "dataset_collection_bindings"
  1156. __table_args__ = (
  1157. db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
  1158. db.Index("provider_model_name_idx", "provider_name", "model_name"),
  1159. )
  1160. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  1161. provider_name = db.Column(db.String(255), nullable=False)
  1162. model_name = db.Column(db.String(255), nullable=False)
  1163. type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
  1164. collection_name = db.Column(db.String(64), nullable=False)
  1165. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  1166. class TidbAuthBinding(db.Model): # type: ignore[name-defined]
  1167. __tablename__ = "tidb_auth_bindings"
  1168. __table_args__ = (
  1169. db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
  1170. db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
  1171. db.Index("tidb_auth_bindings_active_idx", "active"),
  1172. db.Index("tidb_auth_bindings_created_at_idx", "created_at"),
  1173. db.Index("tidb_auth_bindings_status_idx", "status"),
  1174. )
  1175. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  1176. tenant_id = db.Column(StringUUID, nullable=True)
  1177. cluster_id = db.Column(db.String(255), nullable=False)
  1178. cluster_name = db.Column(db.String(255), nullable=False)
  1179. active = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
  1180. status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING"))
  1181. account = db.Column(db.String(255), nullable=False)
  1182. password = db.Column(db.String(255), nullable=False)
  1183. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  1184. class Whitelist(db.Model): # type: ignore[name-defined]
  1185. __tablename__ = "whitelists"
  1186. __table_args__ = (
  1187. db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
  1188. db.Index("whitelists_tenant_idx", "tenant_id"),
  1189. )
  1190. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  1191. tenant_id = db.Column(StringUUID, nullable=True)
  1192. category = db.Column(db.String(255), nullable=False)
  1193. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  1194. class DatasetPermission(db.Model): # type: ignore[name-defined]
  1195. __tablename__ = "dataset_permissions"
  1196. __table_args__ = (
  1197. db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
  1198. db.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
  1199. db.Index("idx_dataset_permissions_account_id", "account_id"),
  1200. db.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
  1201. )
  1202. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True)
  1203. dataset_id = db.Column(StringUUID, nullable=False)
  1204. account_id = db.Column(StringUUID, nullable=False)
  1205. tenant_id = db.Column(StringUUID, nullable=False)
  1206. has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
  1207. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  1208. class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
  1209. __tablename__ = "external_knowledge_apis"
  1210. __table_args__ = (
  1211. db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
  1212. db.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
  1213. db.Index("external_knowledge_apis_name_idx", "name"),
  1214. )
  1215. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  1216. name = db.Column(db.String(255), nullable=False)
  1217. description = db.Column(db.String(255), nullable=False)
  1218. tenant_id = db.Column(StringUUID, nullable=False)
  1219. settings = db.Column(db.Text, nullable=True)
  1220. created_by = db.Column(StringUUID, nullable=False)
  1221. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  1222. updated_by = db.Column(StringUUID, nullable=True)
  1223. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  1224. def to_dict(self):
  1225. return {
  1226. "id": self.id,
  1227. "tenant_id": self.tenant_id,
  1228. "name": self.name,
  1229. "description": self.description,
  1230. "settings": self.settings_dict,
  1231. "dataset_bindings": self.dataset_bindings,
  1232. "created_by": self.created_by,
  1233. "created_at": self.created_at.isoformat(),
  1234. }
  1235. @property
  1236. def settings_dict(self):
  1237. try:
  1238. return json.loads(self.settings) if self.settings else None
  1239. except JSONDecodeError:
  1240. return None
  1241. @property
  1242. def dataset_bindings(self):
  1243. external_knowledge_bindings = (
  1244. db.session.query(ExternalKnowledgeBindings)
  1245. .filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
  1246. .all()
  1247. )
  1248. dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
  1249. datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all()
  1250. dataset_bindings = []
  1251. for dataset in datasets:
  1252. dataset_bindings.append({"id": dataset.id, "name": dataset.name})
  1253. return dataset_bindings
  1254. class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
  1255. __tablename__ = "external_knowledge_bindings"
  1256. __table_args__ = (
  1257. db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
  1258. db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
  1259. db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
  1260. db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
  1261. db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
  1262. )
  1263. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  1264. tenant_id = db.Column(StringUUID, nullable=False)
  1265. external_knowledge_api_id = db.Column(StringUUID, nullable=False)
  1266. dataset_id = db.Column(StringUUID, nullable=False)
  1267. external_knowledge_id = db.Column(db.Text, nullable=False)
  1268. created_by = db.Column(StringUUID, nullable=False)
  1269. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  1270. updated_by = db.Column(StringUUID, nullable=True)
  1271. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  1272. class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
  1273. __tablename__ = "dataset_auto_disable_logs"
  1274. __table_args__ = (
  1275. db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
  1276. db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
  1277. db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
  1278. db.Index("dataset_auto_disable_log_created_atx", "created_at"),
  1279. )
  1280. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  1281. tenant_id = db.Column(StringUUID, nullable=False)
  1282. dataset_id = db.Column(StringUUID, nullable=False)
  1283. document_id = db.Column(StringUUID, nullable=False)
  1284. notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
  1285. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  1286. class RateLimitLog(db.Model): # type: ignore[name-defined]
  1287. __tablename__ = "rate_limit_logs"
  1288. __table_args__ = (
  1289. db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
  1290. db.Index("rate_limit_log_tenant_idx", "tenant_id"),
  1291. db.Index("rate_limit_log_operation_idx", "operation"),
  1292. )
  1293. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  1294. tenant_id = db.Column(StringUUID, nullable=False)
  1295. subscription_plan = db.Column(db.String(255), nullable=False)
  1296. operation = db.Column(db.String(255), nullable=False)
  1297. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  1298. class DatasetMetadata(db.Model): # type: ignore[name-defined]
  1299. __tablename__ = "dataset_metadatas"
  1300. __table_args__ = (
  1301. db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
  1302. db.Index("dataset_metadata_tenant_idx", "tenant_id"),
  1303. db.Index("dataset_metadata_dataset_idx", "dataset_id"),
  1304. )
  1305. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  1306. tenant_id = db.Column(StringUUID, nullable=False)
  1307. dataset_id = db.Column(StringUUID, nullable=False)
  1308. type = db.Column(db.String(255), nullable=False)
  1309. name = db.Column(db.String(255), nullable=False)
  1310. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  1311. updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  1312. created_by = db.Column(StringUUID, nullable=False)
  1313. updated_by = db.Column(StringUUID, nullable=True)
  1314. class DatasetMetadataBinding(db.Model): # type: ignore[name-defined]
  1315. __tablename__ = "dataset_metadata_bindings"
  1316. __table_args__ = (
  1317. db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),
  1318. db.Index("dataset_metadata_binding_tenant_idx", "tenant_id"),
  1319. db.Index("dataset_metadata_binding_dataset_idx", "dataset_id"),
  1320. db.Index("dataset_metadata_binding_metadata_idx", "metadata_id"),
  1321. db.Index("dataset_metadata_binding_document_idx", "document_id"),
  1322. )
  1323. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  1324. tenant_id = db.Column(StringUUID, nullable=False)
  1325. dataset_id = db.Column(StringUUID, nullable=False)
  1326. metadata_id = db.Column(StringUUID, nullable=False)
  1327. document_id = db.Column(StringUUID, nullable=False)
  1328. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  1329. created_by = db.Column(StringUUID, nullable=False)