浏览代码

feat: cluster managment

Yeuoly 1 年之前
父节点
当前提交
367cad55da

+ 11 - 0
internal/cluster/cluster_id/id.go

@@ -0,0 +1,11 @@
+package cluster_id
+
+import "github.com/google/uuid"
+
+var (
+	instanceId = uuid.New().String()
+)
+
+func GetInstanceID() string {
+	return instanceId
+}

+ 17 - 0
internal/cluster/entities.go

@@ -0,0 +1,17 @@
+package cluster
+
+type ip struct {
+	Address string `json:"address"`
+	Votes   []vote `json:"vote"`
+}
+
+type vote struct {
+	NodeID  string `json:"node_id"`
+	VotedAt int64  `json:"voted_at"`
+	Failed  bool   `json:"failed"`
+}
+
+type node struct {
+	Ips        []ip  `json:"ips"`
+	LastPingAt int64 `json:"last_ping_at"`
+}

+ 51 - 0
internal/cluster/gc.go

@@ -0,0 +1,51 @@
+package cluster
+
+import (
+	"errors"
+	"time"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+)
+
+// gc the nodes has already deactivated
+func (c *Cluster) gcNodes() error {
+	var total_errors error
+	add_error := func(err error) {
+		if err != nil {
+			if total_errors == nil {
+				total_errors = err
+			} else {
+				total_errors = errors.Join(total_errors, err)
+			}
+		}
+	}
+
+	// get all nodes status
+	nodes, err := cache.GetMap[node](CLUSTER_STATUS_HASH_MAP_KEY)
+	if err == cache.ErrNotFound {
+		return nil
+	}
+
+	for node_id, node_status := range nodes {
+		// delete the node if it is disconnected
+		if time.Since(time.Unix(node_status.LastPingAt, 0)) > NODE_DISCONNECTED_TIMEOUT {
+			// gc the node
+			if err := c.gcNode(node_id); err != nil {
+				add_error(err)
+				continue
+			}
+
+			// delete the node status
+			if err := cache.DelMapField(CLUSTER_STATUS_HASH_MAP_KEY, node_id); err != nil {
+				add_error(err)
+			}
+		}
+	}
+
+	return total_errors
+}
+
+// remove the resource associated with the node
+func (c *Cluster) gcNode(node_id string) error {
+	return nil
+}

+ 25 - 0
internal/cluster/init.go

@@ -0,0 +1,25 @@
+package cluster
+
+import "github.com/langgenius/dify-plugin-daemon/internal/types/app"
+
+type Cluster struct {
+	port uint16
+}
+
+var (
+	cluster *Cluster
+)
+
+func Launch(config *app.Config) {
+	cluster = &Cluster{
+		port: uint16(config.ServerPort),
+	}
+
+	go func() {
+		cluster.clusterLifetime()
+	}()
+}
+
+func GetCluster() *Cluster {
+	return cluster
+}

+ 44 - 0
internal/cluster/lock.go

@@ -0,0 +1,44 @@
+package cluster
+
+import (
+	"strings"
+	"time"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+)
+
+const (
+	CLUSTER_STATE_TENANT_LOCK_PREFIX       = "cluster_state_tenant_lock"
+	CLUSTER_STATE_PLUGIN_LOCK_PREFIX       = "cluster_state_plugin_lock"
+	CLUSTER_UPDATE_NODE_STATUS_LOCK_PREFIX = "cluster_update_node_status_lock"
+)
+
+func (c *Cluster) LockTenant(tenant_id string) error {
+	key := strings.Join([]string{CLUSTER_STATE_TENANT_LOCK_PREFIX, tenant_id}, ":")
+	return cache.Lock(key, time.Second*5, time.Second)
+}
+
+func (c *Cluster) UnlockTenant(tenant_id string) error {
+	key := strings.Join([]string{CLUSTER_STATE_TENANT_LOCK_PREFIX, tenant_id}, ":")
+	return cache.Unlock(key)
+}
+
+func (c *Cluster) LockPlugin(plugin_id string) error {
+	key := strings.Join([]string{CLUSTER_STATE_PLUGIN_LOCK_PREFIX, plugin_id}, ":")
+	return cache.Lock(key, time.Second*5, time.Second)
+}
+
+func (c *Cluster) UnlockPlugin(plugin_id string) error {
+	key := strings.Join([]string{CLUSTER_STATE_PLUGIN_LOCK_PREFIX, plugin_id}, ":")
+	return cache.Unlock(key)
+}
+
+func (c *Cluster) LockNodeStatus(node_id string) error {
+	key := strings.Join([]string{CLUSTER_UPDATE_NODE_STATUS_LOCK_PREFIX, node_id}, ":")
+	return cache.Lock(key, time.Second*5, time.Second)
+}
+
+func (c *Cluster) UnlockNodeStatus(node_id string) error {
+	key := strings.Join([]string{CLUSTER_UPDATE_NODE_STATUS_LOCK_PREFIX, node_id}, ":")
+	return cache.Unlock(key)
+}

+ 216 - 0
internal/cluster/preemptive.go

@@ -0,0 +1,216 @@
+package cluster
+
+import (
+	"errors"
+	"net"
+	"time"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/cluster/cluster_id"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/network"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+)
+
+// Plugin daemon will preemptively try to lock the slot to be the master of the cluster
+// and keep update current status of the whole cluster
+// once the master is no longer active, one of the slave will try to lock the slot again
+// and become the new master
+//
+// Once a node becomes master, It will take responsibility to gc the nodes has already deactivated
+// and all nodes should to maintenance their own status
+//
+// State:
+//	- hashmap[cluster-status]
+//		- node-id:
+//			- list[ip]:
+//				- address: string
+//				- vote: int
+//			- last_ping_at: int64
+//	- preemption-lock: node-id
+//	- node-status-upgrade-status
+//
+// A node will be removed from the cluster if it is no longer active
+
+var (
+	i_am_master = false
+)
+
+const (
+	CLUSTER_STATUS_HASH_MAP_KEY = "cluster-status-hash-map"
+	PREEMPTION_LOCK_KEY         = "cluster-master-preemption-lock"
+)
+
+const (
+	MASTER_LOCKING_INTERVAL     = time.Millisecond * 500 // interval to try to lock the slot to be the master
+	MASTER_LOCK_EXPIRED_TIME    = time.Second * 5        // expired time of master key
+	MASTER_GC_INTERVAL          = time.Second * 10       // interval to do garbage collection of nodes has already deactivated
+	NODE_VOTE_INTERVAL          = time.Second * 30       // interval to vote the ips of the nodes
+	UPDATE_NODE_STATUS_INTERVAL = time.Second * 5        // interval to update the status of the node
+	NODE_DISCONNECTED_TIMEOUT   = time.Second * 10       // once a node is no longer active, it will be removed from the cluster
+)
+
+// lifetime of the cluster
+func (c *Cluster) clusterLifetime() {
+	ticker_lock_master := time.NewTicker(MASTER_LOCKING_INTERVAL)
+	defer ticker_lock_master.Stop()
+
+	ticker_update_node_status := time.NewTicker(UPDATE_NODE_STATUS_INTERVAL)
+	defer ticker_update_node_status.Stop()
+
+	master_gc_ticker := time.NewTicker(MASTER_GC_INTERVAL)
+	defer master_gc_ticker.Stop()
+
+	node_vote_ticker := time.NewTicker(NODE_VOTE_INTERVAL)
+	defer node_vote_ticker.Stop()
+
+	if err := c.voteIps(); err != nil {
+		log.Error("failed to vote the ips of the nodes: %s", err.Error())
+	}
+
+	for {
+		select {
+		case <-ticker_lock_master.C:
+			if !i_am_master {
+				// try lock the slot
+				if success, err := c.lockMaster(); err != nil {
+					log.Error("failed to lock the slot to be the master of the cluster: %s", err.Error())
+				} else if success {
+					i_am_master = true
+					log.Info("current node has become the master of the cluster")
+				} else {
+					i_am_master = false
+					log.Info("current node lost the master slot")
+				}
+			} else {
+				// update the master
+				if err := c.updateMaster(); err != nil {
+					log.Error("failed to update the master: %s", err.Error())
+				}
+			}
+		case <-ticker_update_node_status.C:
+			if err := c.updateNodeStatus(); err != nil {
+				log.Error("failed to update the status of the node: %s", err.Error())
+			}
+		case <-master_gc_ticker.C:
+			if i_am_master {
+				if err := c.gcNodes(); err != nil {
+					log.Error("failed to gc the nodes has already deactivated: %s", err.Error())
+				}
+			}
+		case <-node_vote_ticker.C:
+			if err := c.voteIps(); err != nil {
+				log.Error("failed to vote the ips of the nodes: %s", err.Error())
+			}
+		}
+	}
+}
+
+// try lock the slot to be the master of the cluster
+// returns:
+//   - bool: true if the slot is locked by the node
+//   - error: error if any
+func (c *Cluster) lockMaster() (bool, error) {
+	var final_error error
+
+	for i := 0; i < 3; i++ {
+		if success, err := cache.SetNX(PREEMPTION_LOCK_KEY, cluster_id.GetInstanceID(), MASTER_LOCK_EXPIRED_TIME); err != nil {
+			// try again
+			if final_error == nil {
+				final_error = err
+			} else {
+				final_error = errors.Join(final_error, err)
+			}
+		} else if !success {
+			return false, nil
+		} else {
+			return true, nil
+		}
+	}
+
+	return false, final_error
+}
+
+// update master
+func (c *Cluster) updateMaster() error {
+	// update expired time of master key
+	if _, err := cache.Expire(PREEMPTION_LOCK_KEY, MASTER_LOCK_EXPIRED_TIME); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+// update the status of the node
+func (c *Cluster) updateNodeStatus() error {
+	if err := c.LockNodeStatus(cluster_id.GetInstanceID()); err != nil {
+		return err
+	}
+	defer c.UnlockNodeStatus(cluster_id.GetInstanceID())
+
+	// update the status of the node
+	node_status, err := cache.GetMapField[node](CLUSTER_STATUS_HASH_MAP_KEY, cluster_id.GetInstanceID())
+	if err != nil {
+		if err == cache.ErrNotFound {
+			// try to get ips configs
+			ips, err := network.FetchCurrentIps()
+			if err != nil {
+				return err
+			}
+			node_status = &node{
+				Ips: parser.Map(func(from net.IP) ip {
+					return ip{
+						Address: from.String(),
+						Votes:   []vote{},
+					}
+				}, ips),
+			}
+		} else {
+			return err
+		}
+	} else {
+		ips, err := network.FetchCurrentIps()
+		if err != nil {
+			return err
+		}
+		// add new ip if not exist
+		for _, _ip := range ips {
+			found := false
+			for _, node_ip := range node_status.Ips {
+				if node_ip.Address == _ip.String() {
+					found = true
+					break
+				}
+			}
+			if !found {
+				node_status.Ips = append(node_status.Ips, ip{
+					Address: _ip.String(),
+					Votes:   []vote{},
+				})
+			}
+		}
+	}
+
+	// refresh the last ping time
+	node_status.LastPingAt = time.Now().Unix()
+
+	// update the status of the node
+	if err := cache.SetMapOneField(CLUSTER_STATUS_HASH_MAP_KEY, cluster_id.GetInstanceID(), node_status); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (c *Cluster) IsMaster() bool {
+	return i_am_master
+}
+
+func (c *Cluster) IsNodeAlive(node_id string) bool {
+	node_status, err := cache.GetMapField[node](CLUSTER_STATUS_HASH_MAP_KEY, node_id)
+	if err != nil {
+		return false
+	}
+
+	return time.Since(time.Unix(node_status.LastPingAt, 0)) < NODE_DISCONNECTED_TIMEOUT
+}

+ 1 - 0
internal/cluster/state.go

@@ -0,0 +1 @@
+package cluster

+ 145 - 0
internal/cluster/vote.go

@@ -0,0 +1,145 @@
+package cluster
+
+import (
+	"errors"
+	"fmt"
+	"net/http"
+	"net/url"
+	"sort"
+	"time"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/cluster/cluster_id"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/http_requests"
+)
+
+func (c *Cluster) voteIps() error {
+	var total_errors error
+	add_error := func(err error) {
+		if err != nil {
+			if total_errors == nil {
+				total_errors = err
+			} else {
+				total_errors = errors.Join(total_errors, err)
+			}
+		}
+	}
+
+	// get all nodes status
+	nodes, err := cache.GetMap[node](CLUSTER_STATUS_HASH_MAP_KEY)
+	if err == cache.ErrNotFound {
+		return nil
+	}
+
+	for node_id, node_status := range nodes {
+		if node_id == cluster_id.GetInstanceID() {
+			continue
+		}
+
+		// vote for ips
+		ips_voting := make(map[string]bool)
+		for _, ip := range node_status.Ips {
+			// skip ips which have already been voted by current node in the last 5 minutes
+			for _, vote := range ip.Votes {
+				if vote.NodeID == cluster_id.GetInstanceID() {
+					if time.Since(time.Unix(vote.VotedAt, 0)) < time.Minute*5 && !vote.Failed {
+						continue
+					} else if time.Since(time.Unix(vote.VotedAt, 0)) < time.Minute*30 && vote.Failed {
+						continue
+					}
+				}
+			}
+
+			ips_voting[ip.Address] = c.voteIp(ip) == nil
+		}
+
+		// lock the node status
+		if err := c.LockNodeStatus(node_id); err != nil {
+			add_error(err)
+			c.UnlockNodeStatus(node_id)
+			continue
+		}
+
+		// get the node status again
+		node_status, err := cache.GetMapField[node](CLUSTER_STATUS_HASH_MAP_KEY, node_id)
+		if err != nil {
+			add_error(err)
+			c.UnlockNodeStatus(node_id)
+			continue
+		}
+
+		// update the node status
+		for i, ip := range node_status.Ips {
+			// update voting time
+			if success, ok := ips_voting[ip.Address]; ok {
+				// check if the ip has already voted
+				already_voted := false
+				for j, vote := range ip.Votes {
+					if vote.NodeID == cluster_id.GetInstanceID() {
+						node_status.Ips[i].Votes[j].VotedAt = time.Now().Unix()
+						node_status.Ips[i].Votes[j].Failed = !success
+						already_voted = true
+						break
+					}
+				}
+				// add a new vote
+				if !already_voted {
+					node_status.Ips[i].Votes = append(node_status.Ips[i].Votes, vote{
+						NodeID:  cluster_id.GetInstanceID(),
+						VotedAt: time.Now().Unix(),
+						Failed:  !success,
+					})
+				}
+			}
+		}
+
+		// sync the node status
+		if err := cache.SetMapOneField(CLUSTER_STATUS_HASH_MAP_KEY, node_id, node_status); err != nil {
+			add_error(err)
+		}
+
+		// unlock the node status
+		if err := c.UnlockNodeStatus(node_id); err != nil {
+			add_error(err)
+		}
+	}
+
+	return total_errors
+}
+
+func (c *Cluster) voteIp(ip ip) error {
+	type healthcheck struct {
+		Status string `json:"status"`
+	}
+
+	healthcheck_endpoint, err := url.JoinPath(fmt.Sprintf("http://%s:%d", ip.Address, c.port), "health/check")
+	if err != nil {
+		return err
+	}
+
+	resp, err := http_requests.GetAndParse[healthcheck](
+		http.DefaultClient,
+		healthcheck_endpoint,
+		http_requests.HttpWriteTimeout(500),
+		http_requests.HttpReadTimeout(500),
+	)
+
+	if err != nil {
+		return err
+	}
+
+	if resp.Status != "ok" {
+		return errors.New("health check failed")
+	}
+
+	return nil
+}
+
+func (c *Cluster) SortIps(node_status node) []ip {
+	// sort by votes
+	sort.Slice(node_status.Ips, func(i, j int) bool {
+		return len(node_status.Ips[i].Votes) > len(node_status.Ips[j].Votes)
+	})
+
+	return node_status.Ips
+}

+ 4 - 0
internal/core/plugin_daemon/backwards_invocation/task_test.go

@@ -16,6 +16,10 @@ func (r *TPluginRuntime) InitEnvironment() error {
 	return nil
 }
 
+func (r *TPluginRuntime) Identity() (string, error) {
+	return "", nil
+}
+
 func (r *TPluginRuntime) StartPlugin() error {
 	return nil
 }

+ 1 - 1
internal/core/plugin_manager/remote_manager/hooks.go

@@ -140,7 +140,7 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 			return
 		}
 
-		runtime.State.TenantID = info.TenantId
+		runtime.tenant_id = info.TenantId
 
 		// handshake completed
 		runtime.handshake = true

+ 1 - 1
internal/core/plugin_manager/remote_manager/server_test.go

@@ -107,7 +107,7 @@ func TestAcceptConnection(t *testing.T) {
 				connection_err = errors.New("plugin name not matched")
 			}
 
-			if runtime.State.TenantID != "test" {
+			if runtime.tenant_id != "test" {
 				connection_err = errors.New("tenant id not matched")
 			}
 

+ 8 - 0
internal/core/plugin_manager/remote_manager/type.go

@@ -1,6 +1,7 @@
 package remote_manager
 
 import (
+	"strings"
 	"sync"
 	"time"
 
@@ -35,9 +36,16 @@ type RemotePluginRuntime struct {
 	// registration transferred
 	registration_transferred bool
 
+	// tenant id
+	tenant_id string
+
 	alive bool
 }
 
+func (r *RemotePluginRuntime) Identity() (string, error) {
+	return strings.Join([]string{r.Configuration().Identity(), r.tenant_id}, ":"), nil
+}
+
 // Listen creates a new listener for the given session_id
 // session id is an unique identifier for a request
 func (r *RemotePluginRuntime) addCallback(session_id string, fn func([]byte)) {

+ 4 - 0
internal/server/server.go

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"github.com/langgenius/dify-plugin-daemon/internal/cluster"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/db"
 	"github.com/langgenius/dify-plugin-daemon/internal/process"
@@ -21,6 +22,9 @@ func Run(config *app.Config) {
 	// init plugin daemon
 	plugin_manager.Init(config)
 
+	// init cluster
+	cluster.Launch(config)
+
 	// start http server
 	server(config)
 }

+ 6 - 2
internal/types/entities/runtime.go

@@ -19,11 +19,12 @@ type (
 	}
 
 	PluginRuntimeTimeLifeInterface interface {
+		Configuration() *plugin_entities.PluginDeclaration
+		Identity() (string, error)
 		InitEnvironment() error
 		StartPlugin() error
 		Stopped() bool
 		Stop()
-		Configuration() *plugin_entities.PluginDeclaration
 		RuntimeState() *PluginRuntimeState
 		Wait() (<-chan bool, error)
 		Type() PluginRuntimeType
@@ -47,6 +48,10 @@ func (r *PluginRuntime) Configuration() *plugin_entities.PluginDeclaration {
 	return &r.Config
 }
 
+func (r *PluginRuntime) Identity() (string, error) {
+	return r.Config.Identity(), nil
+}
+
 func (r *PluginRuntime) RuntimeState() *PluginRuntimeState {
 	return &r.State
 }
@@ -66,7 +71,6 @@ type PluginRuntimeState struct {
 	ActiveAt     *time.Time `json:"active_at"`
 	StoppedAt    *time.Time `json:"stopped_at"`
 	Verified     bool       `json:"verified"`
-	TenantID     string     `json:"tenant_id"`
 }
 
 const (

+ 12 - 0
internal/utils/cache/redis.go

@@ -55,6 +55,14 @@ func serialKey(keys ...string) string {
 }
 
 func Store(key string, value any, time time.Duration, context ...redis.Cmdable) error {
+	if client == nil {
+		return ErrDBNotInit
+	}
+
+	if _, ok := value.(string); !ok {
+		value = parser.MarshalJson(value)
+	}
+
 	return getCmdable(context...).Set(ctx, serialKey(key), value, time).Err()
 }
 
@@ -164,6 +172,10 @@ func SetMapOneField(key string, field string, value any, context ...redis.Cmdabl
 		return ErrDBNotInit
 	}
 
+	if _, ok := value.(string); !ok {
+		value = parser.MarshalJson(value)
+	}
+
 	return getCmdable(context...).HSet(ctx, serialKey(key), field, value).Err()
 }
 

+ 24 - 0
internal/utils/network/ip.go

@@ -0,0 +1,24 @@
+package network
+
+import "net"
+
+// FetchCurrentIps fetches the current IP addresses of the machine
+// only IPv4 addresses are returned
+func FetchCurrentIps() ([]net.IP, error) {
+	ips := []net.IP{}
+
+	addrs, err := net.InterfaceAddrs()
+	if err != nil {
+		return ips, err
+	}
+
+	for _, addr := range addrs {
+		if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
+			if ipNet.IP.To4() != nil {
+				ips = append(ips, ipNet.IP)
+			}
+		}
+	}
+
+	return ips, nil
+}

+ 19 - 0
internal/utils/network/ip_test.go

@@ -0,0 +1,19 @@
+package network
+
+import (
+	"net"
+	"testing"
+)
+
+func TestIsAllIpsAvailableIPv4(t *testing.T) {
+	ips, err := FetchCurrentIps()
+	if err != nil {
+		t.Error(err)
+	}
+
+	for _, ip := range ips {
+		if net.ParseIP(ip.String()).To4() == nil {
+			t.Errorf("invalid ipv4: %s", ip.String())
+		}
+	}
+}

+ 9 - 0
internal/utils/parser/map.go

@@ -0,0 +1,9 @@
+package parser
+
+func Map[From any, To any](f func(From) To, arr []From) []To {
+	var result []To
+	for _, v := range arr {
+		result = append(result, f(v))
+	}
+	return result
+}