| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928 | import base64import enumimport hashlibimport hmacimport jsonimport loggingimport osimport pickleimport reimport timefrom json import JSONDecodeErrorfrom typing import Any, castfrom sqlalchemy import funcfrom sqlalchemy.dialects.postgresql import JSONBfrom configs import dify_configfrom core.rag.retrieval.retrieval_methods import RetrievalMethodfrom extensions.ext_storage import storagefrom services.entities.knowledge_entities.knowledge_entities import ParentMode, Rulefrom .account import Accountfrom .engine import dbfrom .model import App, Tag, TagBinding, UploadFilefrom .types import StringUUIDclass DatasetPermissionEnum(enum.StrEnum):    ONLY_ME = "only_me"    ALL_TEAM = "all_team_members"    PARTIAL_TEAM = "partial_members"class Dataset(db.Model):  # type: ignore[name-defined]    __tablename__ = "datasets"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="dataset_pkey"),        db.Index("dataset_tenant_idx", "tenant_id"),        db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),    )    INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]    PROVIDER_LIST = ["vendor", "external", None]    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))    tenant_id = db.Column(StringUUID, nullable=False)    name = db.Column(db.String(255), nullable=False)    description = db.Column(db.Text, nullable=True)    provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying"))    permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying"))    data_source_type = db.Column(db.String(255))    indexing_technique = db.Column(db.String(255), nullable=True)    index_struct = db.Column(db.Text, nullable=True)    created_by = db.Column(StringUUID, nullable=False)    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())    updated_by = db.Column(StringUUID, nullable=True)    updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())    embedding_model = db.Column(db.String(255), nullable=True)    embedding_model_provider = db.Column(db.String(255), nullable=True)    collection_binding_id = db.Column(StringUUID, nullable=True)    retrieval_model = db.Column(JSONB, nullable=True)    @property    def dataset_keyword_table(self):        dataset_keyword_table = (            db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first()        )        if dataset_keyword_table:            return dataset_keyword_table        return None    @property    def index_struct_dict(self):        return json.loads(self.index_struct) if self.index_struct else None    @property    def external_retrieval_model(self):        default_retrieval_model = {            "top_k": 2,            "score_threshold": 0.0,        }        return self.retrieval_model or default_retrieval_model    @property    def created_by_account(self):        return db.session.get(Account, self.created_by)    @property    def latest_process_rule(self):        return (            DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id)            .order_by(DatasetProcessRule.created_at.desc())            .first()        )    @property    def app_count(self):        return (            db.session.query(func.count(AppDatasetJoin.id))            .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)            .scalar()        )    @property    def document_count(self):        return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()    @property    def available_document_count(self):        return (            db.session.query(func.count(Document.id))            .filter(                Document.dataset_id == self.id,                Document.indexing_status == "completed",                Document.enabled == True,                Document.archived == False,            )            .scalar()        )    @property    def available_segment_count(self):        return (            db.session.query(func.count(DocumentSegment.id))            .filter(                DocumentSegment.dataset_id == self.id,                DocumentSegment.status == "completed",                DocumentSegment.enabled == True,            )            .scalar()        )    @property    def word_count(self):        return (            Document.query.with_entities(func.coalesce(func.sum(Document.word_count)))            .filter(Document.dataset_id == self.id)            .scalar()        )    @property    def doc_form(self):        document = db.session.query(Document).filter(Document.dataset_id == self.id).first()        if document:            return document.doc_form        return None    @property    def retrieval_model_dict(self):        default_retrieval_model = {            "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,            "reranking_enable": False,            "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},            "top_k": 2,            "score_threshold_enabled": False,        }        return self.retrieval_model or default_retrieval_model    @property    def tags(self):        tags = (            db.session.query(Tag)            .join(TagBinding, Tag.id == TagBinding.tag_id)            .filter(                TagBinding.target_id == self.id,                TagBinding.tenant_id == self.tenant_id,                Tag.tenant_id == self.tenant_id,                Tag.type == "knowledge",            )            .all()        )        return tags or []    @property    def external_knowledge_info(self):        if self.provider != "external":            return None        external_knowledge_binding = (            db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first()        )        if not external_knowledge_binding:            return None        external_knowledge_api = (            db.session.query(ExternalKnowledgeApis)            .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)            .first()        )        if not external_knowledge_api:            return None        return {            "external_knowledge_id": external_knowledge_binding.external_knowledge_id,            "external_knowledge_api_id": external_knowledge_api.id,            "external_knowledge_api_name": external_knowledge_api.name,            "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),        }    @staticmethod    def gen_collection_name_by_id(dataset_id: str) -> str:        normalized_dataset_id = dataset_id.replace("-", "_")        return f"Vector_index_{normalized_dataset_id}_Node"class DatasetProcessRule(db.Model):  # type: ignore[name-defined]    __tablename__ = "dataset_process_rules"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),        db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),    )    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))    dataset_id = db.Column(StringUUID, nullable=False)    mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))    rules = db.Column(db.Text, nullable=True)    created_by = db.Column(StringUUID, nullable=False)    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())    MODES = ["automatic", "custom", "hierarchical"]    PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]    AUTOMATIC_RULES: dict[str, Any] = {        "pre_processing_rules": [            {"id": "remove_extra_spaces", "enabled": True},            {"id": "remove_urls_emails", "enabled": False},        ],        "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},    }    def to_dict(self):        return {            "id": self.id,            "dataset_id": self.dataset_id,            "mode": self.mode,            "rules": self.rules_dict,        }    @property    def rules_dict(self):        try:            return json.loads(self.rules) if self.rules else None        except JSONDecodeError:            return Noneclass Document(db.Model):  # type: ignore[name-defined]    __tablename__ = "documents"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="document_pkey"),        db.Index("document_dataset_id_idx", "dataset_id"),        db.Index("document_is_paused_idx", "is_paused"),        db.Index("document_tenant_idx", "tenant_id"),    )    # initial fields    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))    tenant_id = db.Column(StringUUID, nullable=False)    dataset_id = db.Column(StringUUID, nullable=False)    position = db.Column(db.Integer, nullable=False)    data_source_type = db.Column(db.String(255), nullable=False)    data_source_info = db.Column(db.Text, nullable=True)    dataset_process_rule_id = db.Column(StringUUID, nullable=True)    batch = db.Column(db.String(255), nullable=False)    name = db.Column(db.String(255), nullable=False)    created_from = db.Column(db.String(255), nullable=False)    created_by = db.Column(StringUUID, nullable=False)    created_api_request_id = db.Column(StringUUID, nullable=True)    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())    # start processing    processing_started_at = db.Column(db.DateTime, nullable=True)    # parsing    file_id = db.Column(db.Text, nullable=True)    word_count = db.Column(db.Integer, nullable=True)    parsing_completed_at = db.Column(db.DateTime, nullable=True)    # cleaning    cleaning_completed_at = db.Column(db.DateTime, nullable=True)    # split    splitting_completed_at = db.Column(db.DateTime, nullable=True)    # indexing    tokens = db.Column(db.Integer, nullable=True)    indexing_latency = db.Column(db.Float, nullable=True)    completed_at = db.Column(db.DateTime, nullable=True)    # pause    is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))    paused_by = db.Column(StringUUID, nullable=True)    paused_at = db.Column(db.DateTime, nullable=True)    # error    error = db.Column(db.Text, nullable=True)    stopped_at = db.Column(db.DateTime, nullable=True)    # basic fields    indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))    enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))    disabled_at = db.Column(db.DateTime, nullable=True)    disabled_by = db.Column(StringUUID, nullable=True)    archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))    archived_reason = db.Column(db.String(255), nullable=True)    archived_by = db.Column(StringUUID, nullable=True)    archived_at = db.Column(db.DateTime, nullable=True)    updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())    doc_type = db.Column(db.String(40), nullable=True)    doc_metadata = db.Column(db.JSON, nullable=True)    doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))    doc_language = db.Column(db.String(255), nullable=True)    DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]    @property    def display_status(self):        status = None        if self.indexing_status == "waiting":            status = "queuing"        elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused:            status = "paused"        elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}:            status = "indexing"        elif self.indexing_status == "error":            status = "error"        elif self.indexing_status == "completed" and not self.archived and self.enabled:            status = "available"        elif self.indexing_status == "completed" and not self.archived and not self.enabled:            status = "disabled"        elif self.indexing_status == "completed" and self.archived:            status = "archived"        return status    @property    def data_source_info_dict(self):        if self.data_source_info:            try:                data_source_info_dict = json.loads(self.data_source_info)            except JSONDecodeError:                data_source_info_dict = {}            return data_source_info_dict        return None    @property    def data_source_detail_dict(self):        if self.data_source_info:            if self.data_source_type == "upload_file":                data_source_info_dict = json.loads(self.data_source_info)                file_detail = (                    db.session.query(UploadFile)                    .filter(UploadFile.id == data_source_info_dict["upload_file_id"])                    .one_or_none()                )                if file_detail:                    return {                        "upload_file": {                            "id": file_detail.id,                            "name": file_detail.name,                            "size": file_detail.size,                            "extension": file_detail.extension,                            "mime_type": file_detail.mime_type,                            "created_by": file_detail.created_by,                            "created_at": file_detail.created_at.timestamp(),                        }                    }            elif self.data_source_type in {"notion_import", "website_crawl"}:                return json.loads(self.data_source_info)        return {}    @property    def average_segment_length(self):        if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0:            return self.word_count // self.segment_count        return 0    @property    def dataset_process_rule(self):        if self.dataset_process_rule_id:            return db.session.get(DatasetProcessRule, self.dataset_process_rule_id)        return None    @property    def dataset(self):        return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()    @property    def segment_count(self):        return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()    @property    def hit_count(self):        return (            DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))            .filter(DocumentSegment.document_id == self.id)            .scalar()        )    @property    def process_rule_dict(self):        if self.dataset_process_rule_id:            return self.dataset_process_rule.to_dict()        return None    def to_dict(self):        return {            "id": self.id,            "tenant_id": self.tenant_id,            "dataset_id": self.dataset_id,            "position": self.position,            "data_source_type": self.data_source_type,            "data_source_info": self.data_source_info,            "dataset_process_rule_id": self.dataset_process_rule_id,            "batch": self.batch,            "name": self.name,            "created_from": self.created_from,            "created_by": self.created_by,            "created_api_request_id": self.created_api_request_id,            "created_at": self.created_at,            "processing_started_at": self.processing_started_at,            "file_id": self.file_id,            "word_count": self.word_count,            "parsing_completed_at": self.parsing_completed_at,            "cleaning_completed_at": self.cleaning_completed_at,            "splitting_completed_at": self.splitting_completed_at,            "tokens": self.tokens,            "indexing_latency": self.indexing_latency,            "completed_at": self.completed_at,            "is_paused": self.is_paused,            "paused_by": self.paused_by,            "paused_at": self.paused_at,            "error": self.error,            "stopped_at": self.stopped_at,            "indexing_status": self.indexing_status,            "enabled": self.enabled,            "disabled_at": self.disabled_at,            "disabled_by": self.disabled_by,            "archived": self.archived,            "archived_reason": self.archived_reason,            "archived_by": self.archived_by,            "archived_at": self.archived_at,            "updated_at": self.updated_at,            "doc_type": self.doc_type,            "doc_metadata": self.doc_metadata,            "doc_form": self.doc_form,            "doc_language": self.doc_language,            "display_status": self.display_status,            "data_source_info_dict": self.data_source_info_dict,            "average_segment_length": self.average_segment_length,            "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,            "dataset": self.dataset.to_dict() if self.dataset else None,            "segment_count": self.segment_count,            "hit_count": self.hit_count,        }    @classmethod    def from_dict(cls, data: dict):        return cls(            id=data.get("id"),            tenant_id=data.get("tenant_id"),            dataset_id=data.get("dataset_id"),            position=data.get("position"),            data_source_type=data.get("data_source_type"),            data_source_info=data.get("data_source_info"),            dataset_process_rule_id=data.get("dataset_process_rule_id"),            batch=data.get("batch"),            name=data.get("name"),            created_from=data.get("created_from"),            created_by=data.get("created_by"),            created_api_request_id=data.get("created_api_request_id"),            created_at=data.get("created_at"),            processing_started_at=data.get("processing_started_at"),            file_id=data.get("file_id"),            word_count=data.get("word_count"),            parsing_completed_at=data.get("parsing_completed_at"),            cleaning_completed_at=data.get("cleaning_completed_at"),            splitting_completed_at=data.get("splitting_completed_at"),            tokens=data.get("tokens"),            indexing_latency=data.get("indexing_latency"),            completed_at=data.get("completed_at"),            is_paused=data.get("is_paused"),            paused_by=data.get("paused_by"),            paused_at=data.get("paused_at"),            error=data.get("error"),            stopped_at=data.get("stopped_at"),            indexing_status=data.get("indexing_status"),            enabled=data.get("enabled"),            disabled_at=data.get("disabled_at"),            disabled_by=data.get("disabled_by"),            archived=data.get("archived"),            archived_reason=data.get("archived_reason"),            archived_by=data.get("archived_by"),            archived_at=data.get("archived_at"),            updated_at=data.get("updated_at"),            doc_type=data.get("doc_type"),            doc_metadata=data.get("doc_metadata"),            doc_form=data.get("doc_form"),            doc_language=data.get("doc_language"),        )class DocumentSegment(db.Model):  # type: ignore[name-defined]    __tablename__ = "document_segments"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="document_segment_pkey"),        db.Index("document_segment_dataset_id_idx", "dataset_id"),        db.Index("document_segment_document_id_idx", "document_id"),        db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),        db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),        db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"),        db.Index("document_segment_tenant_idx", "tenant_id"),    )    # initial fields    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))    tenant_id = db.Column(StringUUID, nullable=False)    dataset_id = db.Column(StringUUID, nullable=False)    document_id = db.Column(StringUUID, nullable=False)    position = db.Column(db.Integer, nullable=False)    content = db.Column(db.Text, nullable=False)    answer = db.Column(db.Text, nullable=True)    word_count = db.Column(db.Integer, nullable=False)    tokens = db.Column(db.Integer, nullable=False)    # indexing fields    keywords = db.Column(db.JSON, nullable=True)    index_node_id = db.Column(db.String(255), nullable=True)    index_node_hash = db.Column(db.String(255), nullable=True)    # basic fields    hit_count = db.Column(db.Integer, nullable=False, default=0)    enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))    disabled_at = db.Column(db.DateTime, nullable=True)    disabled_by = db.Column(StringUUID, nullable=True)    status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))    created_by = db.Column(StringUUID, nullable=False)    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())    updated_by = db.Column(StringUUID, nullable=True)    updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())    indexing_at = db.Column(db.DateTime, nullable=True)    completed_at = db.Column(db.DateTime, nullable=True)    error = db.Column(db.Text, nullable=True)    stopped_at = db.Column(db.DateTime, nullable=True)    @property    def dataset(self):        return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()    @property    def document(self):        return db.session.query(Document).filter(Document.id == self.document_id).first()    @property    def previous_segment(self):        return (            db.session.query(DocumentSegment)            .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1)            .first()        )    @property    def next_segment(self):        return (            db.session.query(DocumentSegment)            .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1)            .first()        )    @property    def child_chunks(self):        process_rule = self.document.dataset_process_rule        if process_rule.mode == "hierarchical":            rules = Rule(**process_rule.rules_dict)            if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:                child_chunks = (                    db.session.query(ChildChunk)                    .filter(ChildChunk.segment_id == self.id)                    .order_by(ChildChunk.position.asc())                    .all()                )                return child_chunks or []            else:                return []        else:            return []    def get_sign_content(self):        signed_urls = []        text = self.content        # For data before v0.10.0        pattern = r"/files/([a-f0-9\-]+)/image-preview"        matches = re.finditer(pattern, text)        for match in matches:            upload_file_id = match.group(1)            nonce = os.urandom(16).hex()            timestamp = str(int(time.time()))            data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"            secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""            sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()            encoded_sign = base64.urlsafe_b64encode(sign).decode()            params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"            signed_url = f"{match.group(0)}?{params}"            signed_urls.append((match.start(), match.end(), signed_url))        # For data after v0.10.0        pattern = r"/files/([a-f0-9\-]+)/file-preview"        matches = re.finditer(pattern, text)        for match in matches:            upload_file_id = match.group(1)            nonce = os.urandom(16).hex()            timestamp = str(int(time.time()))            data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"            secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""            sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()            encoded_sign = base64.urlsafe_b64encode(sign).decode()            params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"            signed_url = f"{match.group(0)}?{params}"            signed_urls.append((match.start(), match.end(), signed_url))        # Reconstruct the text with signed URLs        offset = 0        for start, end, signed_url in signed_urls:            text = text[: start + offset] + signed_url + text[end + offset :]            offset += len(signed_url) - (end - start)        return textclass ChildChunk(db.Model):  # type: ignore[name-defined]    __tablename__ = "child_chunks"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),        db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),    )    # initial fields    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))    tenant_id = db.Column(StringUUID, nullable=False)    dataset_id = db.Column(StringUUID, nullable=False)    document_id = db.Column(StringUUID, nullable=False)    segment_id = db.Column(StringUUID, nullable=False)    position = db.Column(db.Integer, nullable=False)    content = db.Column(db.Text, nullable=False)    word_count = db.Column(db.Integer, nullable=False)    # indexing fields    index_node_id = db.Column(db.String(255), nullable=True)    index_node_hash = db.Column(db.String(255), nullable=True)    type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))    created_by = db.Column(StringUUID, nullable=False)    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))    updated_by = db.Column(StringUUID, nullable=True)    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))    indexing_at = db.Column(db.DateTime, nullable=True)    completed_at = db.Column(db.DateTime, nullable=True)    error = db.Column(db.Text, nullable=True)    @property    def dataset(self):        return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()    @property    def document(self):        return db.session.query(Document).filter(Document.id == self.document_id).first()    @property    def segment(self):        return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()class AppDatasetJoin(db.Model):  # type: ignore[name-defined]    __tablename__ = "app_dataset_joins"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),        db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),    )    id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))    app_id = db.Column(StringUUID, nullable=False)    dataset_id = db.Column(StringUUID, nullable=False)    created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())    @property    def app(self):        return db.session.get(App, self.app_id)class DatasetQuery(db.Model):  # type: ignore[name-defined]    __tablename__ = "dataset_queries"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),        db.Index("dataset_query_dataset_id_idx", "dataset_id"),    )    id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))    dataset_id = db.Column(StringUUID, nullable=False)    content = db.Column(db.Text, nullable=False)    source = db.Column(db.String(255), nullable=False)    source_app_id = db.Column(StringUUID, nullable=True)    created_by_role = db.Column(db.String, nullable=False)    created_by = db.Column(StringUUID, nullable=False)    created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())class DatasetKeywordTable(db.Model):  # type: ignore[name-defined]    __tablename__ = "dataset_keyword_tables"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),        db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),    )    id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))    dataset_id = db.Column(StringUUID, nullable=False, unique=True)    keyword_table = db.Column(db.Text, nullable=False)    data_source_type = db.Column(        db.String(255), nullable=False, server_default=db.text("'database'::character varying")    )    @property    def keyword_table_dict(self):        class SetDecoder(json.JSONDecoder):            def __init__(self, *args, **kwargs):                super().__init__(object_hook=self.object_hook, *args, **kwargs)            def object_hook(self, dct):                if isinstance(dct, dict):                    for keyword, node_idxs in dct.items():                        if isinstance(node_idxs, list):                            dct[keyword] = set(node_idxs)                return dct        # get dataset        dataset = Dataset.query.filter_by(id=self.dataset_id).first()        if not dataset:            return None        if self.data_source_type == "database":            return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None        else:            file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt"            try:                keyword_table_text = storage.load_once(file_key)                if keyword_table_text:                    return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder)                return None            except Exception as e:                logging.exception(f"Failed to load keyword table from file: {file_key}")                return Noneclass Embedding(db.Model):  # type: ignore[name-defined]    __tablename__ = "embeddings"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="embedding_pkey"),        db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),        db.Index("created_at_idx", "created_at"),    )    id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))    model_name = db.Column(        db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")    )    hash = db.Column(db.String(64), nullable=False)    embedding = db.Column(db.LargeBinary, nullable=False)    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())    provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying"))    def set_embedding(self, embedding_data: list[float]):        self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)    def get_embedding(self) -> list[float]:        return cast(list[float], pickle.loads(self.embedding))class DatasetCollectionBinding(db.Model):  # type: ignore[name-defined]    __tablename__ = "dataset_collection_bindings"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),        db.Index("provider_model_name_idx", "provider_name", "model_name"),    )    id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))    provider_name = db.Column(db.String(40), nullable=False)    model_name = db.Column(db.String(255), nullable=False)    type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)    collection_name = db.Column(db.String(64), nullable=False)    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())class TidbAuthBinding(db.Model):  # type: ignore[name-defined]    __tablename__ = "tidb_auth_bindings"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),        db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),        db.Index("tidb_auth_bindings_active_idx", "active"),        db.Index("tidb_auth_bindings_created_at_idx", "created_at"),        db.Index("tidb_auth_bindings_status_idx", "status"),    )    id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))    tenant_id = db.Column(StringUUID, nullable=True)    cluster_id = db.Column(db.String(255), nullable=False)    cluster_name = db.Column(db.String(255), nullable=False)    active = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))    status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING"))    account = db.Column(db.String(255), nullable=False)    password = db.Column(db.String(255), nullable=False)    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())class Whitelist(db.Model):  # type: ignore[name-defined]    __tablename__ = "whitelists"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="whitelists_pkey"),        db.Index("whitelists_tenant_idx", "tenant_id"),    )    id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))    tenant_id = db.Column(StringUUID, nullable=True)    category = db.Column(db.String(255), nullable=False)    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())class DatasetPermission(db.Model):  # type: ignore[name-defined]    __tablename__ = "dataset_permissions"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),        db.Index("idx_dataset_permissions_dataset_id", "dataset_id"),        db.Index("idx_dataset_permissions_account_id", "account_id"),        db.Index("idx_dataset_permissions_tenant_id", "tenant_id"),    )    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True)    dataset_id = db.Column(StringUUID, nullable=False)    account_id = db.Column(StringUUID, nullable=False)    tenant_id = db.Column(StringUUID, nullable=False)    has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())class ExternalKnowledgeApis(db.Model):  # type: ignore[name-defined]    __tablename__ = "external_knowledge_apis"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),        db.Index("external_knowledge_apis_tenant_idx", "tenant_id"),        db.Index("external_knowledge_apis_name_idx", "name"),    )    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))    name = db.Column(db.String(255), nullable=False)    description = db.Column(db.String(255), nullable=False)    tenant_id = db.Column(StringUUID, nullable=False)    settings = db.Column(db.Text, nullable=True)    created_by = db.Column(StringUUID, nullable=False)    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())    updated_by = db.Column(StringUUID, nullable=True)    updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())    def to_dict(self):        return {            "id": self.id,            "tenant_id": self.tenant_id,            "name": self.name,            "description": self.description,            "settings": self.settings_dict,            "dataset_bindings": self.dataset_bindings,            "created_by": self.created_by,            "created_at": self.created_at.isoformat(),        }    @property    def settings_dict(self):        try:            return json.loads(self.settings) if self.settings else None        except JSONDecodeError:            return None    @property    def dataset_bindings(self):        external_knowledge_bindings = (            db.session.query(ExternalKnowledgeBindings)            .filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)            .all()        )        dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]        datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all()        dataset_bindings = []        for dataset in datasets:            dataset_bindings.append({"id": dataset.id, "name": dataset.name})        return dataset_bindingsclass ExternalKnowledgeBindings(db.Model):  # type: ignore[name-defined]    __tablename__ = "external_knowledge_bindings"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),        db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),        db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),        db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),        db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),    )    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))    tenant_id = db.Column(StringUUID, nullable=False)    external_knowledge_api_id = db.Column(StringUUID, nullable=False)    dataset_id = db.Column(StringUUID, nullable=False)    external_knowledge_id = db.Column(db.Text, nullable=False)    created_by = db.Column(StringUUID, nullable=False)    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())    updated_by = db.Column(StringUUID, nullable=True)    updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())class DatasetAutoDisableLog(db.Model):  # type: ignore[name-defined]    __tablename__ = "dataset_auto_disable_logs"    __table_args__ = (        db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),        db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),        db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),        db.Index("dataset_auto_disable_log_created_atx", "created_at"),    )    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))    tenant_id = db.Column(StringUUID, nullable=False)    dataset_id = db.Column(StringUUID, nullable=False)    document_id = db.Column(StringUUID, nullable=False)    notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
 |