Explorar el Código

feat: cluster lifetime management

Yeuoly hace 1 año
padre
commit
69c402fcea

+ 14 - 2
internal/cluster/init.go

@@ -5,8 +5,10 @@ import (
 	"sync/atomic"
 	"time"
 
+	"github.com/langgenius/dify-plugin-daemon/internal/cluster/cluster_id"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/mapping"
 )
 
 type pluginLifeTime struct {
@@ -15,21 +17,31 @@ type pluginLifeTime struct {
 }
 
 type Cluster struct {
+	// id is the unique id of the cluster
+	id string
+
+	// i_am_master is the flag to indicate whether the current node is the master node
 	i_am_master bool
 
+	// port is the health check port of the cluster
 	port uint16
 
-	plugins     map[string]*pluginLifeTime
+	// plugins stores all the plugin life time of the cluster
+	plugins     mapping.Map[string, *pluginLifeTime]
 	plugin_lock sync.Mutex
 
+	// nodes stores all the nodes of the cluster
+	nodes mapping.Map[string, node]
+
+	// signals for waiting for the cluster to stop
 	stop_chan chan bool
 	stopped   *int32
 }
 
 func NewCluster(config *app.Config) *Cluster {
 	return &Cluster{
+		id:        cluster_id.GetInstanceID(),
 		port:      uint16(config.ServerPort),
-		plugins:   make(map[string]*pluginLifeTime),
 		stop_chan: make(chan bool),
 		stopped:   new(int32),
 	}

+ 5 - 6
internal/cluster/preemptive.go

@@ -5,7 +5,6 @@ import (
 	"net"
 	"time"
 
-	"github.com/langgenius/dify-plugin-daemon/internal/cluster/cluster_id"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/network"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
@@ -44,7 +43,7 @@ func (c *Cluster) lockMaster() (bool, error) {
 	var final_error error
 
 	for i := 0; i < 3; i++ {
-		if success, err := cache.SetNX(PREEMPTION_LOCK_KEY, cluster_id.GetInstanceID(), MASTER_LOCK_EXPIRED_TIME); err != nil {
+		if success, err := cache.SetNX(PREEMPTION_LOCK_KEY, c.id, MASTER_LOCK_EXPIRED_TIME); err != nil {
 			// try again
 			if final_error == nil {
 				final_error = err
@@ -73,13 +72,13 @@ func (c *Cluster) updateMaster() error {
 
 // update the status of the node
 func (c *Cluster) updateNodeStatus() error {
-	if err := c.LockNodeStatus(cluster_id.GetInstanceID()); err != nil {
+	if err := c.LockNodeStatus(c.id); err != nil {
 		return err
 	}
-	defer c.UnlockNodeStatus(cluster_id.GetInstanceID())
+	defer c.UnlockNodeStatus(c.id)
 
 	// update the status of the node
-	node_status, err := cache.GetMapField[node](CLUSTER_STATUS_HASH_MAP_KEY, cluster_id.GetInstanceID())
+	node_status, err := cache.GetMapField[node](CLUSTER_STATUS_HASH_MAP_KEY, c.id)
 	if err != nil {
 		if err == cache.ErrNotFound {
 			// try to get ips configs
@@ -125,7 +124,7 @@ func (c *Cluster) updateNodeStatus() error {
 	node_status.LastPingAt = time.Now().Unix()
 
 	// update the status of the node
-	if err := cache.SetMapOneField(CLUSTER_STATUS_HASH_MAP_KEY, cluster_id.GetInstanceID(), node_status); err != nil {
+	if err := cache.SetMapOneField(CLUSTER_STATUS_HASH_MAP_KEY, c.id, node_status); err != nil {
 		return err
 	}
 

+ 100 - 3
internal/cluster/state.go

@@ -2,8 +2,10 @@ package cluster
 
 import (
 	"sync/atomic"
+	"time"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 )
 
@@ -22,17 +24,29 @@ func (c *Cluster) RegisterPlugin(lifetime entities.PluginRuntimeTimeLifeInterfac
 		}
 	}
 
+	l := &pluginLifeTime{
+		lifetime: lifetime,
+	}
+
 	lifetime.OnStop(func() {
 		c.plugin_lock.Lock()
-		delete(c.plugins, identity)
+		c.plugins.Delete(identity)
+		// remove plugin state
+		c.doPluginStateUpdate(l)
 		c.plugin_lock.Unlock()
 		close()
 	})
 
 	c.plugin_lock.Lock()
 	if !lifetime.Stopped() {
-		c.plugins[identity] = &pluginLifeTime{
-			lifetime: lifetime,
+		c.plugins.Store(identity, l)
+
+		// do plugin state update immediately
+		err = c.doPluginStateUpdate(l)
+		if err != nil {
+			close()
+			c.plugin_lock.Unlock()
+			return err
 		}
 	} else {
 		close()
@@ -44,12 +58,95 @@ func (c *Cluster) RegisterPlugin(lifetime entities.PluginRuntimeTimeLifeInterfac
 	return nil
 }
 
+const (
+	PLUGIN_STATE_MAP_KEY = "plugin_state"
+)
+
+func (c *Cluster) getPluginStateKey(node_id string, plugin_id string) string {
+	return node_id + ":" + plugin_id
+}
+
+func (c *Cluster) getScanPluginsByNodeKey(node_id string) string {
+	return node_id + ":*"
+}
+
+func (c *Cluster) getScanPluginsByIdKey(plugin_id string) string {
+	return "*:" + plugin_id
+}
+
+func (c *Cluster) FetchPluginAvailableNodes(hashed_plugin_id string) ([]string, error) {
+	states, err := cache.ScanMap[entities.PluginRuntimeState](PLUGIN_STATE_MAP_KEY, c.getScanPluginsByIdKey(hashed_plugin_id))
+	if err != nil {
+		return nil, err
+	}
+
+	nodes := make([]string, 0)
+	for key := range states {
+		// split key into node_id and plugin_id
+		if len(key) < len(hashed_plugin_id)+1 {
+			log.Error("unexpected plugin state key: %s", key)
+			continue
+		}
+		node_id := key[:len(key)-len(hashed_plugin_id)-1]
+		nodes = append(nodes, node_id)
+	}
+
+	return nodes, nil
+}
+
 // SchedulePlugin schedules a plugin to the cluster
+// it will walk through the plugin state map and update all the states
+// as for the plugin has exited, normally, it will be removed automatically
+// but once a plugin is not removed, it will be gc by the master node
 func (c *Cluster) schedulePlugins() error {
+	c.plugins.Range(func(key string, value *pluginLifeTime) bool {
+		// do plugin state update
+		err := c.doPluginStateUpdate(value)
+		if err != nil {
+			log.Error("failed to update plugin state: %s", err.Error())
+		}
+
+		return true
+	})
+
 	return nil
 }
 
 // doPluginUpdate updates the plugin state and schedule the plugin
 func (c *Cluster) doPluginStateUpdate(lifetime *pluginLifeTime) error {
+	state := lifetime.lifetime.RuntimeState()
+	hash_identity, err := lifetime.lifetime.HashedIdentity()
+	if err != nil {
+		return err
+	}
+
+	identity, err := lifetime.lifetime.Identity()
+	if err != nil {
+		return err
+	}
+
+	state_key := c.getPluginStateKey(c.id, hash_identity)
+
+	// check if the plugin has been removed
+	if !c.plugins.Exits(identity) {
+		// remove state
+		err = c.removePluginState(hash_identity)
+		if err != nil {
+			return err
+		}
+	} else {
+		// update plugin state
+		state.ScheduledAt = &[]time.Time{time.Now()}[0]
+		lifetime.lifetime.UpdateState(state)
+		err = cache.SetMapOneField(PLUGIN_STATE_MAP_KEY, state_key, state)
+		if err != nil {
+			return err
+		}
+	}
+
 	return nil
 }
+
+func (c *Cluster) removePluginState(hashed_identity string) error {
+	return cache.DelMapField(PLUGIN_STATE_MAP_KEY, c.getPluginStateKey(c.id, hashed_identity))
+}

+ 4 - 5
internal/cluster/vote.go

@@ -8,7 +8,6 @@ import (
 	"sort"
 	"time"
 
-	"github.com/langgenius/dify-plugin-daemon/internal/cluster/cluster_id"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/http_requests"
 )
@@ -32,7 +31,7 @@ func (c *Cluster) voteIps() error {
 	}
 
 	for node_id, node_status := range nodes {
-		if node_id == cluster_id.GetInstanceID() {
+		if node_id == c.id {
 			continue
 		}
 
@@ -41,7 +40,7 @@ func (c *Cluster) voteIps() error {
 		for _, ip := range node_status.Ips {
 			// skip ips which have already been voted by current node in the last 5 minutes
 			for _, vote := range ip.Votes {
-				if vote.NodeID == cluster_id.GetInstanceID() {
+				if vote.NodeID == c.id {
 					if time.Since(time.Unix(vote.VotedAt, 0)) < time.Minute*5 && !vote.Failed {
 						continue
 					} else if time.Since(time.Unix(vote.VotedAt, 0)) < time.Minute*30 && vote.Failed {
@@ -75,7 +74,7 @@ func (c *Cluster) voteIps() error {
 				// check if the ip has already voted
 				already_voted := false
 				for j, vote := range ip.Votes {
-					if vote.NodeID == cluster_id.GetInstanceID() {
+					if vote.NodeID == c.id {
 						node_status.Ips[i].Votes[j].VotedAt = time.Now().Unix()
 						node_status.Ips[i].Votes[j].Failed = !success
 						already_voted = true
@@ -85,7 +84,7 @@ func (c *Cluster) voteIps() error {
 				// add a new vote
 				if !already_voted {
 					node_status.Ips[i].Votes = append(node_status.Ips[i].Votes, vote{
-						NodeID:  cluster_id.GetInstanceID(),
+						NodeID:  c.id,
 						VotedAt: time.Now().Unix(),
 						Failed:  !success,
 					})

+ 3 - 1
internal/core/plugin_manager/lifetime.go

@@ -70,6 +70,8 @@ func (p *PluginManager) lifetime(config *app.Config, r entities.PluginRuntimeInt
 		time.Sleep(5 * time.Second)
 
 		// add restart times
-		r.RuntimeState().Restarts++
+		state := r.RuntimeState()
+		state.Restarts++
+		r.UpdateState(state)
 	}
 }

+ 118 - 0
internal/db/cache.go

@@ -0,0 +1,118 @@
+package db
+
+import (
+	"fmt"
+	"reflect"
+	"time"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+)
+
+const (
+	CACHE_PREFIX      = "cache"
+	CACHE_EXPIRE_TIME = time.Minute * 5
+)
+
+type KeyValuePair struct {
+	Key string
+	Val any
+}
+
+type GetCachePayload[T any] struct {
+	Getter   func() (*T, error)
+	CacheKey []KeyValuePair
+}
+
+func joinCacheKey(typename string, pairs []KeyValuePair) string {
+	cache_key := CACHE_PREFIX
+	for _, kv := range pairs {
+		cache_key += ":" + kv.Key + ":"
+		// convert value to string
+		cache_key += fmt.Sprintf("%v", kv.Val)
+	}
+	return cache_key
+}
+
+func GetCache[T any](p *GetCachePayload[T]) (*T, error) {
+	var t T
+	typename := reflect.TypeOf(t).String()
+
+	// join cache key
+	cache_key := joinCacheKey(typename, p.CacheKey)
+
+	// get cache
+	val, err := cache.Get[T](cache_key)
+	if err == nil {
+		return val, nil
+	}
+
+	if err == cache.ErrNotFound {
+		// get from getter
+		val, err := p.Getter()
+		if err != nil {
+			return nil, err
+		}
+
+		// set cache
+		err = cache.Store(cache_key, val, CACHE_EXPIRE_TIME)
+		if err != nil {
+			return nil, err
+		}
+
+		return val, nil
+	} else {
+		return nil, err
+	}
+}
+
+type DeleteCachePayload[T any] struct {
+	Delete   func() error
+	CacheKey []KeyValuePair
+}
+
+func DeleteCache[T any](p *DeleteCachePayload[T]) error {
+	var t T
+	typename := reflect.TypeOf(t).String()
+
+	// join cache key
+	cache_key := joinCacheKey(typename, p.CacheKey)
+
+	// delete cache
+	err := cache.Del(cache_key)
+	if err != nil {
+		return err
+	}
+
+	err = p.Delete()
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+type UpdateCachePayload[T any] struct {
+	Update   func() error
+	CacheKey []KeyValuePair
+}
+
+func UpdateCache[T any](p *UpdateCachePayload[T]) error {
+	var t T
+	typename := reflect.TypeOf(t).String()
+
+	// join cache key
+	cache_key := joinCacheKey(typename, p.CacheKey)
+
+	err := p.Update()
+	if err != nil {
+		return err
+	}
+
+	// delete cache
+	err = cache.Del(cache_key)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}

+ 5 - 1
internal/db/init.go

@@ -5,6 +5,7 @@ import (
 	"time"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"gorm.io/driver/postgres"
 	"gorm.io/gorm"
@@ -76,7 +77,10 @@ func initDifyPluginDB(host string, port int, db_name string, user string, pass s
 }
 
 func autoMigrate() error {
-	return DifyPluginDB.AutoMigrate()
+	return DifyPluginDB.AutoMigrate(
+		models.Plugin{},
+		models.PluginInstallation{},
+	)
 }
 
 func Init(config *app.Config) {

+ 35 - 3
internal/types/entities/runtime.go

@@ -26,17 +26,33 @@ type (
 	}
 
 	PluginRuntimeTimeLifeInterface interface {
+		// returns the plugin configuration
 		Configuration() *plugin_entities.PluginDeclaration
+		// unique identity of the plugin
 		Identity() (string, error)
+		// hashed identity of the plugin
+		HashedIdentity() (string, error)
+		// before the plugin starts, it will call this method to initialize the environment
 		InitEnvironment() error
+		// start the plugin, returns errors if the plugin fails to start and hangs until the plugin stops
 		StartPlugin() error
+		// returns true if the plugin is stopped
 		Stopped() bool
+		// stop the plugin
 		Stop()
+		// add a function to be called when the plugin stops
 		OnStop(func())
+		// trigger the stop event
 		TriggerStop()
-		RuntimeState() *PluginRuntimeState
+		// returns the runtime state of the plugin
+		RuntimeState() PluginRuntimeState
+		// Update the runtime state of the plugin
+		UpdateState(state PluginRuntimeState)
+		// returns the checksum of the plugin
 		Checksum() string
+		// wait for the plugin to stop
 		Wait() (<-chan bool, error)
+		// returns the runtime type of the plugin
 		Type() PluginRuntimeType
 	}
 
@@ -62,8 +78,22 @@ func (r *PluginRuntime) Identity() (string, error) {
 	return r.Config.Identity(), nil
 }
 
-func (r *PluginRuntime) RuntimeState() *PluginRuntimeState {
-	return &r.State
+func HashedIdentity(identity string) string {
+	hash := sha256.New()
+	hash.Write([]byte(identity))
+	return hex.EncodeToString(hash.Sum(nil))
+}
+
+func (r *PluginRuntime) HashedIdentity() (string, error) {
+	return HashedIdentity(r.Config.Identity()), nil
+}
+
+func (r *PluginRuntime) RuntimeState() PluginRuntimeState {
+	return r.State
+}
+
+func (r *PluginRuntime) UpdateState(state PluginRuntimeState) {
+	r.State = state
 }
 
 func (r *PluginRuntime) Checksum() string {
@@ -100,6 +130,8 @@ type PluginRuntimeState struct {
 	ActiveAt     *time.Time `json:"active_at"`
 	StoppedAt    *time.Time `json:"stopped_at"`
 	Verified     bool       `json:"verified"`
+	ScheduledAt  *time.Time `json:"scheduled_at"`
+	Logs         []string   `json:"logs"`
 }
 
 func (s *PluginRuntimeState) Hash() (uint64, error) {

+ 0 - 3
internal/types/models/base.go

@@ -2,13 +2,10 @@ package models
 
 import (
 	"time"
-
-	"gorm.io/gorm"
 )
 
 type Model struct {
 	ID        string `gorm:"column:id;primaryKey;type:uuid;default:uuid_generate_v4()"`
 	CreatedAt time.Time
 	UpdatedAt time.Time
-	DeletedAt gorm.DeletedAt `gorm:"index"`
 }

+ 161 - 0
internal/types/models/curd/atomic.go

@@ -0,0 +1,161 @@
+package curd
+
+import (
+	"errors"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/db"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
+	"gorm.io/gorm"
+)
+
+// Create plugin for a tenant, create plugin if it has never been created before
+// and install it to the tenant, return the plugin and the installation
+// if the plugin has been created before, return the plugin which has been created before
+func CreatePlugin(tenant_id string, user_id string, plugin *models.Plugin) (*models.Plugin, *models.PluginInstallation, error) {
+	var plugin_to_be_returns *models.Plugin
+	var installation_to_be_returns *models.PluginInstallation
+
+	_, err := db.GetOne[models.PluginInstallation](
+		db.Equal("plugin_id", plugin_to_be_returns.PluginID),
+		db.Equal("tenant_id", tenant_id),
+	)
+
+	if err != nil && err != db.ErrDatabaseNotFound {
+		return nil, nil, err
+	} else if err != nil {
+		return nil, nil, errors.New("plugin has been installed already")
+	}
+
+	err = db.WithTransaction(func(tx *gorm.DB) error {
+		p, err := db.GetOne[models.Plugin](
+			db.WithTransactionContext(tx),
+			db.Equal("plugin_id", plugin.PluginID),
+			db.WLock(),
+		)
+
+		if err == db.ErrDatabaseNotFound {
+			plugin.Refers = 1
+			err := db.Create(plugin, tx)
+			if err != nil {
+				return err
+			}
+
+			plugin_to_be_returns = plugin
+		} else if err != nil {
+			return err
+		} else {
+			p.Refers++
+			err := db.Update(&p, tx)
+			if err != nil {
+				return err
+			}
+			plugin_to_be_returns = &p
+		}
+
+		installation := &models.PluginInstallation{
+			PluginID: plugin_to_be_returns.PluginID,
+			TenantID: tenant_id,
+			UserID:   user_id,
+		}
+
+		err = db.Create(installation, tx)
+		if err != nil {
+			return err
+		}
+
+		installation_to_be_returns = installation
+
+		return nil
+	})
+
+	if err != nil {
+		return nil, nil, err
+	}
+
+	return plugin_to_be_returns, installation_to_be_returns, nil
+}
+
+type DeletePluginResponse struct {
+	Plugin          *models.Plugin
+	Installation    *models.PluginInstallation
+	IsPluginDeleted bool
+}
+
+// Delete plugin for a tenant, delete the plugin if it has never been created before
+// and uninstall it from the tenant, return the plugin and the installation
+// if the plugin has been created before, return the plugin which has been created before
+func DeletePlugin(tenant_id string, plugin_id string) (*DeletePluginResponse, error) {
+	var plugin_to_be_returns *models.Plugin
+	var installation_to_be_returns *models.PluginInstallation
+
+	_, err := db.GetOne[models.PluginInstallation](
+		db.Equal("plugin_id", plugin_to_be_returns.PluginID),
+		db.Equal("tenant_id", tenant_id),
+	)
+
+	if err != nil {
+		if err == db.ErrDatabaseNotFound {
+			return nil, errors.New("plugin has not been installed")
+		} else {
+			return nil, err
+		}
+	}
+
+	err = db.WithTransaction(func(tx *gorm.DB) error {
+		p, err := db.GetOne[models.Plugin](
+			db.WithTransactionContext(tx),
+			db.Equal("plugin_id", plugin_to_be_returns.PluginID),
+			db.WLock(),
+		)
+
+		if err == db.ErrDatabaseNotFound {
+			return errors.New("plugin has not been installed")
+		} else if err != nil {
+			return err
+		} else {
+			p.Refers--
+			err := db.Update(&p, tx)
+			if err != nil {
+				return err
+			}
+			plugin_to_be_returns = &p
+		}
+
+		installation, err := db.GetOne[models.PluginInstallation](
+			db.WithTransactionContext(tx),
+			db.Equal("plugin_id", plugin_id),
+			db.Equal("tenant_id", tenant_id),
+		)
+
+		if err == db.ErrDatabaseNotFound {
+			return errors.New("plugin has not been installed")
+		} else if err != nil {
+			return err
+		} else {
+			err := db.Delete(&installation, tx)
+			if err != nil {
+				return err
+			}
+			installation_to_be_returns = &installation
+		}
+
+		if plugin_to_be_returns.Refers == 0 {
+			err := db.Delete(&plugin_to_be_returns, tx)
+			if err != nil {
+				return err
+			}
+		}
+
+		return nil
+	})
+
+	if err != nil {
+		return nil, err
+	}
+
+	return &DeletePluginResponse{
+		Plugin:          plugin_to_be_returns,
+		Installation:    installation_to_be_returns,
+		IsPluginDeleted: plugin_to_be_returns.Refers == 0,
+	}, nil
+}

+ 5 - 0
internal/types/models/installation.go

@@ -1,5 +1,10 @@
 package models
 
+type PluginInstallationStatus string
+
 type PluginInstallation struct {
 	Model
+	TenantID string `json:"tenant_id" orm:"index;type:uuid;"`
+	UserID   string `json:"user_id" orm:"index;type:uuid;"`
+	PluginID string `json:"plugin_id" orm:"index;size:127"`
 }

+ 16 - 0
internal/types/models/plugin.go

@@ -0,0 +1,16 @@
+package models
+
+import (
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+)
+
+type Plugin struct {
+	Model
+	PluginID          string                           `json:"id" orm:"index;size:127"`
+	ConfigurationText string                           `json:"configuration_text" orm:"type:text"`
+	Refers            int                              `json:"refers" orm:"default:0"`
+	Checksum          string                           `json:"checksum" orm:"size:127"`
+	InstallType       entities.PluginRuntimeType       `json:"install_type" orm:"size:127"`
+	ManifestType      plugin_entities.DifyManifestType `json:"manifest_type" orm:"size:127"`
+}

+ 21 - 5
internal/utils/cache/redis.go

@@ -3,7 +3,6 @@ package cache
 import (
 	"context"
 	"errors"
-	"fmt"
 	"strings"
 	"time"
 
@@ -32,6 +31,7 @@ func InitRedisClient(addr, password string) error {
 	return nil
 }
 
+// Close the redis client
 func Close() error {
 	if client == nil {
 		return ErrDBNotInit
@@ -55,6 +55,7 @@ func serialKey(keys ...string) string {
 	), ":")
 }
 
+// Store the key-value pair
 func Store(key string, value any, time time.Duration, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
@@ -67,6 +68,7 @@ func Store(key string, value any, time time.Duration, context ...redis.Cmdable)
 	return getCmdable(context...).Set(ctx, serialKey(key), value, time).Err()
 }
 
+// Get the value with key
 func Get[T any](key string, context ...redis.Cmdable) (*T, error) {
 	if client == nil {
 		return nil, ErrDBNotInit
@@ -88,6 +90,7 @@ func Get[T any](key string, context ...redis.Cmdable) (*T, error) {
 	return &result, err
 }
 
+// GetString get the string with key
 func GetString(key string, context ...redis.Cmdable) (string, error) {
 	if client == nil {
 		return "", ErrDBNotInit
@@ -103,6 +106,7 @@ func GetString(key string, context ...redis.Cmdable) (string, error) {
 	return v, err
 }
 
+// Del the key
 func Del(key string, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
@@ -120,6 +124,7 @@ func Del(key string, context ...redis.Cmdable) error {
 	return nil
 }
 
+// Exist check the key exist or not
 func Exist(key string, context ...redis.Cmdable) (int64, error) {
 	if client == nil {
 		return 0, ErrDBNotInit
@@ -128,6 +133,7 @@ func Exist(key string, context ...redis.Cmdable) (int64, error) {
 	return getCmdable(context...).Exists(ctx, serialKey(key)).Result()
 }
 
+// Increase the key
 func Increase(key string, context ...redis.Cmdable) (int64, error) {
 	if client == nil {
 		return 0, ErrDBNotInit
@@ -144,6 +150,7 @@ func Increase(key string, context ...redis.Cmdable) (int64, error) {
 	return num, nil
 }
 
+// Decrease the key
 func Decrease(key string, context ...redis.Cmdable) (int64, error) {
 	if client == nil {
 		return 0, ErrDBNotInit
@@ -152,6 +159,7 @@ func Decrease(key string, context ...redis.Cmdable) (int64, error) {
 	return getCmdable(context...).Decr(ctx, serialKey(key)).Result()
 }
 
+// SetExpire set the expire time for the key
 func SetExpire(key string, time time.Duration, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
@@ -160,6 +168,7 @@ func SetExpire(key string, time time.Duration, context ...redis.Cmdable) error {
 	return getCmdable(context...).Expire(ctx, serialKey(key), time).Err()
 }
 
+// SetMapField set the map field with key
 func SetMapField(key string, v map[string]any, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
@@ -168,6 +177,7 @@ func SetMapField(key string, v map[string]any, context ...redis.Cmdable) error {
 	return getCmdable(context...).HMSet(ctx, serialKey(key), v).Err()
 }
 
+// SetMapOneField set the map field with key
 func SetMapOneField(key string, field string, value any, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
@@ -180,6 +190,7 @@ func SetMapOneField(key string, field string, value any, context ...redis.Cmdabl
 	return getCmdable(context...).HSet(ctx, serialKey(key), field, value).Err()
 }
 
+// GetMapField get the map field with key
 func GetMapField[T any](key string, field string, context ...redis.Cmdable) (*T, error) {
 	if client == nil {
 		return nil, ErrDBNotInit
@@ -197,6 +208,7 @@ func GetMapField[T any](key string, field string, context ...redis.Cmdable) (*T,
 	return &result, err
 }
 
+// DelMapField delete the map field with key
 func DelMapField(key string, field string, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
@@ -205,6 +217,7 @@ func DelMapField(key string, field string, context ...redis.Cmdable) error {
 	return getCmdable(context...).HDel(ctx, serialKey(key), field).Err()
 }
 
+// GetMap get the map with key
 func GetMap[V any](key string, context ...redis.Cmdable) (map[string]V, error) {
 	if client == nil {
 		return nil, ErrDBNotInit
@@ -231,14 +244,15 @@ func GetMap[V any](key string, context ...redis.Cmdable) (map[string]V, error) {
 	return result, nil
 }
 
-func ScanMap[V any](key string, prefix string, context ...redis.Cmdable) (map[string]V, error) {
+// ScanMap scan the map with match pattern, format like "key*"
+func ScanMap[V any](key string, match string, context ...redis.Cmdable) (map[string]V, error) {
 	if client == nil {
 		return nil, ErrDBNotInit
 	}
 
 	result := make(map[string]V)
 
-	ScanMapAsync[V](key, prefix, func(m map[string]V) error {
+	ScanMapAsync[V](key, match, func(m map[string]V) error {
 		for k, v := range m {
 			result[k] = v
 		}
@@ -249,7 +263,8 @@ func ScanMap[V any](key string, prefix string, context ...redis.Cmdable) (map[st
 	return result, nil
 }
 
-func ScanMapAsync[V any](key string, prefix string, fn func(map[string]V) error, context ...redis.Cmdable) error {
+// ScanMapAsync scan the map with match pattern, format like "key*"
+func ScanMapAsync[V any](key string, match string, fn func(map[string]V) error, context ...redis.Cmdable) error {
 	if client == nil {
 		return ErrDBNotInit
 	}
@@ -258,7 +273,7 @@ func ScanMapAsync[V any](key string, prefix string, fn func(map[string]V) error,
 
 	for {
 		kvs, new_cursor, err := getCmdable(context...).
-			HScan(ctx, serialKey(key), cursor, fmt.Sprintf("%s*", prefix), 32).
+			HScan(ctx, serialKey(key), cursor, match, 32).
 			Result()
 
 		if err != nil {
@@ -289,6 +304,7 @@ func ScanMapAsync[V any](key string, prefix string, fn func(map[string]V) error,
 	return nil
 }
 
+// SetNX set the key-value pair with expire time
 func SetNX[T any](key string, value T, expire time.Duration, context ...redis.Cmdable) (bool, error) {
 	if client == nil {
 		return false, ErrDBNotInit

+ 1 - 1
internal/utils/cache/redis_test.go

@@ -148,7 +148,7 @@ func TestRedisScanMap(t *testing.T) {
 		return
 	}
 
-	data, err := ScanMap[s](strings.Join([]string{TEST_PREFIX, "map"}, ":"), "key")
+	data, err := ScanMap[s](strings.Join([]string{TEST_PREFIX, "map"}, ":"), "key*")
 	if err != nil {
 		t.Errorf("scan map failed: %v", err)
 		return

+ 78 - 0
internal/utils/mapping/sync.go

@@ -0,0 +1,78 @@
+package mapping
+
+import (
+	"sync"
+	"sync/atomic"
+)
+
+type Map[K comparable, V any] struct {
+	len   int32
+	store sync.Map
+}
+
+func (m *Map[K, V]) Load(key K) (value V, ok bool) {
+	v, ok := m.store.Load(key)
+	if !ok {
+		return
+	}
+
+	value, ok = v.(V)
+	return
+}
+
+func (m *Map[K, V]) Store(key K, value V) {
+	atomic.AddInt32(&m.len, 1)
+	m.store.Store(key, value)
+}
+
+func (m *Map[K, V]) Delete(key K) {
+	atomic.AddInt32(&m.len, -1)
+	m.store.Delete(key)
+}
+
+func (m *Map[K, V]) Range(f func(key K, value V) bool) {
+	m.store.Range(func(key, value interface{}) bool {
+		return f(key.(K), value.(V))
+	})
+}
+
+func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
+	v, loaded := m.store.LoadOrStore(key, value)
+	actual = v.(V)
+	if !loaded {
+		atomic.AddInt32(&m.len, 1)
+	}
+	return
+}
+
+func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
+	v, loaded := m.store.LoadAndDelete(key)
+	value = v.(V)
+	if loaded {
+		atomic.AddInt32(&m.len, -1)
+	}
+	return
+}
+
+func (m *Map[K, V]) Swap(key K, value V) (actual V, swapped bool) {
+	v, swapped := m.store.Swap(key, value)
+	actual = v.(V)
+	return
+}
+
+func (m *Map[K, V]) Clear() {
+	m.store.Range(func(key, value interface{}) bool {
+		m.store.Delete(key)
+		return true
+	})
+	atomic.StoreInt32(&m.len, 0)
+}
+
+func (m *Map[K, V]) Len() int {
+	return int(atomic.LoadInt32(&m.len))
+}
+
+func (m *Map[K, V]) Exits(key K) bool {
+	_, ok := m.Load(key)
+	return ok
+}