Browse Source

feat: lifetime manager

Yeuoly 1 year ago
parent
commit
1af5a2ec11

+ 8 - 0
.env.example

@@ -7,6 +7,14 @@ DIFY_CALLING_PORT=5002
 PLUGIN_HOST=127.0.0.1
 PLUGIN_PORT=5003
 
+ROUTINE_POOL_SIZE=1024
+
+REDIS_HOST=127.0.0.1
+REDIS_PORT=6379
+REDIS_PASS=difyai123456
+LIFETIME_COLLECTION_HEARTBEAT_INTERVAL=5
+LIFETIME_COLLECTION_CG_INTERVAL=60
+LIFETIME_STATE_GC_INTERVAL=300
 STORAGE_PATH=examples
 
 PLATFORM=local

+ 2 - 1
.gitignore

@@ -1,4 +1,5 @@
 release/
 logs/
 .vscode/
-.env
+.env
+cmd/**/__debug_bin*

+ 3 - 0
cmd/server/main.go

@@ -30,6 +30,9 @@ func main() {
 func setDefault(config *app.Config) {
 	setDefaultInt(&config.RoutinePoolSize, 1000)
 	setDefaultInt(&config.DifyCallingPort, 5002)
+	setDefaultInt(&config.LifetimeCollectionGCInterval, 60)
+	setDefaultInt(&config.LifetimeCollectionHeartbeatInterval, 5)
+	setDefaultInt(&config.LifetimeStateGCInterval, 300)
 }
 
 func setDefaultInt[T constraints.Integer](value *T, defaultValue T) {

+ 52 - 0
cmd/tests/main.go

@@ -0,0 +1,52 @@
+package main
+
+import (
+	"fmt"
+	"math/rand"
+	"sync/atomic"
+	"time"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+)
+
+func main() {
+	response := entities.NewInvocationResponse[string](1024)
+
+	random_string := func() string {
+		return fmt.Sprintf("%d", rand.Intn(100000))
+	}
+
+	traffic := new(int64)
+
+	go func() {
+		for {
+			response.Write(random_string())
+		}
+	}()
+
+	go func() {
+		for {
+			response.Write(random_string())
+		}
+	}()
+
+	go func() {
+		for response.Next() {
+			atomic.AddInt64(traffic, 1)
+			_, err := response.Read()
+			if err != nil {
+				fmt.Println(err)
+				break
+			}
+		}
+	}()
+
+	go func() {
+		for range time.NewTicker(time.Second).C {
+			fmt.Printf("Traffic: %d, Unsolved: %d\n", atomic.LoadInt64(traffic), response.Size())
+			atomic.StoreInt64(traffic, 0)
+		}
+	}()
+
+	select {}
+}

+ 6 - 0
go.mod

@@ -5,6 +5,12 @@ go 1.20
 require github.com/google/uuid v1.6.0
 
 require (
+	github.com/cespare/xxhash/v2 v2.2.0 // indirect
+	github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
+	github.com/redis/go-redis/v9 v9.5.3 // indirect
+)
+
+require (
 	github.com/bytedance/sonic v1.11.9 // indirect
 	github.com/bytedance/sonic/loader v0.1.1 // indirect
 	github.com/cloudwego/base64x v0.1.4 // indirect

+ 6 - 0
go.sum

@@ -2,12 +2,16 @@ github.com/bytedance/sonic v1.11.9 h1:LFHENlIY/SLzDWverzdOvgMztTxcfcF+cqNsz9pK5z
 github.com/bytedance/sonic v1.11.9/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
 github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
 github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
+github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44=
+github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
 github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
 github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
 github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
 github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
 github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78=
+github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc=
 github.com/gabriel-vasile/mimetype v1.4.4 h1:QjV6pZ7/XZ7ryI2KuyeEDE8wnh7fHP9YnQy+R0LnH8I=
 github.com/gabriel-vasile/mimetype v1.4.4/go.mod h1:JwLei5XPtWdGiMFB5Pjle1oEeoSeEuJfJE+TtfvdB/s=
 github.com/gammazero/deque v0.2.1 h1:qSdsbG6pgp6nL7A0+K/B7s12mcCY/5l5SIUpMOl+dC0=
@@ -53,6 +57,8 @@ github.com/panjf2000/gnet/v2 v2.5.5/go.mod h1:ppopMJ8VrDbJu8kDsqFQTgNmpMS8Le5CmP
 github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
 github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/redis/go-redis/v9 v9.5.3 h1:fOAp1/uJG+ZtcITgZOfYFmTKPE7n4Vclj1wZFgRciUU=
+github.com/redis/go-redis/v9 v9.5.3/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
 github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=

+ 15 - 3
internal/core/plugin_manager/lifetime.go

@@ -3,12 +3,24 @@ package plugin_manager
 import (
 	"time"
 
+	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 )
 
-func lifetime(r entities.PluginRuntimeInterface) {
+func lifetime(config *app.Config, r entities.PluginRuntimeInterface) {
 	start_failed_times := 0
+	configuration := r.Configuration()
+
+	addLifetimeState(r)
+	defer func() {
+		// remove lifetime state after plugin if it has been stopped for $LIFETIME_STATE_GC_INTERVAL and not started again
+		time.AfterFunc(time.Duration(config.LifetimeStateGCInterval)*time.Second, func() {
+			if r.Stopped() {
+				deleteLifetimeState(r)
+			}
+		})
+	}()
 
 	for !r.Stopped() {
 		if err := r.InitEnvironment(); err != nil {
@@ -17,7 +29,7 @@ func lifetime(r entities.PluginRuntimeInterface) {
 			if start_failed_times == 3 {
 				log.Error(
 					"init environment failed 3 times, plugin %s has been stopped",
-					r.Configuration().Identity(),
+					configuration.Identity(),
 				)
 				r.Stop()
 			}
@@ -33,7 +45,7 @@ func lifetime(r entities.PluginRuntimeInterface) {
 			if start_failed_times == 3 {
 				log.Error(
 					"start plugin failed 3 times, plugin %s has been stopped",
-					r.Configuration().Identity(),
+					configuration.Identity(),
 				)
 				r.Stop()
 			}

+ 145 - 0
internal/core/plugin_manager/lifetime_manager.go

@@ -0,0 +1,145 @@
+package plugin_manager
+
+import (
+	"sync"
+	"time"
+
+	"github.com/google/uuid"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+)
+
+const (
+	KEY_PLUGIN_LIFETIME_STATE             = "lifetime_state"
+	KEY_PLUGIN_LIFETIME_STATE_MODIFY_LOCK = "lifetime_state_modify_lock"
+)
+
+type PluginLifeTime struct {
+	Identity string                       `json:"identity"`
+	Restarts int                          `json:"restarts"`
+	Status   string                       `json:"status"`
+	Config   entities.PluginConfiguration `json:"configuration"`
+}
+
+type pluginLifeCollection struct {
+	Collection  map[string]PluginLifeTime `json:"state"`
+	ID          string                    `json:"id"`
+	LastCheckAt time.Time                 `json:"last_check_at"`
+}
+
+func (p pluginLifeCollection) MarshalBinary() ([]byte, error) {
+	return parser.MarshalJsonBytes(p), nil
+}
+
+var (
+	instanceId = uuid.New().String()
+
+	pluginLifetimeStateLock  = sync.RWMutex{}
+	pluginLifetimeCollection = pluginLifeCollection{
+		Collection: map[string]PluginLifeTime{},
+		ID:         instanceId,
+	}
+)
+
+func startLifeTimeManager(config *app.Config) {
+	go func() {
+		// do check immediately
+		doClusterLifetimeCheck(config.LifetimeCollectionGCInterval)
+
+		duration := time.Duration(config.LifetimeCollectionHeartbeatInterval) * time.Second
+		for range time.NewTicker(duration).C {
+			doClusterLifetimeCheck(config.LifetimeCollectionGCInterval)
+		}
+	}()
+}
+
+func doClusterLifetimeCheck(heartbeat_interval int) {
+	// check and update self lifetime state
+	if err := updateCurrentInstanceLifetimeCollection(); err != nil {
+		log.Error("update current instance lifetime state failed: %s", err.Error())
+		return
+	}
+
+	// lock cluster and do cluster lifetime check
+	if cache.Lock(KEY_PLUGIN_LIFETIME_STATE_MODIFY_LOCK, time.Second*10, time.Second*10) != nil {
+		log.Error("update lifetime state failed: lock failed")
+		return
+	}
+	defer cache.Unlock(KEY_PLUGIN_LIFETIME_STATE_MODIFY_LOCK)
+
+	cluster_lifetime_collections, err := fetchClusterPluginLifetimeCollections()
+	if err != nil {
+		log.Error("fetch cluster plugin lifetime state failed: %s", err.Error())
+		return
+	}
+
+	for cluster_id, state := range cluster_lifetime_collections {
+		if state.ID == instanceId {
+			continue
+		}
+
+		// skip if last check has been done in $LIFETIME_COLLECTION_CG_INTERVAL
+		cg_interval := time.Duration(heartbeat_interval) * time.Second
+		if time.Since(state.LastCheckAt) < cg_interval {
+			continue
+		}
+
+		// if last check has not been done in $LIFETIME_COLLECTION_CG_INTERVAL * 2, delete it
+		if time.Since(state.LastCheckAt) > cg_interval*2 {
+			if err := cache.DelMapField(KEY_PLUGIN_LIFETIME_STATE, cluster_id); err != nil {
+				log.Error("delete cluster plugin lifetime state failed: %s", err.Error())
+			} else {
+				log.Info("delete cluster plugin lifetime state due to no longer active: %s", cluster_id)
+			}
+		}
+	}
+}
+
+func newLifetimeFromRuntimeState(state entities.PluginRuntimeInterface) PluginLifeTime {
+	s := state.RuntimeState()
+	c := state.Configuration()
+
+	return PluginLifeTime{
+		Identity: c.Identity(),
+		Restarts: s.Restarts,
+		Status:   s.Status,
+		Config:   *c,
+	}
+}
+
+func addLifetimeState(state entities.PluginRuntimeInterface) {
+	pluginLifetimeStateLock.Lock()
+	defer pluginLifetimeStateLock.Unlock()
+
+	pluginLifetimeCollection.Collection[state.Configuration().Identity()] = newLifetimeFromRuntimeState(state)
+}
+
+func deleteLifetimeState(state entities.PluginRuntimeInterface) {
+	pluginLifetimeStateLock.Lock()
+	defer pluginLifetimeStateLock.Unlock()
+
+	delete(pluginLifetimeCollection.Collection, state.Configuration().Identity())
+}
+
+func updateCurrentInstanceLifetimeCollection() error {
+	pluginLifetimeStateLock.Lock()
+	defer pluginLifetimeStateLock.Unlock()
+
+	pluginLifetimeCollection.LastCheckAt = time.Now()
+
+	m.Range(func(key, value interface{}) bool {
+		if v, ok := value.(entities.PluginRuntimeInterface); ok {
+			pluginLifetimeCollection.Collection[v.Configuration().Identity()] = newLifetimeFromRuntimeState(v)
+		}
+		return true
+	})
+
+	return cache.SetMapOneField(KEY_PLUGIN_LIFETIME_STATE, instanceId, pluginLifetimeCollection)
+}
+
+func fetchClusterPluginLifetimeCollections() (map[string]pluginLifeCollection, error) {
+	return cache.GetMap[pluginLifeCollection](KEY_PLUGIN_LIFETIME_STATE)
+}

+ 16 - 1
internal/core/plugin_manager/manager.go

@@ -1,8 +1,11 @@
 package plugin_manager
 
 import (
+	"fmt"
+
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 )
 
@@ -38,5 +41,17 @@ func Init(configuration *app.Config) {
 	// TODO: init plugin manager
 	log.Info("start plugin manager daemon...")
 
-	startWatcher(configuration.StoragePath, configuration.Platform)
+	// init redis client
+	if err := cache.InitRedisClient(
+		fmt.Sprintf("%s:%d", configuration.RedisHost, configuration.RedisPort),
+		configuration.RedisPass,
+	); err != nil {
+		log.Panic("init redis client failed: %s", err.Error())
+	}
+
+	// start plugin watcher
+	startWatcher(configuration)
+
+	// start plugin lifetime manager
+	startLifeTimeManager(configuration)
 }

+ 10 - 10
internal/core/plugin_manager/watcher.go

@@ -14,31 +14,31 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
 
-func startWatcher(path string, platform string) {
+func startWatcher(config *app.Config) {
 	go func() {
-		log.Info("start to handle new plugins in path: %s", path)
-		handleNewPlugins(path, platform)
+		log.Info("start to handle new plugins in path: %s", config.StoragePath)
+		handleNewPlugins(config)
 		for range time.NewTicker(time.Second * 30).C {
-			handleNewPlugins(path, platform)
+			handleNewPlugins(config)
 		}
 	}()
 }
 
-func handleNewPlugins(path string, platform string) {
+func handleNewPlugins(config *app.Config) {
 	// load local plugins firstly
-	for plugin := range loadNewPlugins(path) {
+	for plugin := range loadNewPlugins(config.StoragePath) {
 		var plugin_interface entities.PluginRuntimeInterface
 
-		if platform == app.PLATFORM_AWS_LAMBDA {
+		if config.Platform == app.PLATFORM_AWS_LAMBDA {
 			plugin_interface = &aws_manager.AWSPluginRuntime{
 				PluginRuntime: plugin,
 			}
-		} else if platform == app.PLATFORM_LOCAL {
+		} else if config.Platform == app.PLATFORM_LOCAL {
 			plugin_interface = &local_manager.LocalPluginRuntime{
 				PluginRuntime: plugin,
 			}
 		} else {
-			log.Error("unsupported platform: %s for plugin: %s", platform, plugin.Config.Name)
+			log.Error("unsupported platform: %s for plugin: %s", config.Platform, plugin.Config.Name)
 			continue
 		}
 
@@ -47,7 +47,7 @@ func handleNewPlugins(path string, platform string) {
 		m.Store(plugin.Config.Identity(), plugin_interface)
 
 		routine.Submit(func() {
-			lifetime(plugin_interface)
+			lifetime(config, plugin_interface)
 		})
 	}
 }

+ 8 - 0
internal/types/app/config.go

@@ -13,6 +13,14 @@ type Config struct {
 	Platform string `envconfig:"PLATFORM"`
 
 	RoutinePoolSize int `envconfig:"ROUTINE_POOL_SIZE"`
+
+	RedisHost string `envconfig:"REDIS_HOST"`
+	RedisPort int16  `envconfig:"REDIS_PORT"`
+	RedisPass string `envconfig:"REDIS_PASS"`
+
+	LifetimeCollectionHeartbeatInterval int `envconfig:"LIFETIME_COLLECTION_HEARTBEAT_INTERVAL"`
+	LifetimeCollectionGCInterval        int `envconfig:"LIFETIME_COLLECTION_GC_INTERVAL"`
+	LifetimeStateGCInterval             int `envconfig:"LIFETIME_STATE_GC_INTERVAL"`
 }
 
 const (

+ 5 - 0
internal/types/entities/runtime.go

@@ -22,6 +22,7 @@ type (
 		Stopped() bool
 		Stop()
 		Configuration() *PluginConfiguration
+		RuntimeState() *PluginRuntimeState
 	}
 
 	PluginRuntimeSessionIOInterface interface {
@@ -43,6 +44,10 @@ func (r *PluginRuntime) Configuration() *PluginConfiguration {
 	return &r.Config
 }
 
+func (r *PluginRuntime) RuntimeState() *PluginRuntimeState {
+	return &r.State
+}
+
 type PluginRuntimeState struct {
 	Restarts     int        `json:"restarts"`
 	Status       string     `json:"status"`

+ 199 - 0
internal/utils/cache/redis.go

@@ -0,0 +1,199 @@
+package cache
+
+import (
+	"context"
+	"errors"
+	"strings"
+	"time"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+	"github.com/redis/go-redis/v9"
+)
+
+var (
+	client *redis.Client
+	ctx    = context.Background()
+
+	ErrDBNotInit = errors.New("redis client not init")
+)
+
+func InitRedisClient(addr, password string) error {
+	client = redis.NewClient(&redis.Options{
+		Addr:     addr,
+		Password: password,
+	})
+
+	if _, err := client.Ping(ctx).Result(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func serialKey(keys ...string) string {
+	return strings.Join(append(
+		[]string{"plugin_daemon"},
+		keys...,
+	), ":")
+}
+
+func Store(key string, value any, time time.Duration) error {
+	if client == nil {
+		return ErrDBNotInit
+	}
+
+	return client.Set(ctx, serialKey(key), value, time).Err()
+}
+
+func Get[T any](key string) (*T, error) {
+	if client == nil {
+		return nil, ErrDBNotInit
+	}
+
+	val, err := client.Get(ctx, serialKey(key)).Result()
+	if err != nil {
+		return nil, err
+	}
+
+	result, err := parser.UnmarshalJson[T](val)
+	return &result, err
+}
+
+func Del(key string) error {
+	if client == nil {
+		return ErrDBNotInit
+	}
+
+	return client.Del(ctx, serialKey(key)).Err()
+}
+
+func Exist(key string) (int64, error) {
+	if client == nil {
+		return 0, ErrDBNotInit
+	}
+
+	return client.Exists(ctx, serialKey(key)).Result()
+}
+
+func Increase(key string) (int64, error) {
+	if client == nil {
+		return 0, ErrDBNotInit
+	}
+
+	return client.Incr(ctx, serialKey(key)).Result()
+}
+
+func Decrease(key string) (int64, error) {
+	if client == nil {
+		return 0, ErrDBNotInit
+	}
+
+	return client.Decr(ctx, serialKey(key)).Result()
+}
+
+func SetExpire(key string, time time.Duration) error {
+	if client == nil {
+		return ErrDBNotInit
+	}
+
+	return client.Expire(ctx, serialKey(key), time).Err()
+}
+
+func SetMapField(key string, v map[string]any) error {
+	if client == nil {
+		return ErrDBNotInit
+	}
+
+	return client.HMSet(ctx, serialKey(key), v).Err()
+}
+
+func SetMapOneField(key string, field string, value any) error {
+	if client == nil {
+		return ErrDBNotInit
+	}
+
+	return client.HSet(ctx, serialKey(key), field, value).Err()
+}
+
+func GetMapField[T any](key string, field string) (*T, error) {
+	if client == nil {
+		return nil, ErrDBNotInit
+	}
+
+	val, err := client.HGet(ctx, serialKey(key), field).Result()
+	if err != nil {
+		return nil, err
+	}
+
+	result, err := parser.UnmarshalJson[T](val)
+	return &result, err
+}
+
+func DelMapField(key string, field string) error {
+	if client == nil {
+		return ErrDBNotInit
+	}
+
+	return client.HDel(ctx, serialKey(key), field).Err()
+}
+
+func GetMap[V any](key string) (map[string]V, error) {
+	if client == nil {
+		return nil, ErrDBNotInit
+	}
+
+	val, err := client.HGetAll(ctx, serialKey(key)).Result()
+	if err != nil {
+		return nil, err
+	}
+
+	result := make(map[string]V)
+	for k, v := range val {
+		value, err := parser.UnmarshalJson[V](v)
+		if err != nil {
+			return nil, err
+		}
+
+		result[k] = value
+	}
+
+	return result, nil
+}
+
+var (
+	ErrLockTimeout = errors.New("lock timeout")
+)
+
+// Lock key, expire time takes responsibility for expiration time
+// try_lock_timeout takes responsibility for the timeout of trying to lock
+func Lock(key string, expire time.Duration, try_lock_timeout time.Duration) error {
+	if client == nil {
+		return ErrDBNotInit
+	}
+
+	const LOCK_DURATION = 20 * time.Millisecond
+
+	ticker := time.NewTicker(LOCK_DURATION)
+	defer ticker.Stop()
+
+	for range ticker.C {
+		if _, err := client.SetNX(ctx, serialKey(key), "1", expire).Result(); err == nil {
+			return nil
+		}
+
+		try_lock_timeout -= LOCK_DURATION
+		if try_lock_timeout <= 0 {
+			return ErrLockTimeout
+		}
+	}
+
+	return nil
+}
+
+func Unlock(key string) error {
+	if client == nil {
+		return ErrDBNotInit
+	}
+
+	return client.Del(ctx, serialKey(key)).Err()
+}