Browse Source

feat: transaction tests

Yeuoly 11 months ago
parent
commit
9aa40a23dd

+ 49 - 32
internal/core/plugin_manager/aws_manager/full_duplex_simulator.go

@@ -36,6 +36,9 @@ type FullDuplexSimulator struct {
 	// total transactions
 	total_transactions int32
 
+	// connection restarts
+	connection_restarts int32
+
 	// sent bytes
 	sent_bytes int64
 	// received bytes
@@ -58,24 +61,27 @@ type FullDuplexSimulator struct {
 	// max retries
 	max_retries int
 
+	// request id
+	request_id string
+
+	// latest routine id
+	latest_routine_id string
+
 	// is sending connection alive
 	sending_connection_alive         int32
 	sending_routine_lock             sync.Mutex
 	virtual_sending_connection_alive int32
 
+	// receiving routine lock
+	receiving_routine_lock sync.Mutex
 	// is receiving connection alive
-	receiving_connection_alive         int32
-	receiving_routine_lock             sync.Mutex
 	virtual_receiving_connection_alive int32
 
 	// listener for data
 	listeners []func(data []byte)
 
 	// mutex for listeners
-	listeners_mu sync.RWMutex
-
-	// request id
-	request_id string
+	listeners_lock sync.RWMutex
 
 	// http client
 	client *http.Client
@@ -96,6 +102,7 @@ func NewFullDuplexSimulator(
 		sending_connection_max_alive_time:   sending_connection_max_alive_time,
 		receiving_connection_max_alive_time: receiving_connection_max_alive_time,
 		max_retries:                         10,
+		request_id:                          strings.RandomString(32),
 
 		// using keep alive to reduce the connection reset
 		client: &http.Client{
@@ -149,8 +156,8 @@ func (s *FullDuplexSimulator) Send(data []byte, timeout ...time.Duration) error
 }
 
 func (s *FullDuplexSimulator) On(f func(data []byte)) {
-	s.listeners_mu.Lock()
-	defer s.listeners_mu.Unlock()
+	s.listeners_lock.Lock()
+	defer s.listeners_lock.Unlock()
 	s.listeners = append(s.listeners, f)
 }
 
@@ -159,16 +166,22 @@ func (s *FullDuplexSimulator) On(f func(data []byte)) {
 func (s *FullDuplexSimulator) StartTransaction() (func(), error) {
 	// start a transaction
 	if atomic.AddInt32(&s.alive_transactions, 1) == 1 {
+		// increase connection restarts
+		atomic.AddInt32(&s.connection_restarts, 1)
+
 		// reset request id
-		s.request_id = strings.RandomString(32)
+		routine_id := strings.RandomString(32)
+
+		// update latest request id
+		s.latest_routine_id = routine_id
 
 		// start sending connection
-		if err := s.startSendingConnection(); err != nil {
+		if err := s.startSendingConnection(routine_id); err != nil {
 			return nil, err
 		}
 
 		// start receiving connection
-		if err := s.startReceivingConnection(); err != nil {
+		if err := s.startReceivingConnection(routine_id); err != nil {
 			s.stopSendingConnection()
 			return nil, err
 		}
@@ -187,15 +200,12 @@ func (s *FullDuplexSimulator) stopTransaction() {
 	}
 }
 
-func (s *FullDuplexSimulator) startSendingConnection() error {
+func (s *FullDuplexSimulator) startSendingConnection(routine_id string) error {
 	// if virtual sending connection is already alive, do nothing
-	if atomic.LoadInt32(&s.virtual_sending_connection_alive) == 1 {
+	if !atomic.CompareAndSwapInt32(&s.virtual_sending_connection_alive, 0, 1) {
 		return nil
 	}
 
-	// set virtual sending connection as alive
-	atomic.StoreInt32(&s.virtual_sending_connection_alive, 1)
-
 	// lock the sending connection
 	s.sending_connection_timeline_lock.Lock()
 	defer s.sending_connection_timeline_lock.Unlock()
@@ -216,20 +226,26 @@ func (s *FullDuplexSimulator) startSendingConnection() error {
 	req.Header.Set("x-dify-plugin-request-id", s.request_id)
 
 	routine.Submit(func() {
-		s.sendingConnectionRoutine(req)
+		s.sendingConnectionRoutine(req, routine_id)
 	})
 
 	return nil
 }
 
-func (s *FullDuplexSimulator) sendingConnectionRoutine(origin_req *http.Request) {
+func (s *FullDuplexSimulator) sendingConnectionRoutine(origin_req *http.Request, routine_id string) {
 	// lock the sending routine, to avoid there are multiple routines trying to establish the sending connection
 	s.sending_routine_lock.Lock()
+
 	// cancel the sending routine
 	defer s.sending_routine_lock.Unlock()
 
 	failed_times := 0
 	for atomic.LoadInt32(&s.virtual_sending_connection_alive) == 1 {
+		// check if the request id is the latest one, avoid this routine being used by a old request
+		if routine_id != s.latest_routine_id {
+			return
+		}
+
 		ctx, cancel := context.WithCancel(context.Background())
 		time.AfterFunc(s.sending_connection_max_alive_time, func() {
 			// reached max alive time, remove pipe writer
@@ -255,6 +271,8 @@ func (s *FullDuplexSimulator) sendingConnectionRoutine(origin_req *http.Request)
 
 		resp, err := s.client.Do(req)
 		if err != nil {
+			atomic.StoreInt32(&s.sending_connection_alive, 0)
+
 			// if virtual sending connection is not alive, clear the sending pipeline and return
 			if atomic.LoadInt32(&s.virtual_sending_connection_alive) == 0 {
 				// clear the sending pipeline
@@ -293,7 +311,7 @@ func (s *FullDuplexSimulator) sendingConnectionRoutine(origin_req *http.Request)
 }
 
 func (s *FullDuplexSimulator) stopSendingConnection() error {
-	if atomic.LoadInt32(&s.virtual_sending_connection_alive) == 0 {
+	if !atomic.CompareAndSwapInt32(&s.virtual_sending_connection_alive, 1, 0) {
 		return nil
 	}
 
@@ -315,33 +333,35 @@ func (s *FullDuplexSimulator) stopSendingConnection() error {
 	return nil
 }
 
-func (s *FullDuplexSimulator) startReceivingConnection() error {
+func (s *FullDuplexSimulator) startReceivingConnection(request_id string) error {
 	// if virtual receiving connection is already alive, do nothing
-	if atomic.LoadInt32(&s.virtual_receiving_connection_alive) == 1 {
+	if !atomic.CompareAndSwapInt32(&s.virtual_receiving_connection_alive, 0, 1) {
 		return nil
 	}
 
-	// set virtual receiving connection as alive
-	atomic.StoreInt32(&s.virtual_receiving_connection_alive, 1)
-
 	// lock the receiving connection
 	s.receiving_connection_timeline_lock.Lock()
 	defer s.receiving_connection_timeline_lock.Unlock()
 
 	routine.Submit(func() {
-		s.receivingConnectionRoutine()
+		s.receivingConnectionRoutine(request_id)
 	})
 
 	return nil
 }
 
-func (s *FullDuplexSimulator) receivingConnectionRoutine() {
+func (s *FullDuplexSimulator) receivingConnectionRoutine(routine_id string) {
 	// lock the receiving routine, to avoid there are multiple routines trying to establish the receiving connection
 	s.receiving_routine_lock.Lock()
 	// cancel the receiving routine
 	defer s.receiving_routine_lock.Unlock()
 
 	for atomic.LoadInt32(&s.virtual_receiving_connection_alive) == 1 {
+		// check if the request id is the latest one, avoid this routine being used by a old request
+		if routine_id != s.latest_routine_id {
+			return
+		}
+
 		recved_pong := false
 		buf := make([]byte, 0)
 		buf_len := 0
@@ -420,13 +440,10 @@ func (s *FullDuplexSimulator) receivingConnectionRoutine() {
 }
 
 func (s *FullDuplexSimulator) stopReceivingConnection() {
-	if atomic.LoadInt32(&s.virtual_receiving_connection_alive) == 0 {
+	if !atomic.CompareAndSwapInt32(&s.virtual_receiving_connection_alive, 1, 0) {
 		return
 	}
 
-	// mark receiving connection as dead
-	atomic.StoreInt32(&s.virtual_receiving_connection_alive, 0)
-
 	// cancel the receiving context
 	s.receiving_cancel_lock.Lock()
 	if s.receiving_cancel != nil {
@@ -436,6 +453,6 @@ func (s *FullDuplexSimulator) stopReceivingConnection() {
 }
 
 // GetStats, returns the sent and received bytes
-func (s *FullDuplexSimulator) GetStats() (sent_bytes, received_bytes int64) {
-	return atomic.LoadInt64(&s.sent_bytes), atomic.LoadInt64(&s.received_bytes)
+func (s *FullDuplexSimulator) GetStats() (sent_bytes, received_bytes int64, connection_restarts int32) {
+	return atomic.LoadInt64(&s.sent_bytes), atomic.LoadInt64(&s.received_bytes), atomic.LoadInt32(&s.connection_restarts)
 }

+ 201 - 40
internal/core/plugin_manager/aws_manager/full_duplex_simulator_test.go

@@ -4,48 +4,89 @@ import (
 	"bytes"
 	"fmt"
 	"net/http"
+	"strconv"
 	"strings"
 	"sync"
+	"sync/atomic"
 	"testing"
 	"time"
 
 	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/debugging"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/network"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
 
-func server(recv_timeout time.Duration, send_timeout time.Duration) (string, func(), error) {
+func init() {
 	routine.InitPool(1024)
+}
+
+type S struct {
+	srv  *http.Server
+	url  string
+	port int
+
+	send_count          int32
+	recv_buffered_count int32
+	recv_count          int32
+
+	send_request int32
+	recv_request int32
+
+	data_mu sync.Mutex
+	data    map[string]chan []byte
+
+	current_recv_request_id string
+	current_send_request_id string
+}
+
+func (s *S) Stop() {
+	s.srv.Close()
+}
 
+func server(recv_timeout time.Duration, send_timeout time.Duration) (*S, error) {
 	port, err := network.GetRandomPort()
 	if err != nil {
-		return "", nil, err
+		return nil, err
 	}
 
-	data := map[string]chan []byte{}
-	data_mu := sync.Mutex{}
+	eng := gin.New()
+	srv := &http.Server{
+		Addr:    fmt.Sprintf(":%d", port),
+		Handler: eng,
+	}
 
-	recved := 0
+	s := &S{
+		srv:  srv,
+		url:  fmt.Sprintf("http://localhost:%d", port),
+		data: make(map[string]chan []byte),
 
-	eng := gin.New()
+		send_count: 0,
+		recv_count: 0,
+	}
 
 	// avoid log
 	gin.SetMode(gin.ReleaseMode)
 
 	eng.POST("/invoke", func(c *gin.Context) {
+		atomic.AddInt32(&s.send_request, 1)
+		defer atomic.AddInt32(&s.send_request, -1)
+
 		// fmt.Println("new send request")
 		id := c.Request.Header.Get("x-dify-plugin-request-id")
+		s.current_send_request_id = id
+
 		var ch chan []byte
 
-		data_mu.Lock()
-		if _, ok := data[id]; !ok {
-			ch = make(chan []byte, 1024)
-			data[id] = ch
+		s.data_mu.Lock()
+		if _, ok := s.data[id]; !ok {
+			ch = make(chan []byte)
+			s.data[id] = ch
 		} else {
-			ch = data[id]
+			ch = s.data[id]
 		}
-		data_mu.Unlock()
+		s.data_mu.Unlock()
 
 		time.AfterFunc(send_timeout, func() {
 			c.Request.Body.Close()
@@ -56,8 +97,9 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (string, fun
 			buf := make([]byte, 1024)
 			n, err := c.Request.Body.Read(buf)
 			if n != 0 {
-				recved += n
+				atomic.AddInt32(&s.recv_buffered_count, int32(n))
 				ch <- buf[:n]
+				atomic.AddInt32(&s.recv_count, int32(n))
 			}
 			if err != nil {
 				break
@@ -70,20 +112,23 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (string, fun
 		c.Writer.Flush()
 	})
 
-	response := 0
-
 	eng.GET("/response", func(ctx *gin.Context) {
+		atomic.AddInt32(&s.recv_request, 1)
+		defer atomic.AddInt32(&s.recv_request, -1)
+
 		// fmt.Println("new recv request")
 		id := ctx.Request.Header.Get("x-dify-plugin-request-id")
+		s.current_recv_request_id = id
+
 		var ch chan []byte
-		data_mu.Lock()
-		if _, ok := data[id]; ok {
-			ch = data[id]
+		s.data_mu.Lock()
+		if _, ok := s.data[id]; ok {
+			ch = s.data[id]
 		} else {
-			ch = make(chan []byte, 1024)
-			data[id] = ch
+			ch = make(chan []byte)
+			s.data[id] = ch
 		}
-		data_mu.Unlock()
+		s.data_mu.Unlock()
 
 		ctx.Writer.WriteHeader(http.StatusOK)
 		ctx.Writer.Header().Set("Content-Type", "application/octet-stream")
@@ -99,7 +144,7 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (string, fun
 			case data := <-ch:
 				ctx.Writer.Write(data)
 				ctx.Writer.Flush()
-				response += len(data)
+				atomic.AddInt32(&s.send_count, int32(len(data)))
 			case <-ctx.Done():
 				return
 			case <-ctx.Writer.CloseNotify():
@@ -111,34 +156,26 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (string, fun
 		}
 	})
 
-	srv := &http.Server{
-		Addr:    fmt.Sprintf(":%d", port),
-		Handler: eng,
-	}
-
 	go func() {
 		srv.ListenAndServe()
 	}()
 
-	return fmt.Sprintf("http://localhost:%d", port), func() {
-		srv.Close()
-		// fmt.Printf("recved: %d, responsed: %d\n", recved, response)
-	}, nil
+	return s, nil
 }
 
 func TestFullDuplexSimulator_SingleSendAndReceive(t *testing.T) {
 	log.SetShowLog(false)
 	defer log.SetShowLog(true)
 
-	url, cleanup, err := server(time.Second*100, time.Second*100)
+	srv, err := server(time.Second*100, time.Second*100)
 	if err != nil {
 		t.Fatal(err)
 	}
-	defer cleanup()
+	defer srv.Stop()
 
 	time.Sleep(time.Second)
 
-	simulator, err := NewFullDuplexSimulator(url, time.Second*100, time.Second*100)
+	simulator, err := NewFullDuplexSimulator(srv.url, time.Second*100, time.Second*100)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -178,22 +215,22 @@ func TestFullDuplexSimulator_AutoReconnect(t *testing.T) {
 	defer log.SetShowLog(true)
 
 	// hmmm, to ensure the server is stable, we need to run the test 100 times
-	// don't ask me why, just trust me, I have spent 1 days to handle this race condition
+	// don't ask me why, just trust me, I have spent 1 days to correctly handle this race condition
 	wg := sync.WaitGroup{}
 	wg.Add(100)
 	for i := 0; i < 100; i++ {
 		go func() {
 			defer wg.Done()
 
-			url, cleanup, err := server(time.Millisecond*700, time.Second*10)
+			srv, err := server(time.Millisecond*700, time.Second*10)
 			if err != nil {
 				t.Fatal(err)
 			}
-			defer cleanup()
+			defer srv.Stop()
 
 			time.Sleep(time.Second)
 
-			simulator, err := NewFullDuplexSimulator(url, time.Millisecond*700, time.Second*10)
+			simulator, err := NewFullDuplexSimulator(srv.url, time.Millisecond*700, time.Second*10)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -227,8 +264,8 @@ func TestFullDuplexSimulator_AutoReconnect(t *testing.T) {
 			time.Sleep(time.Millisecond * 500)
 
 			if l != 3000*5 {
-				sent, received := simulator.GetStats()
-				t.Errorf(fmt.Sprintf("expected: %d, actual: %d, sent: %d, received: %d", 3000*5, l, sent, received))
+				sent, received, restarts := simulator.GetStats()
+				t.Errorf(fmt.Sprintf("expected: %d, actual: %d, sent: %d, received: %d, restarts: %d", 3000*5, l, sent, received, restarts))
 				// to find which one is missing
 				for i := 0; i < 3000; i++ {
 					if !strings.Contains(recved.String(), fmt.Sprintf("%05d", i)) {
@@ -241,3 +278,127 @@ func TestFullDuplexSimulator_AutoReconnect(t *testing.T) {
 
 	wg.Wait()
 }
+
+func TestFullDuplexSimulator_MultipleTransactions(t *testing.T) {
+	log.SetShowLog(false)
+	defer log.SetShowLog(true)
+
+	// avoid too many test cases, it will cause too many goroutines
+	// finally, os will run into busy, and requests can not be handled correctly in time
+	const NUM_CASES = 30
+
+	w := sync.WaitGroup{}
+	w.Add(NUM_CASES)
+
+	for j := 0; j < NUM_CASES; j++ {
+		// j := j
+		go func() {
+			defer w.Done()
+
+			srv, err := server(time.Millisecond*700, time.Second*10)
+			if err != nil {
+				t.Fatal(err)
+			}
+			defer srv.Stop()
+
+			time.Sleep(time.Second)
+
+			simulator, err := NewFullDuplexSimulator(srv.url, time.Millisecond*700, time.Second*10)
+			if err != nil {
+				t.Fatal(err)
+			}
+
+			l := int32(0)
+
+			dones := make(map[int]func())
+			dones_lock := sync.Mutex{}
+
+			buf := bytes.Buffer{}
+			simulator.On(func(data []byte) {
+				debugging.PossibleBlocking(
+					func() any {
+						atomic.AddInt32(&l, int32(len(data)))
+
+						buf.Write(data)
+
+						bytes := buf.Bytes()
+						buf.Reset()
+
+						i := 0
+						for i < len(bytes) {
+							num, err := strconv.Atoi(string(bytes[i : i+5]))
+							if err != nil {
+								t.Fatalf("invalid data: %s", string(bytes))
+							}
+
+							dones_lock.Lock()
+
+							if done, ok := dones[num]; ok {
+								done()
+							} else {
+								t.Fatalf("done not found: %d", num)
+							}
+
+							dones_lock.Unlock()
+
+							i += 5
+						}
+
+						if buf.Len() != i {
+							// write the rest of the data
+							b := make([]byte, len(bytes)-i)
+							copy(b, bytes[i:])
+							buf.Write(b)
+						}
+
+						return nil
+					},
+					time.Second*1,
+					func() {
+						t.Fatal("possible blocking triggered")
+					},
+				)
+			})
+
+			wg := sync.WaitGroup{}
+			wg.Add(100)
+
+			for i := 0; i < 100; i++ {
+				i := i
+				time.Sleep(time.Millisecond * 20)
+				go func() {
+					done, err := simulator.StartTransaction()
+					if err != nil {
+						t.Fatal(err)
+					}
+
+					dones_lock.Lock()
+					dones[i] = func() {
+						done()
+						wg.Done()
+					}
+					dones_lock.Unlock()
+
+					if err := simulator.Send([]byte(fmt.Sprintf("%05d", i))); err != nil {
+						t.Fatal(err)
+					}
+				}()
+			}
+
+			// time.AfterFunc(time.Second*5, func() {
+			// 	// fmt.Println("server recv count: ", srv.recv_count, "server send count: ", srv.send_count, "j: ", j,
+			// 	// 	"server recv request: ", srv.recv_request, "server send request: ", srv.send_request)
+			// 	// fmt.Println("current_recv_request_id: ", srv.current_recv_request_id, "current_send_request_id: ", srv.current_send_request_id)
+			// })
+
+			wg.Wait()
+
+			if l != 100*5 {
+				sent, received, restarts := simulator.GetStats()
+				t.Errorf(fmt.Sprintf("expected: %d, actual: %d, sent: %d, received: %d, restarts: %d", 100*5, l, sent, received, restarts))
+			}
+		}()
+	}
+
+	w.Wait()
+}

+ 27 - 0
internal/utils/debugging/wait.go

@@ -0,0 +1,27 @@
+package debugging
+
+import (
+	"time"
+)
+
+// PossibleBlocking runs the function f in a goroutine and returns the result.
+// If the function f is blocking, the test will fail.
+func PossibleBlocking[T any](f func() T, timeout time.Duration, trigger func()) T {
+	d := make(chan T)
+
+	go func() {
+		d <- f()
+	}()
+
+	timer := time.NewTimer(timeout)
+	defer timer.Stop()
+
+	for {
+		select {
+		case <-timer.C:
+			trigger()
+		case v := <-d:
+			return v
+		}
+	}
+}

+ 35 - 0
internal/utils/debugging/wait_test.go

@@ -0,0 +1,35 @@
+package debugging
+
+import (
+	"testing"
+	"time"
+)
+
+func TestPossibleBlocking(t *testing.T) {
+	triggered := false
+
+	PossibleBlocking(func() any {
+		time.Sleep(time.Second * 1)
+		return nil
+	}, time.Millisecond*500, func() {
+		triggered = true
+	})
+
+	if !triggered {
+		t.Fatal("possible blocking not triggered")
+	}
+}
+
+func TestPossibleBlocking_Blocking(t *testing.T) {
+	triggered := false
+
+	PossibleBlocking(func() any {
+		return nil
+	}, time.Second*1, func() {
+		triggered = true
+	})
+
+	if triggered {
+		t.Fatal("possible blocking triggered")
+	}
+}

+ 14 - 0
internal/utils/routine/pool.go

@@ -1,15 +1,29 @@
 package routine
 
 import (
+	"sync"
+
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"github.com/panjf2000/ants"
 )
 
 var (
 	p *ants.Pool
+	l sync.Mutex
 )
 
+func IsInit() bool {
+	l.Lock()
+	defer l.Unlock()
+	return p != nil
+}
+
 func InitPool(size int) {
+	l.Lock()
+	defer l.Unlock()
+	if p != nil {
+		return
+	}
 	log.Info("init routine pool, size: %d", size)
 	p, _ = ants.NewPool(size, ants.WithNonblocking(false))
 }