Browse Source

feat: implement persistence

Yeuoly 11 months ago
parent
commit
873549528e

+ 0 - 21
internal/core/persistence/cache.go

@@ -1,21 +0,0 @@
-package persistence
-
-type Persistence struct {
-	storage PersistenceStorage
-}
-
-func (c *Persistence) Save(tenant_id string, key string, data []byte) error {
-	return nil
-}
-
-func (c *Persistence) Load(tenant_id string, key string) ([]byte, error) {
-	return nil, nil
-}
-
-func (c *Persistence) Delete(tenant_id string, key string) error {
-	return nil
-}
-
-func (c *Persistence) Scan(tenant_id string, prefix string, cursor int64) ([]string, error) {
-	return nil, nil
-}

+ 1 - 1
internal/core/persistence/init.go

@@ -22,7 +22,7 @@ func InitPersistence(config *app.Config) *Persistence {
 		}
 	} else if config.PersistenceStorageType == "local" {
 		return &Persistence{
-			storage: NewLocalWrapper(),
+			storage: NewLocalWrapper(config.PersistenceStorageLocalPath),
 		}
 	} else {
 		log.Panic("Invalid persistence storage type: %s", config.PersistenceStorageType)

+ 25 - 9
internal/core/persistence/local.go

@@ -1,19 +1,35 @@
 package persistence
 
-type LocalWrapper struct{}
+import (
+	"os"
+	"path"
+)
 
-func NewLocalWrapper() *LocalWrapper {
-	return &LocalWrapper{}
+type LocalWrapper struct {
+	path string
 }
 
-func (l *LocalWrapper) Save(tenant_id string, key string, data []byte) error {
-	return nil
+func NewLocalWrapper(path string) *LocalWrapper {
+	return &LocalWrapper{
+		path: path,
+	}
 }
 
-func (l *LocalWrapper) Load(tenant_id string, key string) ([]byte, error) {
-	return nil, nil
+func (l *LocalWrapper) getFilePath(tenant_id string, plugin_checksum string, key string) string {
+	return path.Join(l.path, tenant_id, plugin_checksum, key)
 }
 
-func (l *LocalWrapper) Delete(tenant_id string, key string) error {
-	return nil
+func (l *LocalWrapper) Save(tenant_id string, plugin_checksum string, key string, data []byte) error {
+	file_path := l.getFilePath(tenant_id, plugin_checksum, key)
+	return os.WriteFile(file_path, data, 0644)
+}
+
+func (l *LocalWrapper) Load(tenant_id string, plugin_checksum string, key string) ([]byte, error) {
+	file_path := l.getFilePath(tenant_id, plugin_checksum, key)
+	return os.ReadFile(file_path)
+}
+
+func (l *LocalWrapper) Delete(tenant_id string, plugin_checksum string, key string) error {
+	file_path := l.getFilePath(tenant_id, plugin_checksum, key)
+	return os.Remove(file_path)
 }

+ 49 - 0
internal/core/persistence/persistence.go

@@ -0,0 +1,49 @@
+package persistence
+
+import (
+	"encoding/hex"
+	"fmt"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+)
+
+type Persistence struct {
+	storage PersistenceStorage
+}
+
+const (
+	CACHE_KEY_PREFIX = "persistence:cache"
+)
+
+func (c *Persistence) getCacheKey(tenant_id string, plugin_checksum string) string {
+	return fmt.Sprintf("%s:%s:%s", CACHE_KEY_PREFIX, tenant_id, plugin_checksum)
+}
+
+func (c *Persistence) Save(tenant_id string, plugin_checksum string, key string, data []byte) error {
+	// add to cache
+	h := hex.EncodeToString(data)
+	return cache.SetMapOneField(c.getCacheKey(tenant_id, plugin_checksum), key, h)
+}
+
+func (c *Persistence) Load(tenant_id string, plugin_checksum string, key string) ([]byte, error) {
+	// check if the key exists in cache
+	h, err := cache.GetMapFieldString(c.getCacheKey(tenant_id, plugin_checksum), key)
+	if err != nil && err != cache.ErrNotFound {
+		return nil, err
+	}
+	if err == nil {
+		return hex.DecodeString(h)
+	}
+
+	// load from storage
+	return c.storage.Load(tenant_id, plugin_checksum, key)
+}
+
+func (c *Persistence) Delete(tenant_id string, plugin_checksum string, key string) error {
+	// delete from cache and storage
+	err := cache.DelMapField(c.getCacheKey(tenant_id, plugin_checksum), key)
+	if err != nil {
+		return err
+	}
+	return c.storage.Delete(tenant_id, plugin_checksum, key)
+}

+ 3 - 3
internal/core/persistence/s3.go

@@ -46,14 +46,14 @@ func NewS3Wrapper(region string, access_key string, secret_key string, bucket st
 	}, nil
 }
 
-func (s *S3Wrapper) Save(tenant_id string, key string, data []byte) error {
+func (s *S3Wrapper) Save(tenant_id string, plugin_checksum string, key string, data []byte) error {
 	return nil
 }
 
-func (s *S3Wrapper) Load(tenant_id string, key string) ([]byte, error) {
+func (s *S3Wrapper) Load(tenant_id string, plugin_checksum string, key string) ([]byte, error) {
 	return nil, nil
 }
 
-func (s *S3Wrapper) Delete(tenant_id string, key string) error {
+func (s *S3Wrapper) Delete(tenant_id string, plugin_checksum string, key string) error {
 	return nil
 }

+ 1 - 0
internal/core/persistence/sync.go

@@ -0,0 +1 @@
+package persistence

+ 3 - 3
internal/core/persistence/type.go

@@ -1,7 +1,7 @@
 package persistence
 
 type PersistenceStorage interface {
-	Save(tenant_id string, key string, data []byte) error
-	Load(tenant_id string, key string) ([]byte, error)
-	Delete(tenant_id string, key string) error
+	Save(tenant_id string, plugin_checksum string, key string, data []byte) error
+	Load(tenant_id string, plugin_checksum string, key string) ([]byte, error)
+	Delete(tenant_id string, plugin_checksum string, key string) error
 }

+ 17 - 0
internal/utils/cache/redis.go

@@ -208,6 +208,23 @@ func GetMapField[T any](key string, field string, context ...redis.Cmdable) (*T,
 	return &result, err
 }
 
+// GetMapFieldString get the string
+func GetMapFieldString(key string, field string, context ...redis.Cmdable) (string, error) {
+	if client == nil {
+		return "", ErrDBNotInit
+	}
+
+	val, err := getCmdable(context...).HGet(ctx, serialKey(key), field).Result()
+	if err != nil {
+		if err == redis.Nil {
+			return "", ErrNotFound
+		}
+		return "", err
+	}
+
+	return val, nil
+}
+
 // DelMapField delete the map field with key
 func DelMapField(key string, field string, context ...redis.Cmdable) error {
 	if client == nil {