Преглед на файлове

refactor: control timeout

Yeuoly преди 1 година
родител
ревизия
578f2af8ac

+ 132 - 15
internal/core/plugin_manager/aws_manager/full_duplex_simulator.go

@@ -3,6 +3,7 @@ package aws_manager
 import (
 	"context"
 	"errors"
+	"fmt"
 	"io"
 	"net"
 	"net/http"
@@ -27,8 +28,10 @@ type FullDuplexSimulator struct {
 	baseurl *url.URL
 
 	// single connection max alive time
-	sending_connection_max_alive_time   time.Duration
-	receiving_connection_max_alive_time time.Duration
+	sending_connection_max_alive_time          time.Duration
+	receiving_connection_max_alive_time        time.Duration
+	target_sending_connection_max_alive_time   time.Duration
+	target_receiving_connection_max_alive_time time.Duration
 
 	// how many transactions are alive
 	alive_transactions int32
@@ -40,7 +43,8 @@ type FullDuplexSimulator struct {
 	connection_restarts int32
 
 	// sent bytes
-	sent_bytes int64
+	sent_bytes                 int64
+	current_request_sent_bytes int32
 	// received bytes
 	received_bytes int64
 
@@ -60,6 +64,10 @@ type FullDuplexSimulator struct {
 
 	// max retries
 	max_retries int
+	// max sending single request sending bytes
+	max_sending_bytes int32
+	// max receiving single request receiving bytes
+	max_receiving_bytes int32
 
 	// request id
 	request_id string
@@ -70,6 +78,7 @@ type FullDuplexSimulator struct {
 	// is sending connection alive
 	sending_connection_alive         int32
 	sending_routine_lock             sync.Mutex
+	sending_lock                     sync.Mutex
 	virtual_sending_connection_alive int32
 
 	// receiving routine lock
@@ -87,22 +96,92 @@ type FullDuplexSimulator struct {
 	client *http.Client
 }
 
+type FullDuplexSimulatorOption struct {
+	// MaxRetries, default 10
+	MaxRetries int
+	// SendingConnectionMaxAliveTime, default 60s
+	SendingConnectionMaxAliveTime time.Duration
+	// TargetSendingConnectionMaxAliveTime, default 80s
+	TargetSendingConnectionMaxAliveTime time.Duration
+	// ReceivingConnectionMaxAliveTime, default 80s
+	ReceivingConnectionMaxAliveTime time.Duration
+	// TargetReceivingConnectionMaxAliveTime, default 60s
+	TargetReceivingConnectionMaxAliveTime time.Duration
+	// MaxSingleRequestSendingBytes, default 5 * 1024 * 1024
+	MaxSingleRequestSendingBytes int32
+	// MaxSingleRequestReceivingBytes, default 5 * 1024 * 1024
+	MaxSingleRequestReceivingBytes int32
+}
+
+func (opt *FullDuplexSimulatorOption) defaultOption() error {
+	if opt.MaxRetries == 0 {
+		opt.MaxRetries = 10
+	}
+
+	if opt.SendingConnectionMaxAliveTime == 0 {
+		opt.SendingConnectionMaxAliveTime = 60 * time.Second
+	}
+
+	if opt.ReceivingConnectionMaxAliveTime == 0 {
+		opt.ReceivingConnectionMaxAliveTime = 80 * time.Second
+	}
+
+	if opt.TargetSendingConnectionMaxAliveTime == 0 {
+		opt.TargetSendingConnectionMaxAliveTime = 80 * time.Second
+	}
+
+	if opt.TargetReceivingConnectionMaxAliveTime == 0 {
+		opt.TargetReceivingConnectionMaxAliveTime = 60 * time.Second
+	}
+
+	if opt.MaxSingleRequestSendingBytes == 0 {
+		opt.MaxSingleRequestSendingBytes = 5 * 1024 * 1024
+	}
+
+	if opt.MaxSingleRequestReceivingBytes == 0 {
+		opt.MaxSingleRequestReceivingBytes = 5 * 1024 * 1024
+	}
+
+	// target receiving connection max alive time should be larger than receiving connection max alive time
+	if opt.TargetReceivingConnectionMaxAliveTime < opt.ReceivingConnectionMaxAliveTime {
+		return errors.New("target receiving connection max alive time should be larger than receiving connection max alive time")
+	}
+
+	// sending connection max alive time should be larger than target sending connection max alive time
+	if opt.SendingConnectionMaxAliveTime < opt.TargetSendingConnectionMaxAliveTime {
+		return errors.New("sending connection max alive time should be larger than target sending connection max alive time")
+	}
+
+	return nil
+}
+
 func NewFullDuplexSimulator(
 	baseurl string,
-	sending_connection_max_alive_time time.Duration,
-	receiving_connection_max_alive_time time.Duration,
+	opt *FullDuplexSimulatorOption,
 ) (*FullDuplexSimulator, error) {
 	u, err := url.Parse(baseurl)
 	if err != nil {
 		return nil, err
 	}
 
+	if opt == nil {
+		opt = &FullDuplexSimulatorOption{}
+	}
+
+	if err := opt.defaultOption(); err != nil {
+		return nil, err
+	}
+
 	return &FullDuplexSimulator{
-		baseurl:                             u,
-		sending_connection_max_alive_time:   sending_connection_max_alive_time,
-		receiving_connection_max_alive_time: receiving_connection_max_alive_time,
-		max_retries:                         10,
-		request_id:                          strings.RandomString(32),
+		baseurl:                                    u,
+		sending_connection_max_alive_time:          opt.SendingConnectionMaxAliveTime,
+		target_sending_connection_max_alive_time:   opt.TargetSendingConnectionMaxAliveTime,
+		receiving_connection_max_alive_time:        opt.ReceivingConnectionMaxAliveTime,
+		target_receiving_connection_max_alive_time: opt.TargetReceivingConnectionMaxAliveTime,
+		max_sending_bytes:                          opt.MaxSingleRequestSendingBytes,
+		max_receiving_bytes:                        opt.MaxSingleRequestReceivingBytes,
+		max_retries:                                opt.MaxRetries,
+		request_id:                                 strings.RandomString(32),
 
 		// using keep alive to reduce the connection reset
 		client: &http.Client{
@@ -117,21 +196,52 @@ func NewFullDuplexSimulator(
 	}, nil
 }
 
-// send data to server
+// send data to server, it's thread-safe
 func (s *FullDuplexSimulator) Send(data []byte, timeout ...time.Duration) error {
+	s.sending_lock.Lock()
+	defer s.sending_lock.Unlock()
+
+	// split data into max 1024 bytes
+	for len(data) > 0 {
+		chunk := data
+		if len(chunk) > 1024 {
+			chunk = chunk[:1024]
+		}
+
+		data = data[len(chunk):]
+		if err := s.send(chunk, timeout...); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (s *FullDuplexSimulator) send(data []byte, timeout ...time.Duration) error {
+	started := time.Now()
+
 	timeout_duration := time.Second * 10
 	if len(timeout) > 0 {
 		timeout_duration = timeout[0]
 	}
 
-	started := time.Now()
-
 	for time.Since(started) < timeout_duration {
 		if atomic.LoadInt32(&s.sending_connection_alive) != 1 {
 			time.Sleep(time.Millisecond * 50)
 			continue
 		}
 
+		if atomic.AddInt32(&s.current_request_sent_bytes, int32(len(data))) > s.max_sending_bytes {
+			// reached max sending bytes, close current connection, and start a new one
+			s.sending_pipe_lock.Lock()
+			if s.sending_pipeline != nil {
+				s.sending_pipeline.Close()
+			}
+			s.sending_pipe_lock.Unlock()
+			atomic.StoreInt32(&s.current_request_sent_bytes, 0)
+			continue
+		}
+
 		s.sending_pipe_lock.Lock()
 		writer := s.sending_pipeline
 		if writer == nil {
@@ -146,13 +256,18 @@ func (s *FullDuplexSimulator) Send(data []byte, timeout ...time.Duration) error
 			continue
 		} else {
 			atomic.AddInt64(&s.sent_bytes, int64(n))
+			atomic.AddInt32(&s.current_request_sent_bytes, int32(n))
 		}
 
 		s.sending_pipe_lock.Unlock()
-		return nil
+		break
 	}
 
-	return errors.New("send data timeout")
+	if time.Since(started) > timeout_duration {
+		return errors.New("send data timeout")
+	}
+
+	return nil
 }
 
 func (s *FullDuplexSimulator) On(f func(data []byte)) {
@@ -224,6 +339,7 @@ func (s *FullDuplexSimulator) startSendingConnection(routine_id string) error {
 	req.Header.Set("Content-Type", "octet-stream")
 	req.Header.Set("Connection", "keep-alive")
 	req.Header.Set("x-dify-plugin-request-id", s.request_id)
+	req.Header.Set("x-dify-plugin-max-alive-time", fmt.Sprintf("%d", s.target_receiving_connection_max_alive_time.Milliseconds()))
 
 	routine.Submit(func() {
 		s.sendingConnectionRoutine(req, routine_id)
@@ -379,6 +495,7 @@ func (s *FullDuplexSimulator) receivingConnectionRoutine(routine_id string) {
 		req.Header.Set("Content-Type", "octet-stream")
 		req.Header.Set("Connection", "keep-alive")
 		req.Header.Set("x-dify-plugin-request-id", s.request_id)
+		req.Header.Set("x-dify-plugin-max-alive-time", fmt.Sprintf("%d", s.target_sending_connection_max_alive_time.Milliseconds()))
 
 		ctx, cancel := context.WithCancel(context.Background())
 		req = req.Clone(ctx)

+ 108 - 14
internal/core/plugin_manager/aws_manager/full_duplex_simulator_test.go

@@ -45,7 +45,7 @@ func (s *S) Stop() {
 	s.srv.Close()
 }
 
-func server(recv_timeout time.Duration, send_timeout time.Duration) (*S, error) {
+func server() (*S, error) {
 	port, err := network.GetRandomPort()
 	if err != nil {
 		return nil, err
@@ -75,6 +75,7 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (*S, error)
 
 		// fmt.Println("new send request")
 		id := c.Request.Header.Get("x-dify-plugin-request-id")
+		max_alive_time := c.Request.Header.Get("x-dify-plugin-max-alive-time")
 		s.current_send_request_id = id
 
 		var ch chan []byte
@@ -88,7 +89,12 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (*S, error)
 		}
 		s.data_mu.Unlock()
 
-		time.AfterFunc(send_timeout, func() {
+		timeout, err := strconv.ParseInt(max_alive_time, 10, 64)
+		if err != nil {
+			timeout = 60
+		}
+
+		time.AfterFunc(time.Millisecond*time.Duration(timeout), func() {
 			c.Request.Body.Close()
 		})
 
@@ -118,6 +124,7 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (*S, error)
 
 		// fmt.Println("new recv request")
 		id := ctx.Request.Header.Get("x-dify-plugin-request-id")
+		max_alive_time := ctx.Request.Header.Get("x-dify-plugin-max-alive-time")
 		s.current_recv_request_id = id
 
 		var ch chan []byte
@@ -137,7 +144,12 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (*S, error)
 		ctx.Writer.Write([]byte("pong\n"))
 		ctx.Writer.Flush()
 
-		timer := time.NewTimer(recv_timeout)
+		timeout, err := strconv.ParseInt(max_alive_time, 10, 64)
+		if err != nil {
+			timeout = 60
+		}
+
+		timer := time.NewTimer(time.Millisecond * time.Duration(timeout))
 
 		for {
 			select {
@@ -167,7 +179,7 @@ func TestFullDuplexSimulator_SingleSendAndReceive(t *testing.T) {
 	log.SetShowLog(false)
 	defer log.SetShowLog(true)
 
-	srv, err := server(time.Second*100, time.Second*100)
+	srv, err := server()
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -175,7 +187,16 @@ func TestFullDuplexSimulator_SingleSendAndReceive(t *testing.T) {
 
 	time.Sleep(time.Second)
 
-	simulator, err := NewFullDuplexSimulator(srv.url, time.Second*100, time.Second*100)
+	simulator, err := NewFullDuplexSimulator(
+		srv.url, &FullDuplexSimulatorOption{
+			SendingConnectionMaxAliveTime:         time.Second * 100,
+			ReceivingConnectionMaxAliveTime:       time.Second * 100,
+			TargetSendingConnectionMaxAliveTime:   time.Second * 99,
+			TargetReceivingConnectionMaxAliveTime: time.Second * 101,
+			MaxSingleRequestSendingBytes:          1024 * 1024,
+			MaxSingleRequestReceivingBytes:        1024 * 1024,
+		},
+	)
 	if err != nil {
 		t.Fatal(err)
 	}
@@ -222,7 +243,7 @@ func TestFullDuplexSimulator_AutoReconnect(t *testing.T) {
 		go func() {
 			defer wg.Done()
 
-			srv, err := server(time.Millisecond*700, time.Second*10)
+			srv, err := server()
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -230,7 +251,16 @@ func TestFullDuplexSimulator_AutoReconnect(t *testing.T) {
 
 			time.Sleep(time.Second)
 
-			simulator, err := NewFullDuplexSimulator(srv.url, time.Millisecond*700, time.Second*10)
+			simulator, err := NewFullDuplexSimulator(
+				srv.url, &FullDuplexSimulatorOption{
+					SendingConnectionMaxAliveTime:         time.Millisecond * 700,
+					TargetSendingConnectionMaxAliveTime:   time.Millisecond * 700,
+					ReceivingConnectionMaxAliveTime:       time.Millisecond * 10000,
+					TargetReceivingConnectionMaxAliveTime: time.Millisecond * 10000,
+					MaxSingleRequestSendingBytes:          1024 * 1024,
+					MaxSingleRequestReceivingBytes:        1024 * 1024,
+				},
+			)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -266,12 +296,15 @@ func TestFullDuplexSimulator_AutoReconnect(t *testing.T) {
 			if l != 3000*5 {
 				sent, received, restarts := simulator.GetStats()
 				t.Errorf(fmt.Sprintf("expected: %d, actual: %d, sent: %d, received: %d, restarts: %d", 3000*5, l, sent, received, restarts))
+				server_recv_count := srv.recv_count
+				server_send_count := srv.send_count
+				t.Errorf(fmt.Sprintf("server recv count: %d, server send count: %d", server_recv_count, server_send_count))
 				// to find which one is missing
-				for i := 0; i < 3000; i++ {
-					if !strings.Contains(recved.String(), fmt.Sprintf("%05d", i)) {
-						t.Errorf(fmt.Sprintf("missing: %d", i))
-					}
-				}
+				// for i := 0; i < 3000; i++ {
+				// 	if !strings.Contains(recved.String(), fmt.Sprintf("%05d", i)) {
+				// 		t.Errorf(fmt.Sprintf("missing: %d", i))
+				// 	}
+				// }
 			}
 		}()
 	}
@@ -295,7 +328,7 @@ func TestFullDuplexSimulator_MultipleTransactions(t *testing.T) {
 		go func() {
 			defer w.Done()
 
-			srv, err := server(time.Millisecond*700, time.Second*10)
+			srv, err := server()
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -303,7 +336,16 @@ func TestFullDuplexSimulator_MultipleTransactions(t *testing.T) {
 
 			time.Sleep(time.Second)
 
-			simulator, err := NewFullDuplexSimulator(srv.url, time.Millisecond*700, time.Second*10)
+			simulator, err := NewFullDuplexSimulator(
+				srv.url, &FullDuplexSimulatorOption{
+					SendingConnectionMaxAliveTime:         time.Millisecond * 700,
+					TargetSendingConnectionMaxAliveTime:   time.Millisecond * 700,
+					ReceivingConnectionMaxAliveTime:       time.Millisecond * 1000,
+					TargetReceivingConnectionMaxAliveTime: time.Millisecond * 1000,
+					MaxSingleRequestSendingBytes:          1024 * 1024,
+					MaxSingleRequestReceivingBytes:        1024 * 1024,
+				},
+			)
 			if err != nil {
 				t.Fatal(err)
 			}
@@ -402,3 +444,55 @@ func TestFullDuplexSimulator_MultipleTransactions(t *testing.T) {
 
 	w.Wait()
 }
+
+func TestFullDuplexSimulator_SendLargeData(t *testing.T) {
+	log.SetShowLog(false)
+	defer log.SetShowLog(true)
+
+	srv, err := server()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer srv.Stop()
+
+	time.Sleep(time.Second)
+
+	l := 0
+
+	simulator, err := NewFullDuplexSimulator(
+		srv.url, &FullDuplexSimulatorOption{
+			SendingConnectionMaxAliveTime:         time.Millisecond * 700,
+			TargetSendingConnectionMaxAliveTime:   time.Millisecond * 700,
+			ReceivingConnectionMaxAliveTime:       time.Millisecond * 1000,
+			TargetReceivingConnectionMaxAliveTime: time.Millisecond * 1000,
+			MaxSingleRequestSendingBytes:          5 * 1024 * 1024,
+			MaxSingleRequestReceivingBytes:        5 * 1024 * 1024,
+		},
+	)
+
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	simulator.On(func(data []byte) {
+		l += len(data)
+	})
+
+	done, err := simulator.StartTransaction()
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer done()
+
+	for i := 0; i < 300; i++ { // 300MB, this process should be done in 20 seconds
+		if err := simulator.Send([]byte(strings.Repeat("a", 1024*1024))); err != nil {
+			t.Fatal(err)
+		}
+	}
+
+	time.Sleep(time.Second * 1)
+
+	if l != 300*1024*1024 { // 300MB
+		t.Fatal(fmt.Sprintf("expected: %d, actual: %d", 300*1024*1024, l))
+	}
+}