浏览代码

suhh-新增模版管理功能(上传、查询、删除模版)

‘suhuihui’ 3 月之前
父节点
当前提交
81aa9fc2a8

+ 1 - 0
api/constants/__init__.py

@@ -12,6 +12,7 @@ VIDEO_EXTENSIONS.extend([ext.upper() for ext in VIDEO_EXTENSIONS])
 AUDIO_EXTENSIONS = ["mp3", "m4a", "wav", "webm", "amr"]
 AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
 
+TEMPLATE_EXTENSIONS=["doc", "docx", "csv", "xlsx", "xls","ppt"]
 
 if dify_config.ETL_TYPE == "Unstructured":
     DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]

+ 3 - 2
api/controllers/console/__init__.py

@@ -22,7 +22,7 @@ from .explore.workflow import (
     InstalledAppWorkflowRunApi,
     InstalledAppWorkflowTaskStopApi,
 )
-from .files import FileApi, FilePreviewApi, FileSupportTypeApi
+from .files import FileApi, FilePreviewApi, FileSupportTypeApi, TemplateFileSupportTypeApi
 from .remote_files import RemoteFileInfoApi, RemoteFileUploadApi
 
 bp = Blueprint("console", __name__, url_prefix="/console/api")
@@ -32,7 +32,7 @@ api = ExternalApi(bp)
 api.add_resource(FileApi, "/files/upload")
 api.add_resource(FilePreviewApi, "/files/<uuid:file_id>/preview")
 api.add_resource(FileSupportTypeApi, "/files/support-type")
-
+api.add_resource(TemplateFileSupportTypeApi, "/files/support-type/template")
 # Remote files
 api.add_resource(RemoteFileInfoApi, "/remote-files/<path:url>")
 api.add_resource(RemoteFileUploadApi, "/remote-files/upload")
@@ -79,6 +79,7 @@ from .datasets import (
     datasets,
     datasets_document,
     datasets_segments,
+    datasets_templates,
     external,
     hit_testing,
     metadata,

+ 187 - 0
api/controllers/console/datasets/datasets_templates.py

@@ -0,0 +1,187 @@
+from typing import Literal
+
+from flask import request, send_file
+from flask_login import current_user  # type: ignore
+from flask_restful import Resource, fields, marshal, marshal_with  # type: ignore
+from werkzeug.exceptions import Forbidden, NotFound
+
+import services
+from controllers.common.errors import FilenameNotExistsError
+from controllers.console import api
+from controllers.console.app.error import (
+    ProviderModelCurrentlyNotSupportError,
+    ProviderNotInitializeError,
+    ProviderQuotaExceededError,
+)
+from controllers.console.error import (
+    FileTooLargeError,
+    NoFileUploadedError,
+    TooManyFilesError,
+    UnsupportedFileTypeError,
+)
+from controllers.console.wraps import (
+    account_initialization_required,
+    cloud_edition_billing_resource_check,
+    setup_required,
+)
+from core.errors.error import (
+    ModelCurrentlyNotSupportError,
+    ProviderTokenNotInitError,
+    QuotaExceededError,
+)
+from fields.template_fields import (
+    template_fields,
+)
+from libs.login import login_required
+from models import Template
+from services.dataset_service import DatasetService, TemplateService
+from services.file_service import FileService
+
+
+class DatasetTemplateListApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+
+    def get(self, dataset_id):
+        dataset_id = str(dataset_id)
+        page = request.args.get("page", default=1, type=int)
+        limit = request.args.get("limit", default=20, type=int)
+        search = request.args.get("keyword", default=None, type=str)
+        dataset = DatasetService.get_dataset(dataset_id)
+
+        if not dataset:
+            raise NotFound("Dataset not found.")
+        try:
+            DatasetService.check_dataset_permission(dataset, current_user)
+        except services.errors.account.NoPermissionError as e:
+            raise Forbidden(str(e))
+
+        query = Template.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
+
+        if search:
+            search = f"%{search}%"
+            query = query.filter(Template.name.like(search))
+
+        paginated_templates = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
+        templates = paginated_templates.items
+        data = marshal(templates, template_fields)
+        response = {
+            "data": data,
+            "has_more": len(templates) == limit,
+            "limit": limit,
+            "total": paginated_templates.total,
+            "page": page,
+        }
+
+        return response
+
+    templates_and_batch_fields = {"templates": fields.List(fields.Nested(template_fields)), "batch": fields.String}
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    @marshal_with(templates_and_batch_fields)
+    @cloud_edition_billing_resource_check("vector_space")
+    def post(self, dataset_id):
+        if not current_user.is_admin:
+            raise Forbidden()
+        dataset_id = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id)
+        if not dataset:
+            raise NotFound("Dataset not found.")
+        if not current_user.is_dataset_editor:
+            raise Forbidden()
+        try:
+            DatasetService.check_dataset_permission(dataset, current_user)
+        except services.errors.account.NoPermissionError as e:
+            raise Forbidden(str(e))
+
+        file = request.files["file"]
+        source_str = request.form.get("source")
+        source: Literal["datasets"] | None = "datasets" if source_str == "datasets" else None
+
+        if "file" not in request.files:
+            raise NoFileUploadedError()
+
+        if len(request.files) > 1:
+            raise TooManyFilesError()
+
+        if not file.filename:
+            raise FilenameNotExistsError
+
+        if source == "datasets" and not current_user.is_dataset_editor:
+            raise Forbidden()
+
+        if source not in ("datasets", None):
+            source = None
+
+        try:
+            upload_file = FileService.upload_file(
+                filename=file.filename,
+                content=file.read(),
+                mimetype=file.mimetype,
+                user=current_user,
+                source=source,
+            )
+        except services.errors.file.FileTooLargeError as file_too_large_error:
+            raise FileTooLargeError(file_too_large_error.description)
+        except services.errors.file.UnsupportedFileTypeError:
+            raise UnsupportedFileTypeError()
+
+        try:
+            templates,batch = TemplateService.save_template_with_dataset_id(upload_file,dataset, current_user)
+        except ProviderTokenNotInitError as ex:
+            raise ProviderNotInitializeError(ex.description)
+        except QuotaExceededError:
+            raise ProviderQuotaExceededError()
+        except ModelCurrentlyNotSupportError:
+            raise ProviderModelCurrentlyNotSupportError()
+
+        return {"templates": templates, "batch": batch}
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, dataset_id):
+        if not current_user.is_admin:
+            raise Forbidden()
+        dataset_id = str(dataset_id)
+        dataset = DatasetService.get_dataset(dataset_id)
+        if dataset is None:
+            raise NotFound("Dataset not found.")
+        DatasetService.check_dataset_model_setting(dataset)
+
+        template_ids = request.args.getlist("template_id")
+        TemplateService.delete_templates(dataset, template_ids)
+        return {"result": "success"}, 204
+
+class DatasetTemplateApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self, template_id):
+        # The role of the current user in the ta table must be admin, owner, or editor
+        if not current_user.is_editor:
+            raise Forbidden()
+        template_id = str(template_id)
+        template = Template.query.filter_by(id=template_id).first()
+
+        #as_attachment下载作为附件下载
+        return send_file(template.file_url, as_attachment=True)
+
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def delete(self, template_id):
+        if not current_user.is_admin:
+            raise Forbidden()
+        template_id = str(template_id)
+        template = TemplateService.get_templates(template_id)
+        if template is None:
+            raise NotFound("Dataset not found.")
+        TemplateService.delete_template(template)
+        return {"result": "success"}, 204
+
+api.add_resource(DatasetTemplateListApi, "/datasets/<uuid:dataset_id>/templates")
+api.add_resource(DatasetTemplateApi, "/datasets/template/<uuid:template_id>")

+ 4 - 0
api/controllers/console/datasets/error.py

@@ -66,6 +66,10 @@ class DocumentIndexingError(BaseHTTPException):
     description = "The document is being processed and cannot be edited."
     code = 400
 
+class TemplateIndexingError(BaseHTTPException):
+    error_code = "template_indexing"
+    description = "The template is being processed and cannot be edited."
+    code = 400
 
 class InvalidMetadataError(BaseHTTPException):
     error_code = "invalid_metadata"

+ 8 - 1
api/controllers/console/files.py

@@ -7,7 +7,7 @@ from werkzeug.exceptions import Forbidden
 
 import services
 from configs import dify_config
-from constants import DOCUMENT_EXTENSIONS
+from constants import DOCUMENT_EXTENSIONS, TEMPLATE_EXTENSIONS
 from controllers.common.errors import FilenameNotExistsError
 from controllers.console.wraps import (
     account_initialization_required,
@@ -100,3 +100,10 @@ class FileSupportTypeApi(Resource):
     @account_initialization_required
     def get(self):
         return {"allowed_extensions": DOCUMENT_EXTENSIONS}
+
+class TemplateFileSupportTypeApi(Resource):
+    @setup_required
+    @login_required
+    @account_initialization_required
+    def get(self):
+        return {"allowed_extensions": TEMPLATE_EXTENSIONS}

+ 6 - 0
api/core/rag/index_processor/constant/built_in_field.py

@@ -8,6 +8,12 @@ class BuiltInField(str, Enum):
     last_update_date = "last_update_date"
     source = "source"
 
+class BuiltIntemplateField(str, Enum):
+    template_name = "template_name"
+    uploader = "uploader"
+    upload_date = "upload_date"
+    last_update_date = "last_update_date"
+    source = "source"
 
 class MetadataDataSource(Enum):
     upload_file = "file_upload"

+ 4 - 0
api/events/template_event.py

@@ -0,0 +1,4 @@
+from blinker import signal
+
+# sender: document
+template_was_deleted = signal("template-was-deleted")

+ 48 - 0
api/fields/template_fields.py

@@ -0,0 +1,48 @@
+from flask_restful import fields  # type: ignore
+
+from fields.dataset_fields import dataset_fields
+from libs.helper import TimestampField
+
+template_fields = {
+    "id": fields.String,
+    "tenant_id":fields.String,
+    "dataset_id":fields.String,
+    "position": fields.Integer,
+    "data_source_type": fields.String,
+    "data_source_info": fields.Raw(attribute="data_source_info_dict"),
+    "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"),
+    "dataset_process_rule_id": fields.String,
+    "name": fields.String,
+    "created_from": fields.String,
+    "created_by": fields.String,
+    "created_at": TimestampField,
+    "tokens": fields.Integer,
+    "error": fields.String,
+    "enabled": fields.Boolean,
+    "disabled_at": TimestampField,
+    "disabled_by": fields.String,
+    "archived": fields.Boolean,
+    "display_status": fields.String,
+    "word_count": fields.Integer,
+    "hit_count": fields.Integer,
+    "doc_form": fields.String
+}
+dataset_and_template_fields = {
+    "dataset": fields.Nested(dataset_fields),
+    "documents": fields.List(fields.Nested(template_fields)),
+    "batch": fields.String,
+}
+template_status_fields = {
+    "id": fields.String,
+    "indexing_status": fields.String,
+    "processing_started_at": TimestampField,
+    "parsing_completed_at": TimestampField,
+    "cleaning_completed_at": TimestampField,
+    "splitting_completed_at": TimestampField,
+    "completed_at": TimestampField,
+    "paused_at": TimestampField,
+    "error": fields.String,
+    "stopped_at": TimestampField,
+    "completed_segments": fields.Integer,
+    "total_segments": fields.Integer,
+}

+ 2 - 0
api/models/__init__.py

@@ -23,6 +23,7 @@ from .dataset import (
     Embedding,
     ExternalKnowledgeApis,
     ExternalKnowledgeBindings,
+    Template,
     TidbAuthBinding,
     Whitelist,
 )
@@ -153,6 +154,7 @@ __all__ = [
     "Site",
     "Tag",
     "TagBinding",
+    "Template",
     "Tenant",
     "TenantAccountJoin",
     "TenantAccountRole",

+ 338 - 0
api/models/dataset.py

@@ -634,7 +634,345 @@ class Document(db.Model):  # type: ignore[name-defined]
             doc_language=data.get("doc_language"),
         )
 
+class Template(db.Model):  # type: ignore[name-defined]
+    __tablename__ = "template"
+    __table_args__ = (
+        db.PrimaryKeyConstraint("id", name="template_pkey"),
+        db.Index("template_dataset_id_idx", "dataset_id"),
+        db.Index("template_is_paused_idx", "is_paused"),
+        db.Index("template_tenant_idx", "tenant_id"),
+        db.Index("template_metadata_idx", "doc_metadata", postgresql_using="gin"),
+    )
+
+    # initial fields
+    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    dataset_id = db.Column(StringUUID, nullable=False)
+    position = db.Column(db.Integer, nullable=False)
+    data_source_type = db.Column(db.String(255), nullable=False)
+    data_source_info = db.Column(db.Text, nullable=True)
+    dataset_process_rule_id = db.Column(StringUUID, nullable=True)
+    batch = db.Column(db.String(255), nullable=False)
+    name = db.Column(db.String(255), nullable=False)
+    created_from = db.Column(db.String(255), nullable=False)
+    created_by = db.Column(StringUUID, nullable=False)
+    created_api_request_id = db.Column(StringUUID, nullable=True)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+
+    # start processing
+    processing_started_at = db.Column(db.DateTime, nullable=True)
+
+    # parsing
+    file_id = db.Column(db.Text, nullable=True)
+    file_url=db.Column(db.Text,nullable=True)
+    word_count = db.Column(db.Integer, nullable=True)
+    parsing_completed_at = db.Column(db.DateTime, nullable=True)
 
+    # cleaning
+    cleaning_completed_at = db.Column(db.DateTime, nullable=True)
+
+    # split
+    splitting_completed_at = db.Column(db.DateTime, nullable=True)
+
+    # indexing
+    tokens = db.Column(db.Integer, nullable=True)
+    indexing_latency = db.Column(db.Float, nullable=True)
+    completed_at = db.Column(db.DateTime, nullable=True)
+
+    # pause
+    is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
+    paused_by = db.Column(StringUUID, nullable=True)
+    paused_at = db.Column(db.DateTime, nullable=True)
+
+    # error
+    error = db.Column(db.Text, nullable=True)
+    stopped_at = db.Column(db.DateTime, nullable=True)
+
+    # basic fields
+    indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
+    enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
+    disabled_at = db.Column(db.DateTime, nullable=True)
+    disabled_by = db.Column(StringUUID, nullable=True)
+    archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
+    archived_reason = db.Column(db.String(255), nullable=True)
+    archived_by = db.Column(StringUUID, nullable=True)
+    archived_at = db.Column(db.DateTime, nullable=True)
+    updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+    doc_type = db.Column(db.String(40), nullable=True)
+    doc_metadata = db.Column(JSONB, nullable=True)
+    doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
+    doc_language = db.Column(db.String(255), nullable=True)
+
+    DATA_SOURCES = ["upload_file"]
+
+    @property
+    def display_status(self):
+        status = None
+        if self.indexing_status == "waiting":
+            status = "queuing"
+        elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused:
+            status = "paused"
+        elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}:
+            status = "indexing"
+        elif self.indexing_status == "error":
+            status = "error"
+        elif self.indexing_status == "completed" and not self.archived and self.enabled:
+            status = "available"
+        elif self.indexing_status == "completed" and not self.archived and not self.enabled:
+            status = "disabled"
+        elif self.indexing_status == "completed" and self.archived:
+            status = "archived"
+        return status
+
+    @property
+    def data_source_info_dict(self):
+        if self.data_source_info:
+            try:
+                data_source_info_dict = json.loads(self.data_source_info)
+            except JSONDecodeError:
+                data_source_info_dict = {}
+
+            return data_source_info_dict
+        return None
+
+    @property
+    def data_source_detail_dict(self):
+        if self.data_source_info:
+            if self.data_source_type == "upload_file":
+                data_source_info_dict = json.loads(self.data_source_info)
+                file_detail = (
+                    db.session.query(UploadFile)
+                    .filter(UploadFile.id == data_source_info_dict["upload_file_id"])
+                    .one_or_none()
+                )
+                if file_detail:
+                    return {
+                        "upload_file": {
+                            "id": file_detail.id,
+                            "name": file_detail.name,
+                            "size": file_detail.size,
+                            "extension": file_detail.extension,
+                            "mime_type": file_detail.mime_type,
+                            "created_by": file_detail.created_by,
+                            "created_at": file_detail.created_at.timestamp(),
+                        }
+                    }
+            elif self.data_source_type in {"notion_import", "website_crawl"}:
+                return json.loads(self.data_source_info)
+        return {}
+
+    @property
+    def average_segment_length(self):
+        if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0:
+            return self.word_count // self.segment_count
+        return 0
+
+    @property
+    def dataset_process_rule(self):
+        if self.dataset_process_rule_id:
+            return db.session.get(DatasetProcessRule, self.dataset_process_rule_id)
+        return None
+
+    @property
+    def dataset(self):
+        return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
+
+    @property
+    def segment_count(self):
+        return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()
+
+    @property
+    def hit_count(self):
+        return (
+            DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
+            .filter(DocumentSegment.document_id == self.id)
+            .scalar()
+        )
+
+    @property
+    def uploader(self):
+        user = db.session.query(Account).filter(Account.id == self.created_by).first()
+        return user.name if user else None
+
+    @property
+    def upload_date(self):
+        return self.created_at
+
+    @property
+    def last_update_date(self):
+        return self.updated_at
+
+    @property
+    def doc_metadata_details(self):
+        if self.doc_metadata:
+            document_metadatas = (
+                db.session.query(DatasetMetadata)
+                .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
+                .filter(
+                    DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
+                )
+                .all()
+            )
+            metadata_list = []
+            for metadata in document_metadatas:
+                metadata_dict = {
+                    "id": metadata.id,
+                    "name": metadata.name,
+                    "type": metadata.type,
+                    "value": self.doc_metadata.get(metadata.name),
+                }
+                metadata_list.append(metadata_dict)
+            # deal built-in fields
+            metadata_list.extend(self.get_built_in_fields())
+
+            return metadata_list
+        return None
+
+    @property
+    def process_rule_dict(self):
+        if self.dataset_process_rule_id:
+            return self.dataset_process_rule.to_dict()
+        return None
+
+    def get_built_in_fields(self):
+        built_in_fields = []
+        built_in_fields.append(
+            {
+                "id": "built-in",
+                "name": BuiltInField.document_name,
+                "type": "string",
+                "value": self.name,
+            }
+        )
+        built_in_fields.append(
+            {
+                "id": "built-in",
+                "name": BuiltInField.uploader,
+                "type": "string",
+                "value": self.uploader,
+            }
+        )
+        built_in_fields.append(
+            {
+                "id": "built-in",
+                "name": BuiltInField.upload_date,
+                "type": "time",
+                "value": self.created_at.timestamp(),
+            }
+        )
+        built_in_fields.append(
+            {
+                "id": "built-in",
+                "name": BuiltInField.last_update_date,
+                "type": "time",
+                "value": self.updated_at.timestamp(),
+            }
+        )
+        built_in_fields.append(
+            {
+                "id": "built-in",
+                "name": BuiltInField.source,
+                "type": "string",
+                "value": MetadataDataSource[self.data_source_type].value,
+            }
+        )
+        return built_in_fields
+
+    def to_dict(self):
+        return {
+            "id": self.id,
+            "tenant_id": self.tenant_id,
+            "dataset_id": self.dataset_id,
+            "position": self.position,
+            "data_source_type": self.data_source_type,
+            "data_source_info": self.data_source_info,
+            "dataset_process_rule_id": self.dataset_process_rule_id,
+            "batch": self.batch,
+            "name": self.name,
+            "created_from": self.created_from,
+            "created_by": self.created_by,
+            "created_api_request_id": self.created_api_request_id,
+            "created_at": self.created_at,
+            "processing_started_at": self.processing_started_at,
+            "file_id": self.file_id,
+            "word_count": self.word_count,
+            "parsing_completed_at": self.parsing_completed_at,
+            "cleaning_completed_at": self.cleaning_completed_at,
+            "splitting_completed_at": self.splitting_completed_at,
+            "tokens": self.tokens,
+            "indexing_latency": self.indexing_latency,
+            "completed_at": self.completed_at,
+            "is_paused": self.is_paused,
+            "paused_by": self.paused_by,
+            "paused_at": self.paused_at,
+            "error": self.error,
+            "stopped_at": self.stopped_at,
+            "indexing_status": self.indexing_status,
+            "enabled": self.enabled,
+            "disabled_at": self.disabled_at,
+            "disabled_by": self.disabled_by,
+            "archived": self.archived,
+            "archived_reason": self.archived_reason,
+            "archived_by": self.archived_by,
+            "archived_at": self.archived_at,
+            "updated_at": self.updated_at,
+            "doc_type": self.doc_type,
+            "doc_metadata": self.doc_metadata,
+            "doc_form": self.doc_form,
+            "doc_language": self.doc_language,
+            "display_status": self.display_status,
+            "data_source_info_dict": self.data_source_info_dict,
+            "average_segment_length": self.average_segment_length,
+            "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
+            "dataset": self.dataset.to_dict() if self.dataset else None,
+            "segment_count": self.segment_count,
+            "hit_count": self.hit_count,
+        }
+
+    @classmethod
+    def from_dict(cls, data: dict):
+        return cls(
+            id=data.get("id"),
+            tenant_id=data.get("tenant_id"),
+            dataset_id=data.get("dataset_id"),
+            position=data.get("position"),
+            data_source_type=data.get("data_source_type"),
+            data_source_info=data.get("data_source_info"),
+            dataset_process_rule_id=data.get("dataset_process_rule_id"),
+            batch=data.get("batch"),
+            name=data.get("name"),
+            created_from=data.get("created_from"),
+            created_by=data.get("created_by"),
+            created_api_request_id=data.get("created_api_request_id"),
+            created_at=data.get("created_at"),
+            processing_started_at=data.get("processing_started_at"),
+            file_id=data.get("file_id"),
+            word_count=data.get("word_count"),
+            parsing_completed_at=data.get("parsing_completed_at"),
+            cleaning_completed_at=data.get("cleaning_completed_at"),
+            splitting_completed_at=data.get("splitting_completed_at"),
+            tokens=data.get("tokens"),
+            indexing_latency=data.get("indexing_latency"),
+            completed_at=data.get("completed_at"),
+            is_paused=data.get("is_paused"),
+            paused_by=data.get("paused_by"),
+            paused_at=data.get("paused_at"),
+            error=data.get("error"),
+            stopped_at=data.get("stopped_at"),
+            indexing_status=data.get("indexing_status"),
+            enabled=data.get("enabled"),
+            disabled_at=data.get("disabled_at"),
+            disabled_by=data.get("disabled_by"),
+            archived=data.get("archived"),
+            archived_reason=data.get("archived_reason"),
+            archived_by=data.get("archived_by"),
+            archived_at=data.get("archived_at"),
+            updated_at=data.get("updated_at"),
+            doc_type=data.get("doc_type"),
+            doc_metadata=data.get("doc_metadata"),
+            doc_form=data.get("doc_form"),
+            doc_language=data.get("doc_language"),
+        )
 class DocumentSegment(db.Model):  # type: ignore[name-defined]
     __tablename__ = "document_segments"
     __table_args__ = (

+ 1 - 1
api/poetry.lock

@@ -7961,7 +7961,7 @@ files = [
 ]
 
 [package.extras]
-check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
+#check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""]
 core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"]
 cover = ["pytest-cov"]
 doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"]

+ 141 - 0
api/services/dataset_service.py

@@ -23,6 +23,7 @@ from core.rag.index_processor.constant.index_type import IndexType
 from core.rag.retrieval.retrieval_methods import RetrievalMethod
 from events.dataset_event import dataset_was_deleted
 from events.document_event import document_was_deleted
+from events.template_event import template_was_deleted
 from extensions.ext_database import db
 from extensions.ext_redis import redis_client
 from libs import helper
@@ -40,6 +41,7 @@ from models.dataset import (
     Document,
     DocumentSegment,
     ExternalKnowledgeBindings,
+    Template,
 )
 from models.model import UploadFile
 from models.source import DataSourceOauthBinding
@@ -60,6 +62,7 @@ from services.feature_service import FeatureModel, FeatureService
 from services.tag_service import TagService
 from services.vector_service import VectorService
 from tasks.batch_clean_document_task import batch_clean_document_task
+from tasks.batch_clean_template_task import batch_clean_template_task
 from tasks.clean_notion_document_task import clean_notion_document_task
 from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
 from tasks.delete_segment_from_index_task import delete_segment_from_index_task
@@ -518,8 +521,138 @@ class DatasetService:
             "document_ids": [],
             "count": 0,
         }
+class TemplateService:
+    DEFAULT_RULES: dict[str, Any] = {
+        "mode": "custom",
+        "rules": {
+            "pre_processing_rules": [
+                {"id": "remove_extra_spaces", "enabled": True},
+                {"id": "remove_urls_emails", "enabled": False},
+            ],
+            "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
+        },
+        "limits": {
+            "indexing_max_segmentation_tokens_length": dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH,
+        },
+    }
+
+    #批量删除
+    @staticmethod
+    def delete_templates(dataset: Dataset,template_ids: list[str]):
+        templates = db.session.query(Template).filter(Template.id.in_(template_ids)).all()
+        file_ids = [
+            template.data_source_info_dict["upload_file_id"]
+            for template in templates
+            if template.data_source_type == "upload_file"
+        ]
+        batch_clean_template_task.delay(template_ids, dataset.id, dataset.doc_form, file_ids)
+        for template in templates:
+            db.session.delete(template)
+        db.session.commit()
+
+    @staticmethod
+    def delete_template(template):
+        file_id = None
+        if template.data_source_type == "upload_file":
+            if template.data_source_info:
+                data_source_info = template.data_source_info_dict
+                if data_source_info and "upload_file_id" in data_source_info:
+                    file_id = data_source_info["upload_file_id"]
+        template_was_deleted.send(
+
+            template.id, file_id=file_id
+        )
+        db.session.delete(template)
+        db.session.commit()
 
 
+    @staticmethod
+    def save_template_with_dataset_id(
+        upload_file: UploadFile,
+        dataset: Dataset,
+        account: Account | Any,
+        dataset_process_rule: Optional[DatasetProcessRule] = None,
+        created_from: str = "web",
+        ):
+        batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999))
+        # save process rule
+        lock_name = "add_template_lock_dataset_id_{}".format(dataset.id)
+        with redis_client.lock(lock_name, timeout=600):
+          position = TemplateService.get_templates_position(dataset.id)
+        if not upload_file:
+         raise FileNotExistsError()
+
+        file_name = upload_file.name
+        print("文件名称"+file_name)
+        file_id=upload_file.id
+        print("文件id"+file_id)
+        data_source_info = {
+        "upload_file_id": upload_file.id
+         }
+        template = TemplateService.build_template(
+                            dataset,
+                            #dataset_process_rule.id,  # type: ignore
+                            #data_source_type,
+                            data_source_info,
+                            created_from,
+                            position,
+                            account,
+                            file_id,
+                            file_name,
+                            batch,
+                            )
+        db.session.add(template)
+        db.session.flush()
+        db.session.commit()
+        templates = []
+        templates.append(template)
+        position += 1
+        return templates,batch
+    @staticmethod
+    def get_templates_position(dataset_id):
+        template = Template.query.filter_by(dataset_id=dataset_id).order_by(Template.position.desc()).first()
+        if template:
+            return template.position + 1
+        else:
+            return 1
+
+    @staticmethod
+    def build_template(
+            dataset: Dataset,
+            #process_rule_id: str,
+            #data_source_type: 'upload_file',
+            data_source_info: dict,
+            created_from: str,
+            position: int,
+            account: Account,
+            file_id: str,
+            file_name: str,
+            batch: str,
+    ):
+        template = Template(
+            tenant_id=dataset.tenant_id,
+            dataset_id=dataset.id,
+            position=position,
+            data_source_type="upload_file",
+            data_source_info=json.dumps(data_source_info),
+            #dataset_process_rule_id=process_rule_id,
+            batch=batch,
+            name=file_name,
+            created_from=created_from,
+            created_by=account.id,
+            file_id=file_id
+        )
+        return template
+
+    @staticmethod
+    def get_templates(template_id)-> Optional[Template]:
+        if template_id:
+            print("模版id" + template_id)
+            template:Optional[Template] = Template.query.filter_by(id=template_id).first()
+            return template
+        else:
+            return None
+
 class DocumentService:
     DEFAULT_RULES: dict[str, Any] = {
         "mode": "custom",
@@ -2329,3 +2462,11 @@ class DatasetPermissionService:
         except Exception as e:
             db.session.rollback()
             raise e
+    @classmethod
+    def save_template_with_dataset_id(
+        dataset: Dataset,
+        knowledge_config: KnowledgeConfig,
+        account: Account | Any,
+        dataset_process_rule: Optional[DatasetProcessRule] = None,
+        created_from: str = "web"):
+        return 1

+ 52 - 0
api/tasks/batch_clean_template_task.py

@@ -0,0 +1,52 @@
+import logging
+import time
+
+import click
+from celery import shared_task  # type: ignore
+
+from extensions.ext_database import db
+from extensions.ext_storage import storage
+from models.dataset import Dataset
+from models.model import UploadFile
+
+
+@shared_task(queue="dataset")
+def batch_clean_template_task(template_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]):
+    """
+    Clean template when template deleted.
+    :param template_ids: template ids
+    :param dataset_id: dataset id
+    :param doc_form: doc_form
+    :param file_ids: file ids
+
+    Usage: batch_clean_template_task.delay(template_ids, dataset_id)
+    """
+    logging.info(click.style("Start batch clean templates when templates deleted", fg="green"))
+    start_at = time.perf_counter()
+
+    try:
+        dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+
+        if not dataset:
+            raise Exception("template has no dataset")
+
+
+        if file_ids:
+            files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all()
+            for file in files:
+                try:
+                    storage.delete(file.key)
+                except Exception:
+                    logging.exception("Delete file failed when template deleted, file_id: {}".format(file.id))
+                db.session.delete(file)
+            db.session.commit()
+
+        end_at = time.perf_counter()
+        logging.info(
+            click.style(
+                "Cleaned templates when templates deleted latency: {}".format(end_at - start_at),
+                fg="green",
+            )
+        )
+    except Exception:
+        logging.exception("Cleaned templates when templates deleted failed")