瀏覽代碼

feat: support limitations of persistence storage

Yeuoly 8 月之前
父節點
當前提交
c3ec7d010a

+ 4 - 2
internal/core/persistence/init.go

@@ -22,11 +22,13 @@ func InitPersistence(config *app.Config) {
 		}
 
 		persistence = &Persistence{
-			storage: s3,
+			storage:          s3,
+			max_storage_size: config.PersistenceStorageMaxSize,
 		}
 	} else if config.PersistenceStorageType == "local" {
 		persistence = &Persistence{
-			storage: NewLocalWrapper(config.PersistenceStorageLocalPath),
+			storage:          NewLocalWrapper(config.PersistenceStorageLocalPath),
+			max_storage_size: config.PersistenceStorageMaxSize,
 		}
 	} else {
 		log.Panic("Invalid persistence storage type: %s", config.PersistenceStorageType)

+ 9 - 0
internal/core/persistence/local.go

@@ -44,3 +44,12 @@ func (l *LocalWrapper) Delete(tenant_id string, plugin_checksum string, key stri
 	file_path := l.getFilePath(tenant_id, plugin_checksum, key)
 	return os.Remove(file_path)
 }
+
+func (l *LocalWrapper) StateSize(tenant_id string, plugin_checksum string, key string) (int64, error) {
+	file_path := l.getFilePath(tenant_id, plugin_checksum, key)
+	info, err := os.Stat(file_path)
+	if err != nil {
+		return 0, err
+	}
+	return info.Size(), nil
+}

+ 72 - 2
internal/core/persistence/persistence.go

@@ -5,10 +5,14 @@ import (
 	"fmt"
 	"time"
 
+	"github.com/langgenius/dify-plugin-daemon/internal/db"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
 )
 
 type Persistence struct {
+	max_storage_size int64
+
 	storage PersistenceStorage
 }
 
@@ -20,15 +24,58 @@ func (c *Persistence) getCacheKey(tenant_id string, plugin_id string, key string
 	return fmt.Sprintf("%s:%s:%s:%s", CACHE_KEY_PREFIX, tenant_id, plugin_id, key)
 }
 
-func (c *Persistence) Save(tenant_id string, plugin_id string, key string, data []byte) error {
+func (c *Persistence) Save(tenant_id string, plugin_id string, max_size int64, key string, data []byte) error {
 	if len(key) > 64 {
 		return fmt.Errorf("key length must be less than 64 characters")
 	}
 
+	if max_size == -1 {
+		max_size = c.max_storage_size
+	}
+
 	if err := c.storage.Save(tenant_id, plugin_id, key, data); err != nil {
 		return err
 	}
 
+	allocated_size := int64(len(data))
+
+	storage, err := db.GetOne[models.TenantStorage](
+		db.Equal("tenant_id", tenant_id),
+		db.Equal("plugin_id", plugin_id),
+	)
+	if err != nil {
+		if allocated_size > c.max_storage_size || allocated_size > max_size {
+			return fmt.Errorf("allocated size is greater than max storage size")
+		}
+
+		if err == db.ErrDatabaseNotFound {
+			storage = models.TenantStorage{
+				TenantID: tenant_id,
+				PluginID: plugin_id,
+				Size:     allocated_size,
+			}
+			if err := db.Create(&storage); err != nil {
+				return err
+			}
+		} else {
+			return err
+		}
+	} else {
+		if allocated_size+storage.Size > max_size || allocated_size+storage.Size > c.max_storage_size {
+			return fmt.Errorf("allocated size is greater than max storage size")
+		}
+
+		err = db.Run(
+			db.Model(&models.TenantStorage{}),
+			db.Equal("tenant_id", tenant_id),
+			db.Equal("plugin_id", plugin_id),
+			db.Inc(map[string]int64{"size": allocated_size}),
+		)
+		if err != nil {
+			return err
+		}
+	}
+
 	// delete from cache
 	return cache.Del(c.getCacheKey(tenant_id, plugin_id, key))
 }
@@ -61,5 +108,28 @@ func (c *Persistence) Delete(tenant_id string, plugin_id string, key string) err
 	if err != nil {
 		return err
 	}
-	return c.storage.Delete(tenant_id, plugin_id, key)
+
+	// state size
+	size, err := c.storage.StateSize(tenant_id, plugin_id, key)
+	if err != nil {
+		return nil
+	}
+
+	err = c.storage.Delete(tenant_id, plugin_id, key)
+	if err != nil {
+		return nil
+	}
+
+	// update storage size
+	err = db.Run(
+		db.Model(&models.TenantStorage{}),
+		db.Equal("tenant_id", tenant_id),
+		db.Equal("plugin_id", plugin_id),
+		db.Dec(map[string]int64{"size": size}),
+	)
+	if err != nil {
+		return err
+	}
+
+	return nil
 }

+ 35 - 3
internal/core/persistence/persistence_test.go

@@ -5,6 +5,7 @@ import (
 	"os"
 	"testing"
 
+	"github.com/langgenius/dify-plugin-daemon/internal/db"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/strings"
@@ -17,14 +18,25 @@ func TestPersistenceStoreAndLoad(t *testing.T) {
 	}
 	defer cache.Close()
 
+	db.Init(&app.Config{
+		DBUsername: "postgres",
+		DBPassword: "difyai123456",
+		DBHost:     "localhost",
+		DBPort:     5432,
+		DBDatabase: "dify_plugin_daemon",
+		DBSslMode:  "disable",
+	})
+	defer db.Close()
+
 	InitPersistence(&app.Config{
 		PersistenceStorageType:      "local",
 		PersistenceStorageLocalPath: "./persistence_storage",
+		PersistenceStorageMaxSize:   1024 * 1024 * 1024,
 	})
 
 	key := strings.RandomString(10)
 
-	if err := persistence.Save("tenant_id", "plugin_checksum", key, []byte("data")); err != nil {
+	if err := persistence.Save("tenant_id", "plugin_checksum", -1, key, []byte("data")); err != nil {
 		t.Fatalf("Failed to save data: %v", err)
 	}
 
@@ -64,15 +76,25 @@ func TestPersistenceSaveAndLoadWithLongKey(t *testing.T) {
 		t.Fatalf("Failed to init redis client: %v", err)
 	}
 	defer cache.Close()
+	db.Init(&app.Config{
+		DBUsername: "postgres",
+		DBPassword: "difyai123456",
+		DBHost:     "localhost",
+		DBPort:     5432,
+		DBDatabase: "dify_plugin_daemon",
+		DBSslMode:  "disable",
+	})
+	defer db.Close()
 
 	InitPersistence(&app.Config{
 		PersistenceStorageType:      "local",
 		PersistenceStorageLocalPath: "./persistence_storage",
+		PersistenceStorageMaxSize:   1024 * 1024 * 1024,
 	})
 
 	key := strings.RandomString(65)
 
-	if err := persistence.Save("tenant_id", "plugin_checksum", key, []byte("data")); err == nil {
+	if err := persistence.Save("tenant_id", "plugin_checksum", -1, key, []byte("data")); err == nil {
 		t.Fatalf("Expected error, got nil")
 	}
 }
@@ -83,15 +105,25 @@ func TestPersistenceDelete(t *testing.T) {
 		t.Fatalf("Failed to init redis client: %v", err)
 	}
 	defer cache.Close()
+	db.Init(&app.Config{
+		DBUsername: "postgres",
+		DBPassword: "difyai123456",
+		DBHost:     "localhost",
+		DBPort:     5432,
+		DBDatabase: "dify_plugin_daemon",
+		DBSslMode:  "disable",
+	})
+	defer db.Close()
 
 	InitPersistence(&app.Config{
 		PersistenceStorageType:      "local",
 		PersistenceStorageLocalPath: "./persistence_storage",
+		PersistenceStorageMaxSize:   1024 * 1024 * 1024,
 	})
 
 	key := strings.RandomString(10)
 
-	if err := persistence.Save("tenant_id", "plugin_checksum", key, []byte("data")); err != nil {
+	if err := persistence.Save("tenant_id", "plugin_checksum", -1, key, []byte("data")); err != nil {
 		t.Fatalf("Failed to save data: %v", err)
 	}
 

+ 45 - 2
internal/core/persistence/s3.go

@@ -1,7 +1,10 @@
 package persistence
 
 import (
+	"bytes"
 	"context"
+	"fmt"
+	"io"
 
 	"github.com/aws/aws-sdk-go-v2/aws"
 	"github.com/aws/aws-sdk-go-v2/config"
@@ -47,13 +50,53 @@ func NewS3Wrapper(region string, access_key string, secret_key string, bucket st
 }
 
 func (s *S3Wrapper) Save(tenant_id string, plugin_checksum string, key string, data []byte) error {
+	// save to s3
+	_, err := s.client.PutObject(context.TODO(), &s3.PutObjectInput{
+		Bucket: aws.String(s.bucket),
+		Key:    aws.String(key),
+		Body:   bytes.NewReader(data),
+	})
+	if err != nil {
+		return err
+	}
+
 	return nil
 }
 
 func (s *S3Wrapper) Load(tenant_id string, plugin_checksum string, key string) ([]byte, error) {
-	return nil, nil
+	// load from s3
+	resp, err := s.client.GetObject(context.TODO(), &s3.GetObjectInput{
+		Bucket: aws.String(s.bucket),
+		Key:    aws.String(key),
+	})
+	if err != nil {
+		return nil, err
+	}
+
+	return io.ReadAll(resp.Body)
 }
 
 func (s *S3Wrapper) Delete(tenant_id string, plugin_checksum string, key string) error {
-	return nil
+	_, err := s.client.DeleteObject(context.TODO(), &s3.DeleteObjectInput{
+		Bucket: aws.String(s.bucket),
+		Key:    aws.String(key),
+	})
+	return err
+}
+
+func (s *S3Wrapper) StateSize(tenant_id string, plugin_checksum string, key string) (int64, error) {
+	// get object size
+	resp, err := s.client.HeadObject(context.TODO(), &s3.HeadObjectInput{
+		Bucket: aws.String(s.bucket),
+		Key:    aws.String(key),
+	})
+	if err != nil {
+		return 0, err
+	}
+
+	if resp.ContentLength == nil {
+		return 0, fmt.Errorf("content length not found")
+	}
+
+	return *resp.ContentLength, nil
 }

+ 1 - 0
internal/core/persistence/type.go

@@ -4,4 +4,5 @@ type PersistenceStorage interface {
 	Save(tenant_id string, plugin_checksum string, key string, data []byte) error
 	Load(tenant_id string, plugin_checksum string, key string) ([]byte, error)
 	Delete(tenant_id string, plugin_checksum string, key string) error
+	StateSize(tenant_id string, plugin_checksum string, key string) (int64, error)
 }

+ 26 - 1
internal/core/plugin_daemon/backwards_invocation/task.go

@@ -482,7 +482,32 @@ func executeDifyInvocationStorageTask(
 			return
 		}
 
-		if err := persistence.Save(tenant_id, plugin_id.PluginID(), request.Key, data); err != nil {
+		session := handle.session
+		if session == nil {
+			handle.WriteError(fmt.Errorf("session not found"))
+			return
+		}
+
+		declaration := session.Declaration
+		if declaration == nil {
+			handle.WriteError(fmt.Errorf("declaration not found"))
+			return
+		}
+
+		resource := declaration.Resource.Permission
+		if resource == nil {
+			handle.WriteError(fmt.Errorf("resource not found"))
+			return
+		}
+
+		max_storage_size := int64(-1)
+
+		storage := resource.Storage
+		if storage != nil {
+			max_storage_size = int64(storage.Size)
+		}
+
+		if err := persistence.Save(tenant_id, plugin_id.PluginID(), max_storage_size, request.Key, data); err != nil {
 			handle.WriteError(fmt.Errorf("save data failed: %s", err.Error()))
 			return
 		}

+ 1 - 0
internal/db/init.go

@@ -86,6 +86,7 @@ func autoMigrate() error {
 		models.ToolInstallation{},
 		models.AIModelInstallation{},
 		models.InstallTask{},
+		models.TenantStorage{},
 	)
 }
 

+ 1 - 0
internal/types/app/config.go

@@ -60,6 +60,7 @@ type Config struct {
 	PersistenceStorageS3AccessKey string `envconfig:"PERSISTENCE_STORAGE_S3_ACCESS_KEY"`
 	PersistenceStorageS3SecretKey string `envconfig:"PERSISTENCE_STORAGE_S3_SECRET_KEY"`
 	PersistenceStorageS3Bucket    string `envconfig:"PERSISTENCE_STORAGE_S3_BUCKET"`
+	PersistenceStorageMaxSize     int64  `envconfig:"PERSISTENCE_STORAGE_MAX_SIZE"`
 
 	// force verifying signature for all plugins, not allowing install plugin not signed
 	ForceVerifyingSignature bool `envconfig:"FORCE_VERIFYING_SIGNATURE"`

+ 1 - 0
internal/types/app/default.go

@@ -22,6 +22,7 @@ func (config *Config) SetDefault() {
 	setDefaultString(&config.PluginStoragePath, "./storage/plugin")
 	setDefaultString(&config.PluginMediaCachePath, "./storage/assets")
 	setDefaultString(&config.PersistenceStorageLocalPath, "./storage/persistence")
+	setDefaultInt(&config.PersistenceStorageMaxSize, 100*1024*1024)
 	setDefaultString(&config.ProcessCachingPath, "./storage/subprocesses")
 	setDefaultString(&config.PluginPackageCachePath, "./storage/plugin_packages")
 	setDefaultString(&config.PythonInterpreterPath, "/usr/bin/python3")

+ 8 - 0
internal/types/models/storage.go

@@ -0,0 +1,8 @@
+package models
+
+type TenantStorage struct {
+	Model
+	TenantID string `gorm:"column:tenant_id;type:varchar(255);not null;index"`
+	PluginID string `gorm:"column:plugin_id;type:varchar(255);not null;index"`
+	Size     int64  `gorm:"column:size;type:bigint;not null"`
+}