Browse Source

feat: connection keys

Yeuoly 1 year ago
parent
commit
57f56125b0

+ 7 - 0
.github/workflows/tests.yml

@@ -17,6 +17,13 @@ jobs:
   test:
     runs-on: ubuntu-latest
 
+    services:
+      redis:
+        image: redis
+        ports:
+          - 6379:6379
+        options: --requirepass difyai123456
+
     steps:
       - uses: actions/checkout@v2
 

+ 121 - 0
internal/core/plugin_manager/remote_manager/connection_key.go

@@ -0,0 +1,121 @@
+package remote_manager
+
+import (
+	"strings"
+	"time"
+
+	"github.com/google/uuid"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+	"github.com/redis/go-redis/v9"
+)
+
+/*
+ * When connect to dify plugin daemon server, we need identify who is connecting.
+ * Therefore, we need to a key-value pair to connect a random string to a tenant.
+ *
+ * $random_key => $tenant_id, $user_id
+ * $tenant_id => $random_id
+ *
+ * It's a double mapping for each key, therefore a transaction is needed.
+ * */
+
+type ConnectionInfo struct {
+	TenantId string `json:"tenant_id" validate:"required"`
+}
+
+func (c ConnectionInfo) MarshalBinary() ([]byte, error) {
+	return parser.MarshalJsonBytes(c), nil
+}
+
+type Key struct {
+	Key string `json:"key" validate:"required"`
+}
+
+func (k Key) MarshalBinary() ([]byte, error) {
+	return parser.MarshalJsonBytes(k), nil
+}
+
+const (
+	CONNECTION_KEY_MANAGER_KEY2ID_PREFIX = "remote:key:manager:key2id"
+	CONNECTION_KEY_MANAGER_ID2KEY_PREFIX = "remote:key:manager:id2key"
+	CONNECTION_KEY_LOCK                  = "connection_lock"
+	CONNECTION_KEY_EXPIRE_TIME           = time.Minute * 15
+)
+
+// returns a random string, create it if not exists
+func GetConnectionKey(info ConnectionInfo) (string, error) {
+	var key *Key
+	var err error
+
+	key, err = cache.Get[Key](
+		strings.Join([]string{CONNECTION_KEY_MANAGER_ID2KEY_PREFIX, info.TenantId}, ":"),
+	)
+
+	if err == cache.ErrNotFound {
+		err := cache.Transaction(func(p redis.Pipeliner) error {
+			k := uuid.New().String()
+			_, err = cache.SetNX(
+				strings.Join([]string{CONNECTION_KEY_MANAGER_ID2KEY_PREFIX, info.TenantId}, ":"),
+				Key{Key: k},
+				CONNECTION_KEY_EXPIRE_TIME,
+				p,
+			)
+			if err != nil {
+				return err
+			}
+
+			_, err = cache.SetNX(
+				strings.Join([]string{CONNECTION_KEY_MANAGER_KEY2ID_PREFIX, k}, ":"),
+				info,
+				CONNECTION_KEY_EXPIRE_TIME,
+				p,
+			)
+			if err != nil {
+				return err
+			}
+
+			key = &Key{Key: k}
+
+			return nil
+		})
+
+		if err != nil {
+			return "", err
+		}
+	}
+
+	if err != nil {
+		return "", err
+	}
+
+	return key.Key, nil
+}
+
+// get connection info by key
+func GetConnectionInfo(key string) (*ConnectionInfo, error) {
+	info, err := cache.Get[ConnectionInfo](
+		strings.Join([]string{CONNECTION_KEY_MANAGER_KEY2ID_PREFIX, key}, ":"),
+	)
+
+	if err != nil {
+		return nil, err
+	}
+
+	return info, nil
+}
+
+// clear connection key
+func ClearConnectionKey(tenant_id string) error {
+	key, err := cache.Get[Key](
+		strings.Join([]string{CONNECTION_KEY_MANAGER_ID2KEY_PREFIX, tenant_id}, ":"),
+	)
+
+	if err != nil {
+		return err
+	}
+
+	cache.Del(strings.Join([]string{CONNECTION_KEY_MANAGER_KEY2ID_PREFIX, key.Key}, ":"))
+	cache.Del(strings.Join([]string{CONNECTION_KEY_MANAGER_ID2KEY_PREFIX, tenant_id}, ":"))
+	return nil
+}

+ 61 - 0
internal/core/plugin_manager/remote_manager/connection_key_test.go

@@ -0,0 +1,61 @@
+package remote_manager
+
+import (
+	"testing"
+
+	"github.com/google/uuid"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+)
+
+func TestConnectionKey(t *testing.T) {
+	err := cache.InitRedisClient("0.0.0.0:6379", "difyai123456")
+	if err != nil {
+		t.Errorf("init redis client failed: %v", err)
+		return
+	}
+	defer cache.Close()
+
+	// test connection key
+	key, err := GetConnectionKey(ConnectionInfo{
+		TenantId: "abc",
+	})
+
+	if err != nil {
+		t.Errorf("get connection key failed: %v", err)
+		return
+	}
+
+	defer ClearConnectionKey("abc")
+
+	_, err = uuid.Parse(key)
+	if err != nil {
+		t.Errorf("connection key is not a valid uuid: %v", err)
+		return
+	}
+
+	// test connection key with the same tenant id
+	key2, err := GetConnectionKey(ConnectionInfo{
+		TenantId: "abc",
+	})
+
+	if err != nil {
+		t.Errorf("get connection key failed: %v", err)
+		return
+	}
+
+	if key != key2 {
+		t.Errorf("connection key is not the same: %s, %s", key, key2)
+		return
+	}
+
+	connection_info, err := GetConnectionInfo(key)
+	if err != nil {
+		t.Errorf("get connection info failed: %v", err)
+		return
+	}
+
+	if connection_info.TenantId != "abc" {
+		t.Errorf("connection info is not the same: %v", connection_info)
+		return
+	}
+}

+ 0 - 4
internal/types/entities/plugin_entities/plugin_declaration_test.go

@@ -64,10 +64,6 @@ func TestPluginDeclarationFullTest(t *testing.T) {
 		t.Errorf("author not equal")
 		return
 	}
-	if new_declaration.CreatedAt.GoString() != declaration.CreatedAt.GoString() {
-		t.Errorf("created_at not equal")
-		return
-	}
 	if new_declaration.Resource.Memory != declaration.Resource.Memory {
 		t.Errorf("memory not equal")
 		return

+ 114 - 32
internal/utils/cache/redis.go

@@ -15,6 +15,7 @@ var (
 	ctx    = context.Background()
 
 	ErrDBNotInit = errors.New("redis client not init")
+	ErrNotFound  = errors.New("key not found")
 )
 
 func InitRedisClient(addr, password string) error {
@@ -30,6 +31,22 @@ func InitRedisClient(addr, password string) error {
 	return nil
 }
 
+func Close() error {
+	if client == nil {
+		return ErrDBNotInit
+	}
+
+	return client.Close()
+}
+
+func getCmdable(context ...redis.Cmdable) redis.Cmdable {
+	if len(context) > 0 {
+		return context[0]
+	}
+
+	return client
+}
+
 func serialKey(keys ...string) string {
 	return strings.Join(append(
 		[]string{"plugin_daemon"},
@@ -37,91 +54,129 @@ func serialKey(keys ...string) string {
 	), ":")
 }
 
-func Store(key string, value any, time time.Duration) error {
-	if client == nil {
-		return ErrDBNotInit
-	}
-
-	return client.Set(ctx, serialKey(key), value, time).Err()
+func Store(key string, value any, time time.Duration, context ...redis.Cmdable) error {
+	return getCmdable(context...).Set(ctx, serialKey(key), value, time).Err()
 }
 
-func Get[T any](key string) (*T, error) {
+func Get[T any](key string, context ...redis.Cmdable) (*T, error) {
 	if client == nil {
 		return nil, ErrDBNotInit
 	}
 
-	val, err := client.Get(ctx, serialKey(key)).Result()
+	val, err := getCmdable(context...).Get(ctx, serialKey(key)).Result()
 	if err != nil {
+		if err == redis.Nil {
+			return nil, ErrNotFound
+		}
 		return nil, err
 	}
 
+	if val == "" {
+		return nil, ErrNotFound
+	}
+
 	result, err := parser.UnmarshalJson[T](val)
 	return &result, err
 }
 
-func Del(key string) error {
+func GetString(key string, context ...redis.Cmdable) (string, error) {
+	if client == nil {
+		return "", ErrDBNotInit
+	}
+
+	v, err := getCmdable(context...).Get(ctx, serialKey(key)).Result()
+	if err != nil {
+		if err == redis.Nil {
+			return "", ErrNotFound
+		}
+	}
+
+	return v, err
+}
+
+func Del(key string, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
 	}
 
-	return client.Del(ctx, serialKey(key)).Err()
+	_, err := getCmdable(context...).Del(ctx, serialKey(key)).Result()
+	if err != nil {
+		if err == redis.Nil {
+			return ErrNotFound
+		}
+
+		return err
+	}
+
+	return nil
 }
 
-func Exist(key string) (int64, error) {
+func Exist(key string, context ...redis.Cmdable) (int64, error) {
 	if client == nil {
 		return 0, ErrDBNotInit
 	}
 
-	return client.Exists(ctx, serialKey(key)).Result()
+	return getCmdable(context...).Exists(ctx, serialKey(key)).Result()
 }
 
-func Increase(key string) (int64, error) {
+func Increase(key string, context ...redis.Cmdable) (int64, error) {
 	if client == nil {
 		return 0, ErrDBNotInit
 	}
 
-	return client.Incr(ctx, serialKey(key)).Result()
+	num, err := getCmdable(context...).Incr(ctx, serialKey(key)).Result()
+	if err != nil {
+		if err == redis.Nil {
+			return 0, ErrNotFound
+		}
+		return 0, err
+	}
+
+	return num, nil
 }
 
-func Decrease(key string) (int64, error) {
+func Decrease(key string, context ...redis.Cmdable) (int64, error) {
 	if client == nil {
 		return 0, ErrDBNotInit
 	}
 
-	return client.Decr(ctx, serialKey(key)).Result()
+	return getCmdable(context...).Decr(ctx, serialKey(key)).Result()
 }
 
-func SetExpire(key string, time time.Duration) error {
+func SetExpire(key string, time time.Duration, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
 	}
 
-	return client.Expire(ctx, serialKey(key), time).Err()
+	return getCmdable(context...).Expire(ctx, serialKey(key), time).Err()
 }
 
-func SetMapField(key string, v map[string]any) error {
+func SetMapField(key string, v map[string]any, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
 	}
 
-	return client.HMSet(ctx, serialKey(key), v).Err()
+	return getCmdable(context...).HMSet(ctx, serialKey(key), v).Err()
 }
 
-func SetMapOneField(key string, field string, value any) error {
+func SetMapOneField(key string, field string, value any, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
 	}
 
-	return client.HSet(ctx, serialKey(key), field, value).Err()
+	return getCmdable(context...).HSet(ctx, serialKey(key), field, value).Err()
 }
 
-func GetMapField[T any](key string, field string) (*T, error) {
+func GetMapField[T any](key string, field string, context ...redis.Cmdable) (*T, error) {
 	if client == nil {
 		return nil, ErrDBNotInit
 	}
 
-	val, err := client.HGet(ctx, serialKey(key), field).Result()
+	val, err := getCmdable(context...).HGet(ctx, serialKey(key), field).Result()
 	if err != nil {
+		if err == redis.Nil {
+			return nil, ErrNotFound
+		}
 		return nil, err
 	}
 
@@ -129,21 +184,24 @@ func GetMapField[T any](key string, field string) (*T, error) {
 	return &result, err
 }
 
-func DelMapField(key string, field string) error {
+func DelMapField(key string, field string, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
 	}
 
-	return client.HDel(ctx, serialKey(key), field).Err()
+	return getCmdable(context...).HDel(ctx, serialKey(key), field).Err()
 }
 
-func GetMap[V any](key string) (map[string]V, error) {
+func GetMap[V any](key string, context ...redis.Cmdable) (map[string]V, error) {
 	if client == nil {
 		return nil, ErrDBNotInit
 	}
 
-	val, err := client.HGetAll(ctx, serialKey(key)).Result()
+	val, err := getCmdable(context...).HGetAll(ctx, serialKey(key)).Result()
 	if err != nil {
+		if err == redis.Nil {
+			return nil, ErrNotFound
+		}
 		return nil, err
 	}
 
@@ -160,13 +218,21 @@ func GetMap[V any](key string) (map[string]V, error) {
 	return result, nil
 }
 
+func SetNX[T any](key string, value T, expire time.Duration, context ...redis.Cmdable) (bool, error) {
+	if client == nil {
+		return false, ErrDBNotInit
+	}
+
+	return getCmdable(context...).SetNX(ctx, serialKey(key), value, expire).Result()
+}
+
 var (
 	ErrLockTimeout = errors.New("lock timeout")
 )
 
 // Lock key, expire time takes responsibility for expiration time
 // try_lock_timeout takes responsibility for the timeout of trying to lock
-func Lock(key string, expire time.Duration, try_lock_timeout time.Duration) error {
+func Lock(key string, expire time.Duration, try_lock_timeout time.Duration, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
 	}
@@ -177,7 +243,7 @@ func Lock(key string, expire time.Duration, try_lock_timeout time.Duration) erro
 	defer ticker.Stop()
 
 	for range ticker.C {
-		if _, err := client.SetNX(ctx, serialKey(key), "1", expire).Result(); err == nil {
+		if _, err := getCmdable(context...).SetNX(ctx, serialKey(key), "1", expire).Result(); err == nil {
 			return nil
 		}
 
@@ -190,10 +256,26 @@ func Lock(key string, expire time.Duration, try_lock_timeout time.Duration) erro
 	return nil
 }
 
-func Unlock(key string) error {
+func Unlock(key string, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
 	}
 
-	return client.Del(ctx, serialKey(key)).Err()
+	return getCmdable(context...).Del(ctx, serialKey(key)).Err()
+}
+
+func Transaction(fn func(redis.Pipeliner) error) error {
+	if client == nil {
+		return ErrDBNotInit
+	}
+
+	return client.Watch(ctx, func(tx *redis.Tx) error {
+		_, err := tx.TxPipelined(ctx, func(p redis.Pipeliner) error {
+			return fn(p)
+		})
+		if err == redis.Nil {
+			return nil
+		}
+		return err
+	})
 }

+ 115 - 0
internal/utils/cache/redis_test.go

@@ -0,0 +1,115 @@
+package cache
+
+import (
+	"errors"
+	"strings"
+	"testing"
+	"time"
+
+	"github.com/redis/go-redis/v9"
+)
+
+const (
+	TEST_PREFIX = "test"
+)
+
+func getRedisConnection(t *testing.T) error {
+	return InitRedisClient("0.0.0.0:6379", "difyai123456")
+}
+
+func TestRedisConnection(t *testing.T) {
+	// get redis connection
+	if err := getRedisConnection(t); err != nil {
+		t.Errorf("get redis connection failed: %v", err)
+		return
+	}
+
+	// close
+	if err := Close(); err != nil {
+		t.Errorf("close redis client failed: %v", err)
+		return
+	}
+}
+
+func TestRedisTransaction(t *testing.T) {
+	// get redis connection
+	if err := getRedisConnection(t); err != nil {
+		t.Errorf("get redis connection failed: %v", err)
+		return
+	}
+	defer Close()
+
+	// test transaction
+	err := Transaction(func(p redis.Pipeliner) error {
+		// set key
+		if err := Store(
+			strings.Join([]string{TEST_PREFIX, "key"}, ":"),
+			"value",
+			time.Second,
+			p,
+		); err != nil {
+			t.Errorf("store key failed: %v", err)
+			return err
+		}
+
+		return errors.New("test transaction error")
+	})
+
+	if err == nil {
+		t.Errorf("transaction should return error")
+		return
+	}
+
+	// get key
+	value, err := GetString(
+		strings.Join([]string{TEST_PREFIX, "key"}, ":"),
+	)
+
+	if err != ErrNotFound {
+		t.Errorf("key should not exist")
+		return
+	}
+
+	if value != "" {
+		t.Errorf("value should be empty")
+		return
+	}
+
+	// test success transaction
+	err = Transaction(func(p redis.Pipeliner) error {
+		// set key
+		if err := Store(
+			strings.Join([]string{TEST_PREFIX, "key"}, ":"),
+			"value",
+			time.Second,
+			p,
+		); err != nil {
+			t.Errorf("store key failed: %v", err)
+			return err
+		}
+
+		return nil
+	})
+
+	if err != nil {
+		t.Errorf("transaction should not return error")
+		return
+	}
+
+	defer Del(strings.Join([]string{TEST_PREFIX, "key"}, ":"))
+
+	// get key
+	value, err = GetString(
+		strings.Join([]string{TEST_PREFIX, "key"}, ":"),
+	)
+
+	if err != nil {
+		t.Errorf("get key failed: %v", err)
+		return
+	}
+
+	if value != "value" {
+		t.Errorf("value should be value")
+		return
+	}
+}