Yeuoly пре 11 месеци
родитељ
комит
6b554a40f7

+ 8 - 0
internal/cluster/basic_test.go

@@ -0,0 +1,8 @@
+package cluster
+
+import "github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+
+func clearClusterState() {
+	cache.Del(CLUSTER_STATUS_HASH_MAP_KEY)
+	cache.Del(PREEMPTION_LOCK_KEY)
+}

+ 183 - 0
internal/cluster/redirect_test.go

@@ -0,0 +1,183 @@
+package cluster
+
+import (
+	"context"
+	"fmt"
+	"io"
+	"net/http"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/network"
+)
+
+type SimulationCheckServer struct {
+	http.Server
+
+	port uint16
+}
+
+func createSimulationSevers(nums int, register_callback func(i int, c *gin.Engine)) ([]*SimulationCheckServer, error) {
+	gin.SetMode(gin.ReleaseMode)
+	engines := make([]*gin.Engine, nums)
+	servers := make([]*SimulationCheckServer, nums)
+	for i := 0; i < nums; i++ {
+		engines[i] = gin.Default()
+		register_callback(i, engines[i])
+	}
+
+	// get random port
+	ports := make([]uint16, nums)
+	for i := 0; i < nums; i++ {
+		port, err := network.GetRandomPort()
+		if err != nil {
+			return nil, err
+		}
+		ports[i] = port
+	}
+
+	for i := 0; i < nums; i++ {
+		srv := &SimulationCheckServer{
+			Server: http.Server{
+				Addr:    fmt.Sprintf(":%d", ports[i]),
+				Handler: engines[i],
+			},
+			port: ports[i],
+		}
+		servers[i] = srv
+
+		go func(i int) {
+			srv.ListenAndServe()
+		}(i)
+	}
+
+	return servers, nil
+}
+
+func closeSimulationHealthCheckSevers(servers []*SimulationCheckServer) {
+	for _, server := range servers {
+		server.Shutdown(context.Background())
+	}
+}
+
+func TestRedirectTraffic(t *testing.T) {
+	clearClusterState()
+
+	// create 2 nodes cluster
+	cluster, err := createSimulationCluster(2)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	// wait for voting
+	wg := sync.WaitGroup{}
+	wg.Add(len(cluster))
+	// wait for all voting processes complete
+	for _, node := range cluster {
+		node := node
+		go func() {
+			defer wg.Done()
+			<-node.NotifyVotingCompleted()
+		}()
+	}
+
+	node1_recv_reqs := make(chan struct{})
+	node1_recv_correct_reqs := make(chan struct{})
+	defer close(node1_recv_reqs)
+	defer close(node1_recv_correct_reqs)
+
+	// create 2 simulated servers
+	servers, err := createSimulationSevers(2, func(i int, c *gin.Engine) {
+		c.GET("/plugin/invoke/tool", func(c *gin.Context) {
+			if i == 0 {
+				// redirect to node 1
+				status_code, headers, reader, err := cluster[i].RedirectRequest(cluster[1].id, c.Request)
+				if err != nil {
+					c.String(http.StatusInternalServerError, err.Error())
+					return
+				}
+				c.Status(status_code)
+				for k, v := range headers {
+					for _, vv := range v {
+						c.Header(k, vv)
+					}
+				}
+				io.Copy(c.Writer, reader)
+			} else {
+				c.String(http.StatusOK, "ok")
+				node1_recv_reqs <- struct{}{}
+			}
+		})
+		c.GET("/health/check", func(c *gin.Context) {
+			c.JSON(http.StatusOK, gin.H{
+				"status": "ok",
+			})
+		})
+	})
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer closeSimulationHealthCheckSevers(servers)
+
+	// change port
+	for i, node := range cluster {
+		node.port = servers[i].port
+	}
+
+	// launch cluster
+	launchSimulationCluster(cluster, t)
+	defer closeSimulationCluster(cluster, t)
+
+	// wait for all nodes to be ready
+	wg.Wait()
+
+	// wait for node status to by synchronized
+	wg = sync.WaitGroup{}
+	wg.Add(len(cluster))
+	// wait for all voting processes complete
+	for _, node := range cluster {
+		node := node
+		go func() {
+			defer wg.Done()
+			<-node.NotifyNodeUpdateCompleted()
+		}()
+	}
+	wg.Wait()
+
+	// request to node 0
+	go func() {
+		for i := 0; i < 10; i++ {
+			resp, err := http.Get(fmt.Sprintf("http://localhost:%d/plugin/invoke/tool", servers[0].port))
+			if err != nil {
+				t.Error(err)
+			}
+			content, err := io.ReadAll(resp.Body)
+			if err != nil {
+				t.Error(err)
+			}
+			if string(content) == "ok" {
+				node1_recv_correct_reqs <- struct{}{}
+			}
+		}
+	}()
+
+	// check if node 1 received the request
+	recv_count := 0
+	correct_count := 0
+	for {
+		select {
+		case <-node1_recv_reqs:
+			recv_count++
+		case <-node1_recv_correct_reqs:
+			correct_count++
+			if correct_count == 10 {
+				return
+			}
+		case <-time.After(5 * time.Second):
+			t.Fatal("node 1 did not receive correct requests")
+		}
+	}
+
+}

+ 1 - 1
internal/cluster/vote_test.go

@@ -66,7 +66,7 @@ func TestVoteAddresses(t *testing.T) {
 	// wait for all nodes to be ready
 	wg.Wait()
 
-	// wait for all ips to be voted
+	// wait for all addresses to be voted
 	time.Sleep(time.Second)
 
 	for _, node := range cluster {

+ 5 - 2
internal/core/plugin_manager/manager.go

@@ -76,6 +76,9 @@ func (p *PluginManager) Init(configuration *app.Config) {
 		log.Panic("init dify invocation daemon failed: %s", err.Error())
 	}
 
-	// start plugin watcher
-	p.startWatcher(configuration)
+	// start local watcher
+	p.startLocalWatcher(configuration)
+
+	// start remote watcher
+	p.startRemoteWatcher(configuration)
 }

+ 20 - 0
internal/core/plugin_manager/remote_manager/cs.go

@@ -0,0 +1,20 @@
+package remote_manager
+
+import (
+	"bytes"
+	"crypto/sha256"
+	"encoding/binary"
+	"encoding/hex"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+)
+
+func (m *RemotePluginRuntime) calculateChecksum() string {
+	configuration := m.Configuration()
+	// calculate using sha256
+	buffer := bytes.Buffer{}
+	binary.Write(&buffer, binary.BigEndian, parser.MarshalJsonBytes(configuration))
+	hash := sha256.New()
+	hash.Write(append(buffer.Bytes(), []byte(m.tenant_id)...))
+	return hex.EncodeToString(hash.Sum(nil))
+}

+ 1 - 1
internal/core/plugin_manager/remote_manager/hooks.go

@@ -159,7 +159,7 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 		// registration transferred
 		runtime.registration_transferred = true
 
-		runtime.InitState()
+		runtime.InitState(runtime.calculateChecksum())
 		runtime.SetActiveAt(time.Now())
 
 		// publish runtime to watcher

+ 1 - 3
internal/core/plugin_manager/watcher.go

@@ -15,7 +15,7 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
 
-func (p *PluginManager) startWatcher(config *app.Config) {
+func (p *PluginManager) startLocalWatcher(config *app.Config) {
 	go func() {
 		log.Info("start to handle new plugins in path: %s", config.StoragePath)
 		p.handleNewPlugins(config)
@@ -23,8 +23,6 @@ func (p *PluginManager) startWatcher(config *app.Config) {
 			p.handleNewPlugins(config)
 		}
 	}()
-
-	p.startRemoteWatcher(config)
 }
 
 func (p *PluginManager) startRemoteWatcher(config *app.Config) {

+ 4 - 8
internal/types/entities/runtime.go

@@ -3,14 +3,12 @@ package entities
 import (
 	"bytes"
 	"crypto/sha256"
-	"encoding/binary"
 	"encoding/gob"
 	"encoding/hex"
 	"hash/fnv"
 	"time"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
-	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 )
 
 type (
@@ -111,7 +109,7 @@ func (r *PluginRuntime) UpdateScheduledAt(t time.Time) {
 	r.State.ScheduledAt = &t
 }
 
-func (r *PluginRuntime) InitState() {
+func (r *PluginRuntime) InitState(checksum string) {
 	r.State = PluginRuntimeState{
 		Restarts:    0,
 		Status:      PLUGIN_RUNTIME_STATUS_PENDING,
@@ -120,6 +118,7 @@ func (r *PluginRuntime) InitState() {
 		Verified:    false,
 		ScheduledAt: nil,
 		Logs:        []string{},
+		Checksum:    checksum,
 	}
 }
 
@@ -152,11 +151,7 @@ func (r *PluginRuntime) AddRestarts() {
 }
 
 func (r *PluginRuntime) Checksum() string {
-	buf := bytes.Buffer{}
-	binary.Write(&buf, binary.BigEndian, parser.MarshalJsonBytes(r.Config))
-	hash := sha256.New()
-	hash.Write(buf.Bytes())
-	return hex.EncodeToString(hash.Sum(nil))
+	return r.State.Checksum
 }
 
 func (r *PluginRuntime) OnStop(f func()) {
@@ -186,6 +181,7 @@ type PluginRuntimeState struct {
 	Verified     bool       `json:"verified"`
 	ScheduledAt  *time.Time `json:"scheduled_at"`
 	Logs         []string   `json:"logs"`
+	Checksum     string     `json:"checksum"`
 }
 
 func (s *PluginRuntimeState) Hash() (uint64, error) {