|
@@ -0,0 +1,183 @@
|
|
|
+package cluster
|
|
|
+
|
|
|
+import (
|
|
|
+ "context"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "net/http"
|
|
|
+ "sync"
|
|
|
+ "testing"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/gin-gonic/gin"
|
|
|
+ "github.com/langgenius/dify-plugin-daemon/internal/utils/network"
|
|
|
+)
|
|
|
+
|
|
|
+type SimulationCheckServer struct {
|
|
|
+ http.Server
|
|
|
+
|
|
|
+ port uint16
|
|
|
+}
|
|
|
+
|
|
|
+func createSimulationSevers(nums int, register_callback func(i int, c *gin.Engine)) ([]*SimulationCheckServer, error) {
|
|
|
+ gin.SetMode(gin.ReleaseMode)
|
|
|
+ engines := make([]*gin.Engine, nums)
|
|
|
+ servers := make([]*SimulationCheckServer, nums)
|
|
|
+ for i := 0; i < nums; i++ {
|
|
|
+ engines[i] = gin.Default()
|
|
|
+ register_callback(i, engines[i])
|
|
|
+ }
|
|
|
+
|
|
|
+ // get random port
|
|
|
+ ports := make([]uint16, nums)
|
|
|
+ for i := 0; i < nums; i++ {
|
|
|
+ port, err := network.GetRandomPort()
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+ ports[i] = port
|
|
|
+ }
|
|
|
+
|
|
|
+ for i := 0; i < nums; i++ {
|
|
|
+ srv := &SimulationCheckServer{
|
|
|
+ Server: http.Server{
|
|
|
+ Addr: fmt.Sprintf(":%d", ports[i]),
|
|
|
+ Handler: engines[i],
|
|
|
+ },
|
|
|
+ port: ports[i],
|
|
|
+ }
|
|
|
+ servers[i] = srv
|
|
|
+
|
|
|
+ go func(i int) {
|
|
|
+ srv.ListenAndServe()
|
|
|
+ }(i)
|
|
|
+ }
|
|
|
+
|
|
|
+ return servers, nil
|
|
|
+}
|
|
|
+
|
|
|
+func closeSimulationHealthCheckSevers(servers []*SimulationCheckServer) {
|
|
|
+ for _, server := range servers {
|
|
|
+ server.Shutdown(context.Background())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func TestRedirectTraffic(t *testing.T) {
|
|
|
+ clearClusterState()
|
|
|
+
|
|
|
+ // create 2 nodes cluster
|
|
|
+ cluster, err := createSimulationCluster(2)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ // wait for voting
|
|
|
+ wg := sync.WaitGroup{}
|
|
|
+ wg.Add(len(cluster))
|
|
|
+ // wait for all voting processes complete
|
|
|
+ for _, node := range cluster {
|
|
|
+ node := node
|
|
|
+ go func() {
|
|
|
+ defer wg.Done()
|
|
|
+ <-node.NotifyVotingCompleted()
|
|
|
+ }()
|
|
|
+ }
|
|
|
+
|
|
|
+ node1_recv_reqs := make(chan struct{})
|
|
|
+ node1_recv_correct_reqs := make(chan struct{})
|
|
|
+ defer close(node1_recv_reqs)
|
|
|
+ defer close(node1_recv_correct_reqs)
|
|
|
+
|
|
|
+ // create 2 simulated servers
|
|
|
+ servers, err := createSimulationSevers(2, func(i int, c *gin.Engine) {
|
|
|
+ c.GET("/plugin/invoke/tool", func(c *gin.Context) {
|
|
|
+ if i == 0 {
|
|
|
+ // redirect to node 1
|
|
|
+ status_code, headers, reader, err := cluster[i].RedirectRequest(cluster[1].id, c.Request)
|
|
|
+ if err != nil {
|
|
|
+ c.String(http.StatusInternalServerError, err.Error())
|
|
|
+ return
|
|
|
+ }
|
|
|
+ c.Status(status_code)
|
|
|
+ for k, v := range headers {
|
|
|
+ for _, vv := range v {
|
|
|
+ c.Header(k, vv)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ io.Copy(c.Writer, reader)
|
|
|
+ } else {
|
|
|
+ c.String(http.StatusOK, "ok")
|
|
|
+ node1_recv_reqs <- struct{}{}
|
|
|
+ }
|
|
|
+ })
|
|
|
+ c.GET("/health/check", func(c *gin.Context) {
|
|
|
+ c.JSON(http.StatusOK, gin.H{
|
|
|
+ "status": "ok",
|
|
|
+ })
|
|
|
+ })
|
|
|
+ })
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ defer closeSimulationHealthCheckSevers(servers)
|
|
|
+
|
|
|
+ // change port
|
|
|
+ for i, node := range cluster {
|
|
|
+ node.port = servers[i].port
|
|
|
+ }
|
|
|
+
|
|
|
+ // launch cluster
|
|
|
+ launchSimulationCluster(cluster, t)
|
|
|
+ defer closeSimulationCluster(cluster, t)
|
|
|
+
|
|
|
+ // wait for all nodes to be ready
|
|
|
+ wg.Wait()
|
|
|
+
|
|
|
+ // wait for node status to by synchronized
|
|
|
+ wg = sync.WaitGroup{}
|
|
|
+ wg.Add(len(cluster))
|
|
|
+ // wait for all voting processes complete
|
|
|
+ for _, node := range cluster {
|
|
|
+ node := node
|
|
|
+ go func() {
|
|
|
+ defer wg.Done()
|
|
|
+ <-node.NotifyNodeUpdateCompleted()
|
|
|
+ }()
|
|
|
+ }
|
|
|
+ wg.Wait()
|
|
|
+
|
|
|
+ // request to node 0
|
|
|
+ go func() {
|
|
|
+ for i := 0; i < 10; i++ {
|
|
|
+ resp, err := http.Get(fmt.Sprintf("http://localhost:%d/plugin/invoke/tool", servers[0].port))
|
|
|
+ if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ content, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ t.Error(err)
|
|
|
+ }
|
|
|
+ if string(content) == "ok" {
|
|
|
+ node1_recv_correct_reqs <- struct{}{}
|
|
|
+ }
|
|
|
+ }
|
|
|
+ }()
|
|
|
+
|
|
|
+ // check if node 1 received the request
|
|
|
+ recv_count := 0
|
|
|
+ correct_count := 0
|
|
|
+ for {
|
|
|
+ select {
|
|
|
+ case <-node1_recv_reqs:
|
|
|
+ recv_count++
|
|
|
+ case <-node1_recv_correct_reqs:
|
|
|
+ correct_count++
|
|
|
+ if correct_count == 10 {
|
|
|
+ return
|
|
|
+ }
|
|
|
+ case <-time.After(5 * time.Second):
|
|
|
+ t.Fatal("node 1 did not receive correct requests")
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+}
|