Explorar el Código

feat: add redis scan hmap

Yeuoly hace 1 año
padre
commit
b2c4e5cd5f

+ 3 - 3
internal/cluster/init.go

@@ -9,7 +9,7 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 )
 
-type PluginLifeTime struct {
+type pluginLifeTime struct {
 	lifetime          entities.PluginRuntimeTimeLifeInterface
 	last_scheduled_at time.Time
 }
@@ -19,7 +19,7 @@ type Cluster struct {
 
 	port uint16
 
-	plugins     map[string]*PluginLifeTime
+	plugins     map[string]*pluginLifeTime
 	plugin_lock sync.Mutex
 
 	stop_chan chan bool
@@ -29,7 +29,7 @@ type Cluster struct {
 func NewCluster(config *app.Config) *Cluster {
 	return &Cluster{
 		port:      uint16(config.ServerPort),
-		plugins:   make(map[string]*PluginLifeTime),
+		plugins:   make(map[string]*pluginLifeTime),
 		stop_chan: make(chan bool),
 		stopped:   new(int32),
 	}

+ 15 - 3
internal/cluster/state.go

@@ -31,7 +31,7 @@ func (c *Cluster) RegisterPlugin(lifetime entities.PluginRuntimeTimeLifeInterfac
 
 	c.plugin_lock.Lock()
 	if !lifetime.Stopped() {
-		c.plugins[identity] = &PluginLifeTime{
+		c.plugins[identity] = &pluginLifeTime{
 			lifetime: lifetime,
 		}
 	} else {
@@ -46,10 +46,22 @@ func (c *Cluster) RegisterPlugin(lifetime entities.PluginRuntimeTimeLifeInterfac
 
 // SchedulePlugin schedules a plugin to the cluster
 func (c *Cluster) schedulePlugins() error {
-	return nil
+	c.plugin_lock.Lock()
+	defer c.plugin_lock.Unlock()
+
+	for i, v := range c.plugins {
+		if v.lifetime.Stopped() {
+			delete(c.plugins, i)
+			continue
+		}
+
+		if err := c.doPluginStateUpdate(v); err != nil {
+
+		}
+	}
 }
 
 // doPluginUpdate updates the plugin state and schedule the plugin
-func (c *Cluster) doPluginStateUpdate(lifetime entities.PluginRuntimeTimeLifeInterface) error {
+func (c *Cluster) doPluginStateUpdate(lifetime *pluginLifeTime) error {
 	return nil
 }

+ 1 - 0
internal/core/plugin_manager/watcher.go

@@ -107,6 +107,7 @@ func loadNewPlugins(root_path string) <-chan entities.PluginRuntime {
 						RelativePath: path.Join(root_path, plugin.Name()),
 						ActiveAt:     nil,
 						Verified:     err == nil,
+						Identity:     configuration.Identity(),
 					},
 				}
 			}

+ 59 - 0
internal/utils/cache/redis.go

@@ -3,6 +3,7 @@ package cache
 import (
 	"context"
 	"errors"
+	"fmt"
 	"strings"
 	"time"
 
@@ -230,6 +231,64 @@ func GetMap[V any](key string, context ...redis.Cmdable) (map[string]V, error) {
 	return result, nil
 }
 
+func ScanMap[V any](key string, prefix string, context ...redis.Cmdable) (map[string]V, error) {
+	if client == nil {
+		return nil, ErrDBNotInit
+	}
+
+	result := make(map[string]V)
+
+	ScanMapAsync[V](key, prefix, func(m map[string]V) error {
+		for k, v := range m {
+			result[k] = v
+		}
+
+		return nil
+	})
+
+	return result, nil
+}
+
+func ScanMapAsync[V any](key string, prefix string, fn func(map[string]V) error, context ...redis.Cmdable) error {
+	if client == nil {
+		return ErrDBNotInit
+	}
+
+	cursor := uint64(0)
+
+	for {
+		kvs, new_cursor, err := getCmdable(context...).
+			HScan(ctx, serialKey(key), cursor, fmt.Sprintf("%s*", prefix), 32).
+			Result()
+
+		if err != nil {
+			return err
+		}
+
+		result := make(map[string]V)
+		for i := 0; i < len(kvs); i += 2 {
+			value, err := parser.UnmarshalJson[V](kvs[i+1])
+			if err != nil {
+				continue
+			}
+
+			result[kvs[i]] = value
+		}
+
+		if err := fn(result); err != nil {
+			return err
+		}
+
+		if new_cursor == 0 {
+			break
+		}
+
+		cursor = new_cursor
+	}
+
+	return nil
+}
+
 func SetNX[T any](key string, value T, expire time.Duration, context ...redis.Cmdable) (bool, error) {
 	if client == nil {
 		return false, ErrDBNotInit

+ 80 - 0
internal/utils/cache/redis_test.go

@@ -113,3 +113,83 @@ func TestRedisTransaction(t *testing.T) {
 		return
 	}
 }
+
+func TestRedisScanMap(t *testing.T) {
+	// get redis connection
+	if err := getRedisConnection(t); err != nil {
+		t.Errorf("get redis connection failed: %v", err)
+		return
+	}
+	defer Close()
+
+	type s struct {
+		Field string `json:"field"`
+	}
+
+	err := SetMapOneField(strings.Join([]string{TEST_PREFIX, "map"}, ":"), "key1", s{Field: "value1"})
+	if err != nil {
+		t.Errorf("set map failed: %v", err)
+		return
+	}
+	defer Del(strings.Join([]string{TEST_PREFIX, "map"}, ":"))
+	err = SetMapOneField(strings.Join([]string{TEST_PREFIX, "map"}, ":"), "key2", s{Field: "value2"})
+	if err != nil {
+		t.Errorf("set map failed: %v", err)
+		return
+	}
+	err = SetMapOneField(strings.Join([]string{TEST_PREFIX, "map"}, ":"), "key3", s{Field: "value3"})
+	if err != nil {
+		t.Errorf("set map failed: %v", err)
+		return
+	}
+	err = SetMapOneField(strings.Join([]string{TEST_PREFIX, "map"}, ":"), "4", s{Field: "value4"})
+	if err != nil {
+		t.Errorf("set map failed: %v", err)
+		return
+	}
+
+	data, err := ScanMap[s](strings.Join([]string{TEST_PREFIX, "map"}, ":"), "key")
+	if err != nil {
+		t.Errorf("scan map failed: %v", err)
+		return
+	}
+
+	if len(data) != 3 {
+		t.Errorf("scan map should return 3")
+		return
+	}
+
+	if data["key1"].Field != "value1" {
+		t.Errorf("scan map should return value1")
+		return
+	}
+
+	if data["key2"].Field != "value2" {
+		t.Errorf("scan map should return value2")
+		return
+	}
+
+	if data["key3"].Field != "value3" {
+		t.Errorf("scan map should return value3")
+		return
+	}
+
+	err = ScanMapAsync[s](strings.Join([]string{TEST_PREFIX, "map"}, ":"), "4", func(m map[string]s) error {
+		if len(m) != 1 {
+			t.Errorf("scan map async should return 1")
+			return errors.New("scan map async should return 1")
+		}
+
+		if m["4"].Field != "value4" {
+			t.Errorf("scan map async should return value4")
+			return errors.New("scan map async should return value4")
+		}
+
+		return nil
+	})
+
+	if err != nil {
+		t.Errorf("scan map async failed: %v", err)
+		return
+	}
+}