Procházet zdrojové kódy

Feat/new saas billing (#14996)

Jyong před 1 měsícem
rodič
revize
9b2a9260ef

+ 9 - 1
api/controllers/console/datasets/datasets.py

@@ -10,7 +10,12 @@ from controllers.console import api
 from controllers.console.apikey import api_key_fields, api_key_list
 from controllers.console.app.error import ProviderNotInitializeError
 from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
-from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
+from controllers.console.wraps import (
+    account_initialization_required,
+    cloud_edition_billing_rate_limit_check,
+    enterprise_license_required,
+    setup_required,
+)
 from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
 from core.indexing_runner import IndexingRunner
 from core.model_runtime.entities.model_entities import ModelType
@@ -96,6 +101,7 @@ class DatasetListApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self):
         parser = reqparse.RequestParser()
         parser.add_argument(
@@ -210,6 +216,7 @@ class DatasetApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id):
         dataset_id_str = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id_str)
@@ -313,6 +320,7 @@ class DatasetApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def delete(self, dataset_id):
         dataset_id_str = str(dataset_id)
 

+ 10 - 0
api/controllers/console/datasets/datasets_document.py

@@ -26,6 +26,7 @@ from controllers.console.datasets.error import (
 )
 from controllers.console.wraps import (
     account_initialization_required,
+    cloud_edition_billing_rate_limit_check,
     cloud_edition_billing_resource_check,
     setup_required,
 )
@@ -242,6 +243,7 @@ class DatasetDocumentListApi(Resource):
     @account_initialization_required
     @marshal_with(documents_and_batch_fields)
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self, dataset_id):
         dataset_id = str(dataset_id)
 
@@ -297,6 +299,7 @@ class DatasetDocumentListApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def delete(self, dataset_id):
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
@@ -320,6 +323,7 @@ class DatasetInitApi(Resource):
     @account_initialization_required
     @marshal_with(dataset_and_document_fields)
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self):
         # The role of the current user in the ta table must be admin, owner, or editor
         if not current_user.is_editor:
@@ -694,6 +698,7 @@ class DocumentProcessingApi(DocumentResource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id, action):
         dataset_id = str(dataset_id)
         document_id = str(document_id)
@@ -730,6 +735,7 @@ class DocumentDeleteApi(DocumentResource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def delete(self, dataset_id, document_id):
         dataset_id = str(dataset_id)
         document_id = str(document_id)
@@ -798,6 +804,7 @@ class DocumentStatusApi(DocumentResource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, action):
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
@@ -893,6 +900,7 @@ class DocumentPauseApi(DocumentResource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id):
         """pause document."""
         dataset_id = str(dataset_id)
@@ -925,6 +933,7 @@ class DocumentRecoverApi(DocumentResource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id):
         """recover document."""
         dataset_id = str(dataset_id)
@@ -954,6 +963,7 @@ class DocumentRetryApi(DocumentResource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self, dataset_id):
         """retry document."""
 

+ 11 - 0
api/controllers/console/datasets/datasets_segments.py

@@ -19,6 +19,7 @@ from controllers.console.datasets.error import (
 from controllers.console.wraps import (
     account_initialization_required,
     cloud_edition_billing_knowledge_limit_check,
+    cloud_edition_billing_rate_limit_check,
     cloud_edition_billing_resource_check,
     setup_required,
 )
@@ -106,6 +107,7 @@ class DatasetDocumentSegmentListApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def delete(self, dataset_id, document_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -137,6 +139,7 @@ class DatasetDocumentSegmentApi(Resource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id, action):
         dataset_id = str(dataset_id)
         dataset = DatasetService.get_dataset(dataset_id)
@@ -191,6 +194,7 @@ class DatasetDocumentSegmentAddApi(Resource):
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
     @cloud_edition_billing_knowledge_limit_check("add_segment")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self, dataset_id, document_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -240,6 +244,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id, segment_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -299,6 +304,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def delete(self, dataset_id, document_id, segment_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -336,6 +342,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
     @cloud_edition_billing_knowledge_limit_check("add_segment")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self, dataset_id, document_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -402,6 +409,7 @@ class ChildChunkAddApi(Resource):
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
     @cloud_edition_billing_knowledge_limit_check("add_segment")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self, dataset_id, document_id, segment_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -499,6 +507,7 @@ class ChildChunkAddApi(Resource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id, segment_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -542,6 +551,7 @@ class ChildChunkUpdateApi(Resource):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
         # check dataset
         dataset_id = str(dataset_id)
@@ -586,6 +596,7 @@ class ChildChunkUpdateApi(Resource):
     @login_required
     @account_initialization_required
     @cloud_edition_billing_resource_check("vector_space")
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
         # check dataset
         dataset_id = str(dataset_id)

+ 6 - 1
api/controllers/console/datasets/hit_testing.py

@@ -2,7 +2,11 @@ from flask_restful import Resource  # type: ignore
 
 from controllers.console import api
 from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
-from controllers.console.wraps import account_initialization_required, setup_required
+from controllers.console.wraps import (
+    account_initialization_required,
+    cloud_edition_billing_rate_limit_check,
+    setup_required,
+)
 from libs.login import login_required
 
 
@@ -10,6 +14,7 @@ class HitTestingApi(Resource, DatasetsHitTestingBase):
     @setup_required
     @login_required
     @account_initialization_required
+    @cloud_edition_billing_rate_limit_check("knowledge")
     def post(self, dataset_id):
         dataset_id_str = str(dataset_id)
 

+ 41 - 1
api/controllers/console/wraps.py

@@ -1,5 +1,6 @@
 import json
 import os
+import time
 from functools import wraps
 
 from flask import abort, request
@@ -8,6 +9,8 @@ from flask_login import current_user  # type: ignore
 from configs import dify_config
 from controllers.console.workspace.error import AccountNotInitializedError
 from extensions.ext_database import db
+from extensions.ext_redis import redis_client
+from models.dataset import RateLimitLog
 from models.model import DifySetup
 from services.feature_service import FeatureService, LicenseStatus
 from services.operation_service import OperationService
@@ -67,7 +70,9 @@ def cloud_edition_billing_resource_check(resource: str):
                 elif resource == "apps" and 0 < apps.limit <= apps.size:
                     abort(403, "The number of apps has reached the limit of your subscription.")
                 elif resource == "vector_space" and 0 < vector_space.limit <= vector_space.size:
-                    abort(403, "The capacity of the vector space has reached the limit of your subscription.")
+                    abort(
+                        403, "The capacity of the knowledge storage space has reached the limit of your subscription."
+                    )
                 elif resource == "documents" and 0 < documents_upload_quota.limit <= documents_upload_quota.size:
                     # The api of file upload is used in the multiple places,
                     # so we need to check the source of the request from datasets
@@ -112,6 +117,41 @@ def cloud_edition_billing_knowledge_limit_check(resource: str):
     return interceptor
 
 
+def cloud_edition_billing_rate_limit_check(resource: str):
+    def interceptor(view):
+        @wraps(view)
+        def decorated(*args, **kwargs):
+            if resource == "knowledge":
+                knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(current_user.current_tenant_id)
+                if knowledge_rate_limit.enabled:
+                    current_time = int(time.time() * 1000)
+                    key = f"rate_limit_{current_user.current_tenant_id}"
+
+                    redis_client.zadd(key, {current_time: current_time})
+
+                    redis_client.zremrangebyscore(key, 0, current_time - 60000)
+
+                    request_count = redis_client.zcard(key)
+
+                    if request_count > knowledge_rate_limit.limit:
+                        # add ratelimit record
+                        rate_limit_log = RateLimitLog(
+                            tenant_id=current_user.current_tenant_id,
+                            subscription_plan=knowledge_rate_limit.subscription_plan,
+                            operation="knowledge",
+                        )
+                        db.session.add(rate_limit_log)
+                        db.session.commit()
+                        abort(
+                            403, "Sorry, you have reached the knowledge base request rate limit of your subscription."
+                        )
+            return view(*args, **kwargs)
+
+        return decorated
+
+    return interceptor
+
+
 def cloud_utm_record(view):
     @wraps(view)
     def decorated(*args, **kwargs):

+ 40 - 0
api/controllers/service_api/wraps.py

@@ -1,3 +1,4 @@
+import time
 from collections.abc import Callable
 from datetime import UTC, datetime, timedelta
 from enum import Enum
@@ -13,8 +14,10 @@ from sqlalchemy.orm import Session
 from werkzeug.exceptions import Forbidden, Unauthorized
 
 from extensions.ext_database import db
+from extensions.ext_redis import redis_client
 from libs.login import _get_user
 from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
+from models.dataset import RateLimitLog
 from models.model import ApiToken, App, EndUser
 from services.feature_service import FeatureService
 
@@ -139,6 +142,43 @@ def cloud_edition_billing_knowledge_limit_check(resource: str, api_token_type: s
     return interceptor
 
 
+def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str):
+    def interceptor(view):
+        @wraps(view)
+        def decorated(*args, **kwargs):
+            api_token = validate_and_get_api_token(api_token_type)
+
+            if resource == "knowledge":
+                knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(api_token.tenant_id)
+                if knowledge_rate_limit.enabled:
+                    current_time = int(time.time() * 1000)
+                    key = f"rate_limit_{api_token.tenant_id}"
+
+                    redis_client.zadd(key, {current_time: current_time})
+
+                    redis_client.zremrangebyscore(key, 0, current_time - 60000)
+
+                    request_count = redis_client.zcard(key)
+
+                    if request_count > knowledge_rate_limit.limit:
+                        # add ratelimit record
+                        rate_limit_log = RateLimitLog(
+                            tenant_id=api_token.tenant_id,
+                            subscription_plan=knowledge_rate_limit.subscription_plan,
+                            operation="knowledge",
+                        )
+                        db.session.add(rate_limit_log)
+                        db.session.commit()
+                        raise Forbidden(
+                            "Sorry, you have reached the knowledge base request rate limit of your subscription."
+                        )
+            return view(*args, **kwargs)
+
+        return decorated
+
+    return interceptor
+
+
 def validate_dataset_token(view=None):
     def decorator(view):
         @wraps(view)

+ 29 - 1
api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py

@@ -1,4 +1,5 @@
 import logging
+import time
 from collections.abc import Mapping, Sequence
 from typing import Any, cast
 
@@ -19,8 +20,10 @@ from core.workflow.entities.node_entities import NodeRunResult
 from core.workflow.nodes.base import BaseNode
 from core.workflow.nodes.enums import NodeType
 from extensions.ext_database import db
-from models.dataset import Dataset, Document
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset, Document, RateLimitLog
 from models.workflow import WorkflowNodeExecutionStatus
+from services.feature_service import FeatureService
 
 from .entities import KnowledgeRetrievalNodeData
 from .exc import (
@@ -61,6 +64,31 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
             return NodeRunResult(
                 status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required."
             )
+        # check rate limit
+        if self.tenant_id:
+            knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
+            if knowledge_rate_limit.enabled:
+                current_time = int(time.time() * 1000)
+                key = f"rate_limit_{self.tenant_id}"
+                redis_client.zadd(key, {current_time: current_time})
+                redis_client.zremrangebyscore(key, 0, current_time - 60000)
+                request_count = redis_client.zcard(key)
+                if request_count > knowledge_rate_limit.limit:
+                    # add ratelimit record
+                    rate_limit_log = RateLimitLog(
+                        tenant_id=self.tenant_id,
+                        subscription_plan=knowledge_rate_limit.subscription_plan,
+                        operation="knowledge",
+                    )
+                    db.session.add(rate_limit_log)
+                    db.session.commit()
+                    return NodeRunResult(
+                        status=WorkflowNodeExecutionStatus.FAILED,
+                        inputs=variables,
+                        error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
+                        error_type="RateLimitExceeded",
+                    )
+
         # retrieve knowledge
         try:
             results = self._fetch_dataset_retriever(node_data=self.node_data, query=query)

+ 43 - 0
api/migrations/versions/2025_01_14_0617-f051706725cc_add_rate_limit_logs.py

@@ -0,0 +1,43 @@
+"""add_rate_limit_logs
+
+Revision ID: f051706725cc
+Revises: 923752d42eb6
+Create Date: 2025-01-14 06:17:35.536388
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = 'f051706725cc'
+down_revision = 'd20049ed0af6'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    op.create_table('rate_limit_logs',
+    sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+    sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+    sa.Column('subscription_plan', sa.String(length=255), nullable=False),
+    sa.Column('operation', sa.String(length=255), nullable=False),
+    sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+    sa.PrimaryKeyConstraint('id', name='rate_limit_log_pkey')
+    )
+    with op.batch_alter_table('rate_limit_logs', schema=None) as batch_op:
+        batch_op.create_index('rate_limit_log_operation_idx', ['operation'], unique=False)
+        batch_op.create_index('rate_limit_log_tenant_idx', ['tenant_id'], unique=False)
+    # ### end Alembic commands ###
+
+
+def downgrade():
+    # ### commands auto generated by Alembic - please adjust! ###
+    with op.batch_alter_table('rate_limit_logs', schema=None) as batch_op:
+        batch_op.drop_index('rate_limit_log_tenant_idx')
+        batch_op.drop_index('rate_limit_log_operation_idx')
+
+    op.drop_table('rate_limit_logs')
+    # ### end Alembic commands ###

+ 15 - 0
api/models/dataset.py

@@ -930,3 +930,18 @@ class DatasetAutoDisableLog(db.Model):  # type: ignore[name-defined]
     document_id = db.Column(StringUUID, nullable=False)
     notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
     created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+
+
+class RateLimitLog(db.Model):  # type: ignore[name-defined]
+    __tablename__ = "rate_limit_logs"
+    __table_args__ = (
+        db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
+        db.Index("rate_limit_log_tenant_idx", "tenant_id"),
+        db.Index("rate_limit_log_operation_idx", "operation"),
+    )
+
+    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+    tenant_id = db.Column(StringUUID, nullable=False)
+    subscription_plan = db.Column(db.String(255), nullable=False)
+    operation = db.Column(db.String(255), nullable=False)
+    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

+ 11 - 0
api/services/billing_service.py

@@ -23,6 +23,17 @@ class BillingService:
         return billing_info
 
     @classmethod
+    def get_knowledge_rate_limit(cls, tenant_id: str):
+        params = {"tenant_id": tenant_id}
+
+        knowledge_rate_limit = cls._send_request("GET", "/subscription/knowledge-rate-limit", params=params)
+
+        return {
+            "limit": knowledge_rate_limit.get("limit", 10),
+            "subscription_plan": knowledge_rate_limit.get("subscription_plan", "sandbox"),
+        }
+
+    @classmethod
     def get_subscription(cls, plan: str, interval: str, prefilled_email: str = "", tenant_id: str = ""):
         params = {"plan": plan, "interval": interval, "prefilled_email": prefilled_email, "tenant_id": tenant_id}
         return cls._send_request("GET", "/subscription/payment-link", params=params)

+ 20 - 0
api/services/feature_service.py

@@ -41,6 +41,7 @@ class FeatureModel(BaseModel):
     members: LimitationModel = LimitationModel(size=0, limit=1)
     apps: LimitationModel = LimitationModel(size=0, limit=10)
     vector_space: LimitationModel = LimitationModel(size=0, limit=5)
+    knowledge_rate_limit: int = 10
     annotation_quota_limit: LimitationModel = LimitationModel(size=0, limit=10)
     documents_upload_quota: LimitationModel = LimitationModel(size=0, limit=50)
     docs_processing: str = "standard"
@@ -52,6 +53,12 @@ class FeatureModel(BaseModel):
     model_config = ConfigDict(protected_namespaces=())
 
 
+class KnowledgeRateLimitModel(BaseModel):
+    enabled: bool = False
+    limit: int = 10
+    subscription_plan: str = ""
+
+
 class SystemFeatureModel(BaseModel):
     sso_enforced_for_signin: bool = False
     sso_enforced_for_signin_protocol: str = ""
@@ -82,6 +89,16 @@ class FeatureService:
         return features
 
     @classmethod
+    def get_knowledge_rate_limit(cls, tenant_id: str):
+        knowledge_rate_limit = KnowledgeRateLimitModel()
+        if dify_config.BILLING_ENABLED and tenant_id:
+            knowledge_rate_limit.enabled = True
+            limit_info = BillingService.get_knowledge_rate_limit(tenant_id)
+            knowledge_rate_limit.limit = limit_info.get("limit", 10)
+            knowledge_rate_limit.subscription_plan = limit_info.get("subscription_plan", "sandbox")
+        return knowledge_rate_limit
+
+    @classmethod
     def get_system_features(cls) -> SystemFeatureModel:
         system_features = SystemFeatureModel()
 
@@ -149,6 +166,9 @@ class FeatureService:
         if "model_load_balancing_enabled" in billing_info:
             features.model_load_balancing_enabled = billing_info["model_load_balancing_enabled"]
 
+        if "knowledge_rate_limit" in billing_info:
+            features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"]
+
     @classmethod
     def _fulfill_params_from_enterprise(cls, features):
         enterprise_info = EnterpriseService.get_info()