dataset.py 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948
  1. import base64
  2. import enum
  3. import hashlib
  4. import hmac
  5. import json
  6. import logging
  7. import os
  8. import pickle
  9. import re
  10. import time
  11. from json import JSONDecodeError
  12. from typing import Any, cast
  13. from sqlalchemy import func
  14. from sqlalchemy.dialects.postgresql import JSONB
  15. from sqlalchemy.orm import Mapped
  16. from configs import dify_config
  17. from core.rag.retrieval.retrieval_methods import RetrievalMethod
  18. from extensions.ext_storage import storage
  19. from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
  20. from .account import Account
  21. from .engine import db
  22. from .model import App, Tag, TagBinding, UploadFile
  23. from .types import StringUUID
  24. class DatasetPermissionEnum(enum.StrEnum):
  25. ONLY_ME = "only_me"
  26. ALL_TEAM = "all_team_members"
  27. PARTIAL_TEAM = "partial_members"
  28. class Dataset(db.Model): # type: ignore[name-defined]
  29. __tablename__ = "datasets"
  30. __table_args__ = (
  31. db.PrimaryKeyConstraint("id", name="dataset_pkey"),
  32. db.Index("dataset_tenant_idx", "tenant_id"),
  33. db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
  34. )
  35. INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
  36. PROVIDER_LIST = ["vendor", "external", None]
  37. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  38. tenant_id = db.Column(StringUUID, nullable=False)
  39. name = db.Column(db.String(255), nullable=False)
  40. description = db.Column(db.Text, nullable=True)
  41. provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying"))
  42. permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying"))
  43. data_source_type = db.Column(db.String(255))
  44. indexing_technique = db.Column(db.String(255), nullable=True)
  45. index_struct = db.Column(db.Text, nullable=True)
  46. created_by = db.Column(StringUUID, nullable=False)
  47. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  48. updated_by = db.Column(StringUUID, nullable=True)
  49. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  50. embedding_model = db.Column(db.String(255), nullable=True)
  51. embedding_model_provider = db.Column(db.String(255), nullable=True)
  52. collection_binding_id = db.Column(StringUUID, nullable=True)
  53. retrieval_model = db.Column(JSONB, nullable=True)
  54. @property
  55. def dataset_keyword_table(self):
  56. dataset_keyword_table = (
  57. db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first()
  58. )
  59. if dataset_keyword_table:
  60. return dataset_keyword_table
  61. return None
  62. @property
  63. def index_struct_dict(self):
  64. return json.loads(self.index_struct) if self.index_struct else None
  65. @property
  66. def external_retrieval_model(self):
  67. default_retrieval_model = {
  68. "top_k": 2,
  69. "score_threshold": 0.0,
  70. }
  71. return self.retrieval_model or default_retrieval_model
  72. @property
  73. def created_by_account(self):
  74. return db.session.get(Account, self.created_by)
  75. @property
  76. def latest_process_rule(self):
  77. return (
  78. DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id)
  79. .order_by(DatasetProcessRule.created_at.desc())
  80. .first()
  81. )
  82. @property
  83. def app_count(self):
  84. return (
  85. db.session.query(func.count(AppDatasetJoin.id))
  86. .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
  87. .scalar()
  88. )
  89. @property
  90. def document_count(self):
  91. return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
  92. @property
  93. def available_document_count(self):
  94. return (
  95. db.session.query(func.count(Document.id))
  96. .filter(
  97. Document.dataset_id == self.id,
  98. Document.indexing_status == "completed",
  99. Document.enabled == True,
  100. Document.archived == False,
  101. )
  102. .scalar()
  103. )
  104. @property
  105. def available_segment_count(self):
  106. return (
  107. db.session.query(func.count(DocumentSegment.id))
  108. .filter(
  109. DocumentSegment.dataset_id == self.id,
  110. DocumentSegment.status == "completed",
  111. DocumentSegment.enabled == True,
  112. )
  113. .scalar()
  114. )
  115. @property
  116. def word_count(self):
  117. return (
  118. Document.query.with_entities(func.coalesce(func.sum(Document.word_count)))
  119. .filter(Document.dataset_id == self.id)
  120. .scalar()
  121. )
  122. @property
  123. def doc_form(self):
  124. document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
  125. if document:
  126. return document.doc_form
  127. return None
  128. @property
  129. def retrieval_model_dict(self):
  130. default_retrieval_model = {
  131. "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
  132. "reranking_enable": False,
  133. "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
  134. "top_k": 2,
  135. "score_threshold_enabled": False,
  136. }
  137. return self.retrieval_model or default_retrieval_model
  138. @property
  139. def tags(self):
  140. tags = (
  141. db.session.query(Tag)
  142. .join(TagBinding, Tag.id == TagBinding.tag_id)
  143. .filter(
  144. TagBinding.target_id == self.id,
  145. TagBinding.tenant_id == self.tenant_id,
  146. Tag.tenant_id == self.tenant_id,
  147. Tag.type == "knowledge",
  148. )
  149. .all()
  150. )
  151. return tags or []
  152. @property
  153. def external_knowledge_info(self):
  154. if self.provider != "external":
  155. return None
  156. external_knowledge_binding = (
  157. db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first()
  158. )
  159. if not external_knowledge_binding:
  160. return None
  161. external_knowledge_api = (
  162. db.session.query(ExternalKnowledgeApis)
  163. .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
  164. .first()
  165. )
  166. if not external_knowledge_api:
  167. return None
  168. return {
  169. "external_knowledge_id": external_knowledge_binding.external_knowledge_id,
  170. "external_knowledge_api_id": external_knowledge_api.id,
  171. "external_knowledge_api_name": external_knowledge_api.name,
  172. "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
  173. }
  174. @staticmethod
  175. def gen_collection_name_by_id(dataset_id: str) -> str:
  176. normalized_dataset_id = dataset_id.replace("-", "_")
  177. return f"Vector_index_{normalized_dataset_id}_Node"
  178. class DatasetProcessRule(db.Model): # type: ignore[name-defined]
  179. __tablename__ = "dataset_process_rules"
  180. __table_args__ = (
  181. db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
  182. db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
  183. )
  184. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  185. dataset_id = db.Column(StringUUID, nullable=False)
  186. mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
  187. rules = db.Column(db.Text, nullable=True)
  188. created_by = db.Column(StringUUID, nullable=False)
  189. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  190. MODES = ["automatic", "custom", "hierarchical"]
  191. PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
  192. AUTOMATIC_RULES: dict[str, Any] = {
  193. "pre_processing_rules": [
  194. {"id": "remove_extra_spaces", "enabled": True},
  195. {"id": "remove_urls_emails", "enabled": False},
  196. ],
  197. "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
  198. }
  199. def to_dict(self):
  200. return {
  201. "id": self.id,
  202. "dataset_id": self.dataset_id,
  203. "mode": self.mode,
  204. "rules": self.rules_dict,
  205. }
  206. @property
  207. def rules_dict(self):
  208. try:
  209. return json.loads(self.rules) if self.rules else None
  210. except JSONDecodeError:
  211. return None
  212. class Document(db.Model): # type: ignore[name-defined]
  213. __tablename__ = "documents"
  214. __table_args__ = (
  215. db.PrimaryKeyConstraint("id", name="document_pkey"),
  216. db.Index("document_dataset_id_idx", "dataset_id"),
  217. db.Index("document_is_paused_idx", "is_paused"),
  218. db.Index("document_tenant_idx", "tenant_id"),
  219. )
  220. # initial fields
  221. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  222. tenant_id = db.Column(StringUUID, nullable=False)
  223. dataset_id = db.Column(StringUUID, nullable=False)
  224. position = db.Column(db.Integer, nullable=False)
  225. data_source_type = db.Column(db.String(255), nullable=False)
  226. data_source_info = db.Column(db.Text, nullable=True)
  227. dataset_process_rule_id = db.Column(StringUUID, nullable=True)
  228. batch = db.Column(db.String(255), nullable=False)
  229. name = db.Column(db.String(255), nullable=False)
  230. created_from = db.Column(db.String(255), nullable=False)
  231. created_by = db.Column(StringUUID, nullable=False)
  232. created_api_request_id = db.Column(StringUUID, nullable=True)
  233. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  234. # start processing
  235. processing_started_at = db.Column(db.DateTime, nullable=True)
  236. # parsing
  237. file_id = db.Column(db.Text, nullable=True)
  238. word_count = db.Column(db.Integer, nullable=True)
  239. parsing_completed_at = db.Column(db.DateTime, nullable=True)
  240. # cleaning
  241. cleaning_completed_at = db.Column(db.DateTime, nullable=True)
  242. # split
  243. splitting_completed_at = db.Column(db.DateTime, nullable=True)
  244. # indexing
  245. tokens = db.Column(db.Integer, nullable=True)
  246. indexing_latency = db.Column(db.Float, nullable=True)
  247. completed_at = db.Column(db.DateTime, nullable=True)
  248. # pause
  249. is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
  250. paused_by = db.Column(StringUUID, nullable=True)
  251. paused_at = db.Column(db.DateTime, nullable=True)
  252. # error
  253. error = db.Column(db.Text, nullable=True)
  254. stopped_at = db.Column(db.DateTime, nullable=True)
  255. # basic fields
  256. indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
  257. enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
  258. disabled_at = db.Column(db.DateTime, nullable=True)
  259. disabled_by = db.Column(StringUUID, nullable=True)
  260. archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
  261. archived_reason = db.Column(db.String(255), nullable=True)
  262. archived_by = db.Column(StringUUID, nullable=True)
  263. archived_at = db.Column(db.DateTime, nullable=True)
  264. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  265. doc_type = db.Column(db.String(40), nullable=True)
  266. doc_metadata = db.Column(db.JSON, nullable=True)
  267. doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
  268. doc_language = db.Column(db.String(255), nullable=True)
  269. DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
  270. @property
  271. def display_status(self):
  272. status = None
  273. if self.indexing_status == "waiting":
  274. status = "queuing"
  275. elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused:
  276. status = "paused"
  277. elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}:
  278. status = "indexing"
  279. elif self.indexing_status == "error":
  280. status = "error"
  281. elif self.indexing_status == "completed" and not self.archived and self.enabled:
  282. status = "available"
  283. elif self.indexing_status == "completed" and not self.archived and not self.enabled:
  284. status = "disabled"
  285. elif self.indexing_status == "completed" and self.archived:
  286. status = "archived"
  287. return status
  288. @property
  289. def data_source_info_dict(self):
  290. if self.data_source_info:
  291. try:
  292. data_source_info_dict = json.loads(self.data_source_info)
  293. except JSONDecodeError:
  294. data_source_info_dict = {}
  295. return data_source_info_dict
  296. return None
  297. @property
  298. def data_source_detail_dict(self):
  299. if self.data_source_info:
  300. if self.data_source_type == "upload_file":
  301. data_source_info_dict = json.loads(self.data_source_info)
  302. file_detail = (
  303. db.session.query(UploadFile)
  304. .filter(UploadFile.id == data_source_info_dict["upload_file_id"])
  305. .one_or_none()
  306. )
  307. if file_detail:
  308. return {
  309. "upload_file": {
  310. "id": file_detail.id,
  311. "name": file_detail.name,
  312. "size": file_detail.size,
  313. "extension": file_detail.extension,
  314. "mime_type": file_detail.mime_type,
  315. "created_by": file_detail.created_by,
  316. "created_at": file_detail.created_at.timestamp(),
  317. }
  318. }
  319. elif self.data_source_type in {"notion_import", "website_crawl"}:
  320. return json.loads(self.data_source_info)
  321. return {}
  322. @property
  323. def average_segment_length(self):
  324. if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0:
  325. return self.word_count // self.segment_count
  326. return 0
  327. @property
  328. def dataset_process_rule(self):
  329. if self.dataset_process_rule_id:
  330. return db.session.get(DatasetProcessRule, self.dataset_process_rule_id)
  331. return None
  332. @property
  333. def dataset(self):
  334. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
  335. @property
  336. def segment_count(self):
  337. return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
  338. @property
  339. def hit_count(self):
  340. return (
  341. DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
  342. .filter(DocumentSegment.document_id == self.id)
  343. .scalar()
  344. )
  345. @property
  346. def process_rule_dict(self):
  347. if self.dataset_process_rule_id:
  348. return self.dataset_process_rule.to_dict()
  349. return None
  350. def to_dict(self):
  351. return {
  352. "id": self.id,
  353. "tenant_id": self.tenant_id,
  354. "dataset_id": self.dataset_id,
  355. "position": self.position,
  356. "data_source_type": self.data_source_type,
  357. "data_source_info": self.data_source_info,
  358. "dataset_process_rule_id": self.dataset_process_rule_id,
  359. "batch": self.batch,
  360. "name": self.name,
  361. "created_from": self.created_from,
  362. "created_by": self.created_by,
  363. "created_api_request_id": self.created_api_request_id,
  364. "created_at": self.created_at,
  365. "processing_started_at": self.processing_started_at,
  366. "file_id": self.file_id,
  367. "word_count": self.word_count,
  368. "parsing_completed_at": self.parsing_completed_at,
  369. "cleaning_completed_at": self.cleaning_completed_at,
  370. "splitting_completed_at": self.splitting_completed_at,
  371. "tokens": self.tokens,
  372. "indexing_latency": self.indexing_latency,
  373. "completed_at": self.completed_at,
  374. "is_paused": self.is_paused,
  375. "paused_by": self.paused_by,
  376. "paused_at": self.paused_at,
  377. "error": self.error,
  378. "stopped_at": self.stopped_at,
  379. "indexing_status": self.indexing_status,
  380. "enabled": self.enabled,
  381. "disabled_at": self.disabled_at,
  382. "disabled_by": self.disabled_by,
  383. "archived": self.archived,
  384. "archived_reason": self.archived_reason,
  385. "archived_by": self.archived_by,
  386. "archived_at": self.archived_at,
  387. "updated_at": self.updated_at,
  388. "doc_type": self.doc_type,
  389. "doc_metadata": self.doc_metadata,
  390. "doc_form": self.doc_form,
  391. "doc_language": self.doc_language,
  392. "display_status": self.display_status,
  393. "data_source_info_dict": self.data_source_info_dict,
  394. "average_segment_length": self.average_segment_length,
  395. "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
  396. "dataset": self.dataset.to_dict() if self.dataset else None,
  397. "segment_count": self.segment_count,
  398. "hit_count": self.hit_count,
  399. }
  400. @classmethod
  401. def from_dict(cls, data: dict):
  402. return cls(
  403. id=data.get("id"),
  404. tenant_id=data.get("tenant_id"),
  405. dataset_id=data.get("dataset_id"),
  406. position=data.get("position"),
  407. data_source_type=data.get("data_source_type"),
  408. data_source_info=data.get("data_source_info"),
  409. dataset_process_rule_id=data.get("dataset_process_rule_id"),
  410. batch=data.get("batch"),
  411. name=data.get("name"),
  412. created_from=data.get("created_from"),
  413. created_by=data.get("created_by"),
  414. created_api_request_id=data.get("created_api_request_id"),
  415. created_at=data.get("created_at"),
  416. processing_started_at=data.get("processing_started_at"),
  417. file_id=data.get("file_id"),
  418. word_count=data.get("word_count"),
  419. parsing_completed_at=data.get("parsing_completed_at"),
  420. cleaning_completed_at=data.get("cleaning_completed_at"),
  421. splitting_completed_at=data.get("splitting_completed_at"),
  422. tokens=data.get("tokens"),
  423. indexing_latency=data.get("indexing_latency"),
  424. completed_at=data.get("completed_at"),
  425. is_paused=data.get("is_paused"),
  426. paused_by=data.get("paused_by"),
  427. paused_at=data.get("paused_at"),
  428. error=data.get("error"),
  429. stopped_at=data.get("stopped_at"),
  430. indexing_status=data.get("indexing_status"),
  431. enabled=data.get("enabled"),
  432. disabled_at=data.get("disabled_at"),
  433. disabled_by=data.get("disabled_by"),
  434. archived=data.get("archived"),
  435. archived_reason=data.get("archived_reason"),
  436. archived_by=data.get("archived_by"),
  437. archived_at=data.get("archived_at"),
  438. updated_at=data.get("updated_at"),
  439. doc_type=data.get("doc_type"),
  440. doc_metadata=data.get("doc_metadata"),
  441. doc_form=data.get("doc_form"),
  442. doc_language=data.get("doc_language"),
  443. )
  444. class DocumentSegment(db.Model): # type: ignore[name-defined]
  445. __tablename__ = "document_segments"
  446. __table_args__ = (
  447. db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
  448. db.Index("document_segment_dataset_id_idx", "dataset_id"),
  449. db.Index("document_segment_document_id_idx", "document_id"),
  450. db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
  451. db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
  452. db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"),
  453. db.Index("document_segment_tenant_idx", "tenant_id"),
  454. )
  455. # initial fields
  456. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  457. tenant_id = db.Column(StringUUID, nullable=False)
  458. dataset_id = db.Column(StringUUID, nullable=False)
  459. document_id = db.Column(StringUUID, nullable=False)
  460. position: Mapped[int]
  461. content = db.Column(db.Text, nullable=False)
  462. answer = db.Column(db.Text, nullable=True)
  463. word_count = db.Column(db.Integer, nullable=False)
  464. tokens = db.Column(db.Integer, nullable=False)
  465. # indexing fields
  466. keywords = db.Column(db.JSON, nullable=True)
  467. index_node_id = db.Column(db.String(255), nullable=True)
  468. index_node_hash = db.Column(db.String(255), nullable=True)
  469. # basic fields
  470. hit_count = db.Column(db.Integer, nullable=False, default=0)
  471. enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
  472. disabled_at = db.Column(db.DateTime, nullable=True)
  473. disabled_by = db.Column(StringUUID, nullable=True)
  474. status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
  475. created_by = db.Column(StringUUID, nullable=False)
  476. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  477. updated_by = db.Column(StringUUID, nullable=True)
  478. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  479. indexing_at = db.Column(db.DateTime, nullable=True)
  480. completed_at = db.Column(db.DateTime, nullable=True)
  481. error = db.Column(db.Text, nullable=True)
  482. stopped_at = db.Column(db.DateTime, nullable=True)
  483. @property
  484. def dataset(self):
  485. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
  486. @property
  487. def document(self):
  488. return db.session.query(Document).filter(Document.id == self.document_id).first()
  489. @property
  490. def previous_segment(self):
  491. return (
  492. db.session.query(DocumentSegment)
  493. .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1)
  494. .first()
  495. )
  496. @property
  497. def next_segment(self):
  498. return (
  499. db.session.query(DocumentSegment)
  500. .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1)
  501. .first()
  502. )
  503. @property
  504. def child_chunks(self):
  505. process_rule = self.document.dataset_process_rule
  506. if process_rule.mode == "hierarchical":
  507. rules = Rule(**process_rule.rules_dict)
  508. if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
  509. child_chunks = (
  510. db.session.query(ChildChunk)
  511. .filter(ChildChunk.segment_id == self.id)
  512. .order_by(ChildChunk.position.asc())
  513. .all()
  514. )
  515. return child_chunks or []
  516. else:
  517. return []
  518. else:
  519. return []
  520. @property
  521. def sign_content(self):
  522. return self.get_sign_content()
  523. def get_sign_content(self):
  524. signed_urls = []
  525. text = self.content
  526. # For data before v0.10.0
  527. pattern = r"/files/([a-f0-9\-]+)/image-preview"
  528. matches = re.finditer(pattern, text)
  529. for match in matches:
  530. upload_file_id = match.group(1)
  531. nonce = os.urandom(16).hex()
  532. timestamp = str(int(time.time()))
  533. data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
  534. secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
  535. sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
  536. encoded_sign = base64.urlsafe_b64encode(sign).decode()
  537. params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
  538. signed_url = f"{match.group(0)}?{params}"
  539. signed_urls.append((match.start(), match.end(), signed_url))
  540. # For data after v0.10.0
  541. pattern = r"/files/([a-f0-9\-]+)/file-preview"
  542. matches = re.finditer(pattern, text)
  543. for match in matches:
  544. upload_file_id = match.group(1)
  545. nonce = os.urandom(16).hex()
  546. timestamp = str(int(time.time()))
  547. data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
  548. secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
  549. sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
  550. encoded_sign = base64.urlsafe_b64encode(sign).decode()
  551. params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
  552. signed_url = f"{match.group(0)}?{params}"
  553. signed_urls.append((match.start(), match.end(), signed_url))
  554. # Reconstruct the text with signed URLs
  555. offset = 0
  556. for start, end, signed_url in signed_urls:
  557. text = text[: start + offset] + signed_url + text[end + offset :]
  558. offset += len(signed_url) - (end - start)
  559. return text
  560. class ChildChunk(db.Model): # type: ignore[name-defined]
  561. __tablename__ = "child_chunks"
  562. __table_args__ = (
  563. db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
  564. db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
  565. )
  566. # initial fields
  567. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  568. tenant_id = db.Column(StringUUID, nullable=False)
  569. dataset_id = db.Column(StringUUID, nullable=False)
  570. document_id = db.Column(StringUUID, nullable=False)
  571. segment_id = db.Column(StringUUID, nullable=False)
  572. position = db.Column(db.Integer, nullable=False)
  573. content = db.Column(db.Text, nullable=False)
  574. word_count = db.Column(db.Integer, nullable=False)
  575. # indexing fields
  576. index_node_id = db.Column(db.String(255), nullable=True)
  577. index_node_hash = db.Column(db.String(255), nullable=True)
  578. type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
  579. created_by = db.Column(StringUUID, nullable=False)
  580. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  581. updated_by = db.Column(StringUUID, nullable=True)
  582. updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  583. indexing_at = db.Column(db.DateTime, nullable=True)
  584. completed_at = db.Column(db.DateTime, nullable=True)
  585. error = db.Column(db.Text, nullable=True)
  586. @property
  587. def dataset(self):
  588. return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
  589. @property
  590. def document(self):
  591. return db.session.query(Document).filter(Document.id == self.document_id).first()
  592. @property
  593. def segment(self):
  594. return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
  595. class AppDatasetJoin(db.Model): # type: ignore[name-defined]
  596. __tablename__ = "app_dataset_joins"
  597. __table_args__ = (
  598. db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
  599. db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
  600. )
  601. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
  602. app_id = db.Column(StringUUID, nullable=False)
  603. dataset_id = db.Column(StringUUID, nullable=False)
  604. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  605. @property
  606. def app(self):
  607. return db.session.get(App, self.app_id)
  608. class DatasetQuery(db.Model): # type: ignore[name-defined]
  609. __tablename__ = "dataset_queries"
  610. __table_args__ = (
  611. db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
  612. db.Index("dataset_query_dataset_id_idx", "dataset_id"),
  613. )
  614. id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
  615. dataset_id = db.Column(StringUUID, nullable=False)
  616. content = db.Column(db.Text, nullable=False)
  617. source = db.Column(db.String(255), nullable=False)
  618. source_app_id = db.Column(StringUUID, nullable=True)
  619. created_by_role = db.Column(db.String, nullable=False)
  620. created_by = db.Column(StringUUID, nullable=False)
  621. created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
  622. class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
  623. __tablename__ = "dataset_keyword_tables"
  624. __table_args__ = (
  625. db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
  626. db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
  627. )
  628. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  629. dataset_id = db.Column(StringUUID, nullable=False, unique=True)
  630. keyword_table = db.Column(db.Text, nullable=False)
  631. data_source_type = db.Column(
  632. db.String(255), nullable=False, server_default=db.text("'database'::character varying")
  633. )
  634. @property
  635. def keyword_table_dict(self):
  636. class SetDecoder(json.JSONDecoder):
  637. def __init__(self, *args, **kwargs):
  638. super().__init__(object_hook=self.object_hook, *args, **kwargs)
  639. def object_hook(self, dct):
  640. if isinstance(dct, dict):
  641. for keyword, node_idxs in dct.items():
  642. if isinstance(node_idxs, list):
  643. dct[keyword] = set(node_idxs)
  644. return dct
  645. # get dataset
  646. dataset = Dataset.query.filter_by(id=self.dataset_id).first()
  647. if not dataset:
  648. return None
  649. if self.data_source_type == "database":
  650. return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
  651. else:
  652. file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt"
  653. try:
  654. keyword_table_text = storage.load_once(file_key)
  655. if keyword_table_text:
  656. return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder)
  657. return None
  658. except Exception as e:
  659. logging.exception(f"Failed to load keyword table from file: {file_key}")
  660. return None
  661. class Embedding(db.Model): # type: ignore[name-defined]
  662. __tablename__ = "embeddings"
  663. __table_args__ = (
  664. db.PrimaryKeyConstraint("id", name="embedding_pkey"),
  665. db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
  666. db.Index("created_at_idx", "created_at"),
  667. )
  668. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  669. model_name = db.Column(
  670. db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")
  671. )
  672. hash = db.Column(db.String(64), nullable=False)
  673. embedding = db.Column(db.LargeBinary, nullable=False)
  674. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  675. provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying"))
  676. def set_embedding(self, embedding_data: list[float]):
  677. self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
  678. def get_embedding(self) -> list[float]:
  679. return cast(list[float], pickle.loads(self.embedding))
  680. class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
  681. __tablename__ = "dataset_collection_bindings"
  682. __table_args__ = (
  683. db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
  684. db.Index("provider_model_name_idx", "provider_name", "model_name"),
  685. )
  686. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  687. provider_name = db.Column(db.String(255), nullable=False)
  688. model_name = db.Column(db.String(255), nullable=False)
  689. type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
  690. collection_name = db.Column(db.String(64), nullable=False)
  691. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  692. class TidbAuthBinding(db.Model): # type: ignore[name-defined]
  693. __tablename__ = "tidb_auth_bindings"
  694. __table_args__ = (
  695. db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
  696. db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
  697. db.Index("tidb_auth_bindings_active_idx", "active"),
  698. db.Index("tidb_auth_bindings_created_at_idx", "created_at"),
  699. db.Index("tidb_auth_bindings_status_idx", "status"),
  700. )
  701. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  702. tenant_id = db.Column(StringUUID, nullable=True)
  703. cluster_id = db.Column(db.String(255), nullable=False)
  704. cluster_name = db.Column(db.String(255), nullable=False)
  705. active = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
  706. status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING"))
  707. account = db.Column(db.String(255), nullable=False)
  708. password = db.Column(db.String(255), nullable=False)
  709. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  710. class Whitelist(db.Model): # type: ignore[name-defined]
  711. __tablename__ = "whitelists"
  712. __table_args__ = (
  713. db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
  714. db.Index("whitelists_tenant_idx", "tenant_id"),
  715. )
  716. id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
  717. tenant_id = db.Column(StringUUID, nullable=True)
  718. category = db.Column(db.String(255), nullable=False)
  719. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  720. class DatasetPermission(db.Model): # type: ignore[name-defined]
  721. __tablename__ = "dataset_permissions"
  722. __table_args__ = (
  723. db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
  724. db.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
  725. db.Index("idx_dataset_permissions_account_id", "account_id"),
  726. db.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
  727. )
  728. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True)
  729. dataset_id = db.Column(StringUUID, nullable=False)
  730. account_id = db.Column(StringUUID, nullable=False)
  731. tenant_id = db.Column(StringUUID, nullable=False)
  732. has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
  733. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  734. class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
  735. __tablename__ = "external_knowledge_apis"
  736. __table_args__ = (
  737. db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
  738. db.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
  739. db.Index("external_knowledge_apis_name_idx", "name"),
  740. )
  741. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  742. name = db.Column(db.String(255), nullable=False)
  743. description = db.Column(db.String(255), nullable=False)
  744. tenant_id = db.Column(StringUUID, nullable=False)
  745. settings = db.Column(db.Text, nullable=True)
  746. created_by = db.Column(StringUUID, nullable=False)
  747. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  748. updated_by = db.Column(StringUUID, nullable=True)
  749. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  750. def to_dict(self):
  751. return {
  752. "id": self.id,
  753. "tenant_id": self.tenant_id,
  754. "name": self.name,
  755. "description": self.description,
  756. "settings": self.settings_dict,
  757. "dataset_bindings": self.dataset_bindings,
  758. "created_by": self.created_by,
  759. "created_at": self.created_at.isoformat(),
  760. }
  761. @property
  762. def settings_dict(self):
  763. try:
  764. return json.loads(self.settings) if self.settings else None
  765. except JSONDecodeError:
  766. return None
  767. @property
  768. def dataset_bindings(self):
  769. external_knowledge_bindings = (
  770. db.session.query(ExternalKnowledgeBindings)
  771. .filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
  772. .all()
  773. )
  774. dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
  775. datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all()
  776. dataset_bindings = []
  777. for dataset in datasets:
  778. dataset_bindings.append({"id": dataset.id, "name": dataset.name})
  779. return dataset_bindings
  780. class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
  781. __tablename__ = "external_knowledge_bindings"
  782. __table_args__ = (
  783. db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
  784. db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
  785. db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
  786. db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
  787. db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
  788. )
  789. id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
  790. tenant_id = db.Column(StringUUID, nullable=False)
  791. external_knowledge_api_id = db.Column(StringUUID, nullable=False)
  792. dataset_id = db.Column(StringUUID, nullable=False)
  793. external_knowledge_id = db.Column(db.Text, nullable=False)
  794. created_by = db.Column(StringUUID, nullable=False)
  795. created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  796. updated_by = db.Column(StringUUID, nullable=True)
  797. updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
  798. class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
  799. __tablename__ = "dataset_auto_disable_logs"
  800. __table_args__ = (
  801. db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
  802. db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
  803. db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
  804. db.Index("dataset_auto_disable_log_created_atx", "created_at"),
  805. )
  806. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  807. tenant_id = db.Column(StringUUID, nullable=False)
  808. dataset_id = db.Column(StringUUID, nullable=False)
  809. document_id = db.Column(StringUUID, nullable=False)
  810. notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
  811. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
  812. class RateLimitLog(db.Model): # type: ignore[name-defined]
  813. __tablename__ = "rate_limit_logs"
  814. __table_args__ = (
  815. db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
  816. db.Index("rate_limit_log_tenant_idx", "tenant_id"),
  817. db.Index("rate_limit_log_operation_idx", "operation"),
  818. )
  819. id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
  820. tenant_id = db.Column(StringUUID, nullable=False)
  821. subscription_plan = db.Column(db.String(255), nullable=False)
  822. operation = db.Column(db.String(255), nullable=False)
  823. created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))