Pārlūkot izejas kodu

feat: simulator for full duplex traffic control

Yeuoly 1 gadu atpakaļ
vecāks
revīzija
f8dfb5e751

+ 170 - 119
internal/core/plugin_manager/aws_manager/full_duplex_simulator.go

@@ -26,30 +26,45 @@ import (
 type FullDuplexSimulator struct {
 	baseurl *url.URL
 
+	// single connection max alive time
+	sending_connection_max_alive_time   time.Duration
+	receiving_connection_max_alive_time time.Duration
+
 	// how many transactions are alive
 	alive_transactions int32
 
 	// total transactions
 	total_transactions int32
 
+	// sent bytes
+	sent_bytes int64
+	// received bytes
+	received_bytes int64
+
 	// sending_connection_timeline_lock
 	sending_connection_timeline_lock sync.Mutex
 	// sending pipeline
 	sending_pipeline *io.PipeWriter
+	// sending pipe lock
+	sending_pipe_lock sync.RWMutex
 
 	// receiving_connection_timeline_lock
 	receiving_connection_timeline_lock sync.Mutex
-	// receiving reader
-	receiving_reader io.ReadCloser
+	// receiving context
+	receiving_cancel context.CancelFunc
+	// receiving context lock
+	receiving_context_lock sync.Mutex
 
 	// max retries
 	max_retries int
 
 	// is sending connection alive
-	sending_connection_alive int32
+	sending_connection_alive         int32
+	virtual_sending_connection_alive int32
 
 	// is receiving connection alive
-	receiving_connection_alive int32
+	receiving_connection_alive         int32
+	virtual_receiving_connection_alive int32
 
 	// listener for data
 	listeners []func(data []byte)
@@ -64,16 +79,22 @@ type FullDuplexSimulator struct {
 	client *http.Client
 }
 
-func NewFullDuplexSimulator(baseurl string) (*FullDuplexSimulator, error) {
+func NewFullDuplexSimulator(
+	baseurl string,
+	sending_connection_max_alive_time time.Duration,
+	receiving_connection_max_alive_time time.Duration,
+) (*FullDuplexSimulator, error) {
 	u, err := url.Parse(baseurl)
 	if err != nil {
 		return nil, err
 	}
 
 	return &FullDuplexSimulator{
-		baseurl:     u,
-		max_retries: 10,
-		request_id:  strings.RandomString(32),
+		baseurl:                             u,
+		sending_connection_max_alive_time:   sending_connection_max_alive_time,
+		receiving_connection_max_alive_time: receiving_connection_max_alive_time,
+		max_retries:                         10,
+		request_id:                          strings.RandomString(32),
 
 		// using keep alive to reduce the connection reset
 		client: &http.Client{
@@ -89,21 +110,41 @@ func NewFullDuplexSimulator(baseurl string) (*FullDuplexSimulator, error) {
 }
 
 // send data to server
-func (s *FullDuplexSimulator) Send(data []byte) error {
-	if atomic.LoadInt32(&s.sending_connection_alive) == 0 {
-		return errors.New("sending connection is not alive")
+func (s *FullDuplexSimulator) Send(data []byte, timeout ...time.Duration) error {
+	timeout_duration := time.Second * 10
+	if len(timeout) > 0 {
+		timeout_duration = timeout[0]
 	}
 
-	writer := s.sending_pipeline
-	if writer == nil {
-		return errors.New("sending pipeline is not alive")
-	}
+	started := time.Now()
 
-	if _, err := writer.Write(data); err != nil {
-		return err
+	for time.Since(started) < timeout_duration {
+		if atomic.LoadInt32(&s.sending_connection_alive) != 1 {
+			time.Sleep(time.Millisecond * 50)
+			continue
+		}
+
+		s.sending_pipe_lock.Lock()
+		writer := s.sending_pipeline
+		if writer == nil {
+			time.Sleep(time.Millisecond * 50)
+			s.sending_pipe_lock.Unlock()
+			continue
+		}
+
+		if n, err := writer.Write(data); err != nil {
+			time.Sleep(time.Millisecond * 50)
+			s.sending_pipe_lock.Unlock()
+			continue
+		} else {
+			atomic.AddInt64(&s.sent_bytes, int64(n))
+		}
+
+		s.sending_pipe_lock.Unlock()
+		return nil
 	}
 
-	return nil
+	return errors.New("send data timeout")
 }
 
 func (s *FullDuplexSimulator) On(f func(data []byte)) {
@@ -155,9 +196,7 @@ func (s *FullDuplexSimulator) startSendingConnection() error {
 		return err
 	}
 
-	pr, pw := io.Pipe()
-
-	req, err := http.NewRequest("POST", u, pr)
+	req, err := http.NewRequest("POST", u, nil)
 	if err != nil {
 		return err
 	}
@@ -170,24 +209,37 @@ func (s *FullDuplexSimulator) startSendingConnection() error {
 		s.sendingConnectionRoutine(req)
 	})
 
-	// mark sending connection as alive
-	atomic.StoreInt32(&s.sending_connection_alive, 1)
-
-	// set the sending pipeline
-	s.sending_pipeline = pw
-
 	return nil
 }
 
 func (s *FullDuplexSimulator) sendingConnectionRoutine(origin_req *http.Request) {
 	failed_times := 0
 	for {
-		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
-		time.AfterFunc(5*time.Second, func() {
-			cancel()
+		// this real connection will be closed after single_connection_max_alive_time
+		// but the virtual connection will be established again and again
+		ctx, cancel := context.WithCancel(context.Background())
+		time.AfterFunc(s.sending_connection_max_alive_time, func() {
+			// reached max alive time, remove pipe writer
+			s.sending_pipe_lock.Lock()
+			if s.sending_pipeline != nil {
+				s.sending_pipeline.Close()
+				s.sending_pipeline = nil
+			}
+			s.sending_pipe_lock.Unlock()
+			time.AfterFunc(time.Second, cancel)
 		})
+
 		req := origin_req.Clone(ctx)
+		pr, pw := io.Pipe()
+		s.sending_pipe_lock.Lock()
+		req.Body = pr
+		s.sending_pipeline = pw
+		s.sending_pipe_lock.Unlock()
 		req = req.WithContext(ctx)
+
+		// mark sending connection as alive
+		atomic.StoreInt32(&s.sending_connection_alive, 1)
+
 		resp, err := s.client.Do(req)
 
 		if err != nil {
@@ -199,27 +251,25 @@ func (s *FullDuplexSimulator) sendingConnectionRoutine(origin_req *http.Request)
 			}
 
 			log.Error("failed to establish sending connection: %v", err)
-			continue
+		} else {
+			defer resp.Body.Close()
 		}
 
-		defer resp.Body.Close()
-
 		// mark sending connection as dead
 		atomic.StoreInt32(&s.sending_connection_alive, 0)
 
-		s.sending_connection_timeline_lock.Lock()
-		defer s.sending_connection_timeline_lock.Unlock()
-
+		s.sending_pipe_lock.Lock()
 		// close the sending pipeline
 		if s.sending_pipeline != nil {
 			s.sending_pipeline.Close()
 			s.sending_pipeline = nil
 		}
+		s.sending_pipe_lock.Unlock()
 	}
 }
 
 func (s *FullDuplexSimulator) stopSendingConnection() error {
-	if atomic.LoadInt32(&s.sending_connection_alive) == 0 {
+	if atomic.LoadInt32(&s.virtual_sending_connection_alive) == 0 {
 		return nil
 	}
 
@@ -233,135 +283,136 @@ func (s *FullDuplexSimulator) stopSendingConnection() error {
 	}
 
 	// mark sending connection as dead
-	atomic.StoreInt32(&s.sending_connection_alive, 0)
+	atomic.StoreInt32(&s.virtual_sending_connection_alive, 0)
 
 	return nil
 }
 
 func (s *FullDuplexSimulator) startReceivingConnection() error {
-	if atomic.LoadInt32(&s.receiving_connection_alive) == 1 {
+	if atomic.LoadInt32(&s.virtual_receiving_connection_alive) == 1 {
 		return nil
 	}
 
+	// virtual receiving connection is alive
+	atomic.StoreInt32(&s.virtual_receiving_connection_alive, 1)
+
 	// lock the receiving connection
 	s.receiving_connection_timeline_lock.Lock()
 	defer s.receiving_connection_timeline_lock.Unlock()
 
-	// start a new receiving connection
-	u, err := url.JoinPath(s.baseurl.String(), "/response")
-	if err != nil {
-		return err
-	}
-
-	req, err := http.NewRequest("GET", u, nil)
-	if err != nil {
-		return err
-	}
-	req.Header.Set("Content-Type", "octet-stream")
-	req.Header.Set("Connection", "keep-alive")
-	req.Header.Set("x-dify-plugin-request-id", s.request_id)
-
-	req = req.Clone(context.Background())
-	resp, err := s.client.Do(req)
-	if err != nil {
-		return errors.Join(err, errors.New("failed to establish receiving connection"))
-	}
-
 	routine.Submit(func() {
-		s.receivingConnectionRoutine(req, resp.Body)
+		s.receivingConnectionRoutine()
 	})
 
-	// mark receiving connection as alive
-	atomic.StoreInt32(&s.receiving_connection_alive, 1)
-
 	return nil
 }
 
-func (s *FullDuplexSimulator) receivingConnectionRoutine(req *http.Request, reader io.ReadCloser) {
-	failed_times := 0
-	for {
-		s.receiving_reader = reader
+func (s *FullDuplexSimulator) receivingConnectionRoutine() {
+	// close the virtual receiving connection
+	defer atomic.StoreInt32(&s.virtual_receiving_connection_alive, 0)
+
+	for atomic.LoadInt32(&s.virtual_receiving_connection_alive) == 1 {
 		recved_pong := false
 		buf := make([]byte, 0)
 		buf_len := 0
 
+		// start a new receiving connection
+		u, err := url.JoinPath(s.baseurl.String(), "/response")
+		if err != nil {
+			continue
+		}
+
+		req, err := http.NewRequest("GET", u, nil)
+		if err != nil {
+			continue
+		}
+		req.Header.Set("Content-Type", "octet-stream")
+		req.Header.Set("Connection", "keep-alive")
+		req.Header.Set("x-dify-plugin-request-id", s.request_id)
+
+		ctx, cancel := context.WithCancel(context.Background())
+		req = req.Clone(ctx)
+		resp, err := s.client.Do(req)
+		if err != nil {
+			continue
+		}
+
+		s.receiving_context_lock.Lock()
+		s.receiving_cancel = cancel
+		s.receiving_context_lock.Unlock()
+
+		time.AfterFunc(s.receiving_connection_max_alive_time, func() {
+			cancel()
+			resp.Body.Close()
+		})
+
+		reader := resp.Body
 		for {
 			data := make([]byte, 1024)
 			n, err := reader.Read(data)
-			if err != nil {
-				break
-			}
-
-			// check if pong\n is at the beginning of the data
-			if !recved_pong {
-				data = append(data, buf[:buf_len]...)
-				buf = make([]byte, 0)
-				buf_len = 0
-
-				if n >= 5 {
-					if string(data[:5]) == "pong\n" {
-						recved_pong = true
-						// remove pong\n from the beginning of the data
-						data = data[5:]
-						n -= 5
-					} else {
-						// not pong\n, break
-						break
+			if n != 0 {
+				// check if pong\n is at the beginning of the data
+				if !recved_pong {
+					data = append(buf[:buf_len], data[:n]...)
+					buf = make([]byte, 0)
+					buf_len = 0
+
+					if n >= 5 {
+						if string(data[:5]) == "pong\n" {
+							recved_pong = true
+							// remove pong\n from the beginning of the data
+							data = data[5:]
+							n -= 5
+						} else {
+							// not pong\n, break
+							break
+						}
+					} else if n < 5 {
+						// save the data to the buffer
+						buf = append(buf, data[:n]...)
+						buf_len += n
+						continue
 					}
-				} else if n < 5 {
-					// save the data to the buffer
-					buf = append(buf, data[:n]...)
-					buf_len += n
-					continue
 				}
 			}
 
 			for _, listener := range s.listeners[:] {
 				listener(data[:n])
 			}
-		}
 
-		s.receiving_reader = nil
+			atomic.AddInt64(&s.received_bytes, int64(n))
+
+			if err != nil {
+				break
+			}
+		}
 
 		s.receiving_connection_timeline_lock.Lock()
-		if atomic.LoadInt32(&s.receiving_connection_alive) == 0 {
+		if atomic.LoadInt32(&s.virtual_receiving_connection_alive) == 0 {
 			s.receiving_connection_timeline_lock.Unlock()
 			return
 		}
 		s.receiving_connection_timeline_lock.Unlock()
-
-		req = req.Clone(context.Background())
-		resp, err := s.client.Do(req)
-		if err != nil {
-			failed_times++
-			if failed_times > s.max_retries {
-				log.Error("failed to establish receiving connection: %v", err)
-				s.stopReceivingConnection()
-				return
-			}
-
-			log.Error("failed to establish receiving connection: %v", err)
-			continue
-		}
-
-		reader = resp.Body
 	}
 }
 
 func (s *FullDuplexSimulator) stopReceivingConnection() {
-	if atomic.LoadInt32(&s.receiving_connection_alive) == 0 {
+	if atomic.LoadInt32(&s.virtual_receiving_connection_alive) == 0 {
 		return
 	}
 
 	// mark receiving connection as dead
-	atomic.StoreInt32(&s.receiving_connection_alive, 0)
+	atomic.StoreInt32(&s.virtual_receiving_connection_alive, 0)
 
-	s.receiving_connection_timeline_lock.Lock()
-	defer s.receiving_connection_timeline_lock.Unlock()
-
-	// close the receiving reader
-	reader := s.receiving_reader
-	if reader != nil {
-		reader.Close()
+	// cancel the receiving context
+	s.receiving_context_lock.Lock()
+	if s.receiving_cancel != nil {
+		s.receiving_cancel()
 	}
+	s.receiving_context_lock.Unlock()
+}
+
+// GetStats, returns the sent and received bytes
+func (s *FullDuplexSimulator) GetStats() (sent_bytes, received_bytes int64) {
+	return atomic.LoadInt64(&s.sent_bytes), atomic.LoadInt64(&s.received_bytes)
 }

+ 79 - 8
internal/core/plugin_manager/aws_manager/full_duplex_simulator_test.go

@@ -4,6 +4,7 @@ import (
 	"bytes"
 	"fmt"
 	"net/http"
+	"strings"
 	"sync"
 	"testing"
 	"time"
@@ -13,7 +14,7 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
 
-func server(timeout time.Duration) (string, func(), error) {
+func server(recv_timeout time.Duration, send_timeout time.Duration) (string, func(), error) {
 	routine.InitPool(1024)
 
 	port, err := network.GetRandomPort()
@@ -24,8 +25,11 @@ func server(timeout time.Duration) (string, func(), error) {
 	data := map[string]chan []byte{}
 	data_mu := sync.Mutex{}
 
+	recved := 0
+
 	eng := gin.New()
 	eng.POST("/invoke", func(c *gin.Context) {
+		// fmt.Println("new send request")
 		id := c.Request.Header.Get("x-dify-plugin-request-id")
 		var ch chan []byte
 
@@ -38,7 +42,7 @@ func server(timeout time.Duration) (string, func(), error) {
 		}
 		data_mu.Unlock()
 
-		time.AfterFunc(timeout, func() {
+		time.AfterFunc(send_timeout, func() {
 			c.Request.Body.Close()
 		})
 
@@ -46,16 +50,25 @@ func server(timeout time.Duration) (string, func(), error) {
 		for {
 			buf := make([]byte, 1024)
 			n, err := c.Request.Body.Read(buf)
+			if n != 0 {
+				recved += n
+				ch <- buf[:n]
+			}
 			if err != nil {
 				break
 			}
-			ch <- buf[:n]
 		}
 
-		c.Status(http.StatusOK)
+		// output closed
+		c.Writer.WriteHeader(http.StatusOK)
+		c.Writer.Write([]byte("closed\n"))
+		c.Writer.Flush()
 	})
 
+	response := 0
+
 	eng.GET("/response", func(ctx *gin.Context) {
+		// fmt.Println("new recv request")
 		id := ctx.Request.Header.Get("x-dify-plugin-request-id")
 		var ch chan []byte
 		data_mu.Lock()
@@ -74,16 +87,19 @@ func server(timeout time.Duration) (string, func(), error) {
 		ctx.Writer.Write([]byte("pong\n"))
 		ctx.Writer.Flush()
 
+		timer := time.NewTimer(recv_timeout)
+
 		for {
 			select {
 			case data := <-ch:
 				ctx.Writer.Write(data)
 				ctx.Writer.Flush()
+				response += len(data)
 			case <-ctx.Done():
 				return
 			case <-ctx.Writer.CloseNotify():
 				return
-			case <-time.After(timeout):
+			case <-timer.C:
 				ctx.Status(http.StatusOK)
 				return
 			}
@@ -101,11 +117,12 @@ func server(timeout time.Duration) (string, func(), error) {
 
 	return fmt.Sprintf("http://localhost:%d", port), func() {
 		srv.Close()
+		fmt.Printf("recved: %d, responsed: %d\n", recved, response)
 	}, nil
 }
 
-func TestFullDuplexSimulator_Send(t *testing.T) {
-	url, cleanup, err := server(time.Second * 100)
+func TestFullDuplexSimulator_SingleSendAndReceive(t *testing.T) {
+	url, cleanup, err := server(time.Second*100, time.Second*100)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -113,7 +130,7 @@ func TestFullDuplexSimulator_Send(t *testing.T) {
 
 	time.Sleep(time.Second)
 
-	simulator, err := NewFullDuplexSimulator(url)
+	simulator, err := NewFullDuplexSimulator(url, time.Second*100, time.Second*100)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -147,3 +164,57 @@ func TestFullDuplexSimulator_Send(t *testing.T) {
 		t.Fatal(fmt.Sprintf("recved: %s", string(recved)))
 	}
 }
+
+func TestFullDuplexSimulator_AutoReconnect(t *testing.T) {
+	url, cleanup, err := server(time.Millisecond*700, time.Second*10)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer cleanup()
+
+	time.Sleep(time.Second)
+
+	simulator, err := NewFullDuplexSimulator(url, time.Millisecond*700, time.Second*10)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	l := 0
+	recved := strings.Builder{}
+	simulator.On(func(data []byte) {
+		l += len(data)
+		recved.Write(data)
+	})
+
+	done, err := simulator.StartTransaction()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer done()
+
+	ticker := time.NewTicker(time.Millisecond * 1)
+	counter := 0
+
+	for range ticker.C {
+		if err := simulator.Send([]byte(fmt.Sprintf("%05d", counter))); err != nil {
+			t.Fatal(err)
+		}
+		counter++
+		if counter == 3000 {
+			break
+		}
+	}
+
+	time.Sleep(time.Millisecond * 500)
+
+	if l != 3000*5 {
+		sent, received := simulator.GetStats()
+		t.Errorf(fmt.Sprintf("expected: %d, actual: %d, sent: %d, received: %d", 3000*5, l, sent, received))
+		// to find which one is missing
+		for i := 0; i < 3000; i++ {
+			if !strings.Contains(recved.String(), fmt.Sprintf("%05d", i)) {
+				t.Errorf(fmt.Sprintf("missing: %d", i))
+			}
+		}
+	}
+}