浏览代码

fix: full duplex simulator lifetime control

Yeuoly 11 月之前
父节点
当前提交
3d10c85354

+ 53 - 30
internal/core/plugin_manager/aws_manager/full_duplex_simulator.go

@@ -53,17 +53,19 @@ type FullDuplexSimulator struct {
 	// receiving context
 	receiving_cancel context.CancelFunc
 	// receiving context lock
-	receiving_context_lock sync.Mutex
+	receiving_cancel_lock sync.Mutex
 
 	// max retries
 	max_retries int
 
 	// is sending connection alive
 	sending_connection_alive         int32
+	sending_routine_lock             sync.Mutex
 	virtual_sending_connection_alive int32
 
 	// is receiving connection alive
 	receiving_connection_alive         int32
+	receiving_routine_lock             sync.Mutex
 	virtual_receiving_connection_alive int32
 
 	// listener for data
@@ -94,7 +96,6 @@ func NewFullDuplexSimulator(
 		sending_connection_max_alive_time:   sending_connection_max_alive_time,
 		receiving_connection_max_alive_time: receiving_connection_max_alive_time,
 		max_retries:                         10,
-		request_id:                          strings.RandomString(32),
 
 		// using keep alive to reduce the connection reset
 		client: &http.Client{
@@ -157,19 +158,24 @@ func (s *FullDuplexSimulator) On(f func(data []byte)) {
 // returns a function to stop the transaction
 func (s *FullDuplexSimulator) StartTransaction() (func(), error) {
 	// start a transaction
-	atomic.AddInt32(&s.alive_transactions, 1)
-	atomic.AddInt32(&s.total_transactions, 1)
+	if atomic.AddInt32(&s.alive_transactions, 1) == 1 {
+		// reset request id
+		s.request_id = strings.RandomString(32)
 
-	// start sending connection
-	if err := s.startSendingConnection(); err != nil {
-		return nil, err
-	}
+		// start sending connection
+		if err := s.startSendingConnection(); err != nil {
+			return nil, err
+		}
 
-	// start receiving connection
-	if err := s.startReceivingConnection(); err != nil {
-		return nil, err
+		// start receiving connection
+		if err := s.startReceivingConnection(); err != nil {
+			s.stopSendingConnection()
+			return nil, err
+		}
 	}
 
+	atomic.AddInt32(&s.total_transactions, 1)
+
 	return s.stopTransaction, nil
 }
 
@@ -182,10 +188,14 @@ func (s *FullDuplexSimulator) stopTransaction() {
 }
 
 func (s *FullDuplexSimulator) startSendingConnection() error {
-	if atomic.LoadInt32(&s.sending_connection_alive) == 1 {
+	// if virtual sending connection is already alive, do nothing
+	if atomic.LoadInt32(&s.virtual_sending_connection_alive) == 1 {
 		return nil
 	}
 
+	// set virtual sending connection as alive
+	atomic.StoreInt32(&s.virtual_sending_connection_alive, 1)
+
 	// lock the sending connection
 	s.sending_connection_timeline_lock.Lock()
 	defer s.sending_connection_timeline_lock.Unlock()
@@ -213,10 +223,13 @@ func (s *FullDuplexSimulator) startSendingConnection() error {
 }
 
 func (s *FullDuplexSimulator) sendingConnectionRoutine(origin_req *http.Request) {
+	// lock the sending routine, to avoid there are multiple routines trying to establish the sending connection
+	s.sending_routine_lock.Lock()
+	// cancel the sending routine
+	defer s.sending_routine_lock.Unlock()
+
 	failed_times := 0
-	for {
-		// this real connection will be closed after single_connection_max_alive_time
-		// but the virtual connection will be established again and again
+	for atomic.LoadInt32(&s.virtual_sending_connection_alive) == 1 {
 		ctx, cancel := context.WithCancel(context.Background())
 		time.AfterFunc(s.sending_connection_max_alive_time, func() {
 			// reached max alive time, remove pipe writer
@@ -241,8 +254,19 @@ func (s *FullDuplexSimulator) sendingConnectionRoutine(origin_req *http.Request)
 		atomic.StoreInt32(&s.sending_connection_alive, 1)
 
 		resp, err := s.client.Do(req)
-
 		if err != nil {
+			// if virtual sending connection is not alive, clear the sending pipeline and return
+			if atomic.LoadInt32(&s.virtual_sending_connection_alive) == 0 {
+				// clear the sending pipeline
+				s.sending_pipe_lock.Lock()
+				if s.sending_pipeline != nil {
+					s.sending_pipeline.Close()
+					s.sending_pipeline = nil
+				}
+				s.sending_pipe_lock.Unlock()
+				return
+			}
+
 			failed_times++
 			if failed_times > s.max_retries {
 				log.Error("failed to establish sending connection: %v", err)
@@ -276,6 +300,9 @@ func (s *FullDuplexSimulator) stopSendingConnection() error {
 	s.sending_connection_timeline_lock.Lock()
 	defer s.sending_connection_timeline_lock.Unlock()
 
+	s.sending_pipe_lock.Lock()
+	defer s.sending_pipe_lock.Unlock()
+
 	// close the sending pipeline
 	if s.sending_pipeline != nil {
 		s.sending_pipeline.Close()
@@ -289,11 +316,12 @@ func (s *FullDuplexSimulator) stopSendingConnection() error {
 }
 
 func (s *FullDuplexSimulator) startReceivingConnection() error {
+	// if virtual receiving connection is already alive, do nothing
 	if atomic.LoadInt32(&s.virtual_receiving_connection_alive) == 1 {
 		return nil
 	}
 
-	// virtual receiving connection is alive
+	// set virtual receiving connection as alive
 	atomic.StoreInt32(&s.virtual_receiving_connection_alive, 1)
 
 	// lock the receiving connection
@@ -308,8 +336,10 @@ func (s *FullDuplexSimulator) startReceivingConnection() error {
 }
 
 func (s *FullDuplexSimulator) receivingConnectionRoutine() {
-	// close the virtual receiving connection
-	defer atomic.StoreInt32(&s.virtual_receiving_connection_alive, 0)
+	// lock the receiving routine, to avoid there are multiple routines trying to establish the receiving connection
+	s.receiving_routine_lock.Lock()
+	// cancel the receiving routine
+	defer s.receiving_routine_lock.Unlock()
 
 	for atomic.LoadInt32(&s.virtual_receiving_connection_alive) == 1 {
 		recved_pong := false
@@ -337,9 +367,9 @@ func (s *FullDuplexSimulator) receivingConnectionRoutine() {
 			continue
 		}
 
-		s.receiving_context_lock.Lock()
+		s.receiving_cancel_lock.Lock()
 		s.receiving_cancel = cancel
-		s.receiving_context_lock.Unlock()
+		s.receiving_cancel_lock.Unlock()
 
 		time.AfterFunc(s.receiving_connection_max_alive_time, func() {
 			cancel()
@@ -386,13 +416,6 @@ func (s *FullDuplexSimulator) receivingConnectionRoutine() {
 				break
 			}
 		}
-
-		s.receiving_connection_timeline_lock.Lock()
-		if atomic.LoadInt32(&s.virtual_receiving_connection_alive) == 0 {
-			s.receiving_connection_timeline_lock.Unlock()
-			return
-		}
-		s.receiving_connection_timeline_lock.Unlock()
 	}
 }
 
@@ -405,11 +428,11 @@ func (s *FullDuplexSimulator) stopReceivingConnection() {
 	atomic.StoreInt32(&s.virtual_receiving_connection_alive, 0)
 
 	// cancel the receiving context
-	s.receiving_context_lock.Lock()
+	s.receiving_cancel_lock.Lock()
 	if s.receiving_cancel != nil {
 		s.receiving_cancel()
 	}
-	s.receiving_context_lock.Unlock()
+	s.receiving_cancel_lock.Unlock()
 }
 
 // GetStats, returns the sent and received bytes

+ 12 - 1
internal/core/plugin_manager/aws_manager/full_duplex_simulator_test.go

@@ -10,6 +10,7 @@ import (
 	"time"
 
 	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/network"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
@@ -28,6 +29,10 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (string, fun
 	recved := 0
 
 	eng := gin.New()
+
+	// avoid log
+	gin.SetMode(gin.ReleaseMode)
+
 	eng.POST("/invoke", func(c *gin.Context) {
 		// fmt.Println("new send request")
 		id := c.Request.Header.Get("x-dify-plugin-request-id")
@@ -117,11 +122,14 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (string, fun
 
 	return fmt.Sprintf("http://localhost:%d", port), func() {
 		srv.Close()
-		fmt.Printf("recved: %d, responsed: %d\n", recved, response)
+		// fmt.Printf("recved: %d, responsed: %d\n", recved, response)
 	}, nil
 }
 
 func TestFullDuplexSimulator_SingleSendAndReceive(t *testing.T) {
+	log.SetShowLog(false)
+	defer log.SetShowLog(true)
+
 	url, cleanup, err := server(time.Second*100, time.Second*100)
 	if err != nil {
 		t.Fatal(err)
@@ -166,6 +174,9 @@ func TestFullDuplexSimulator_SingleSendAndReceive(t *testing.T) {
 }
 
 func TestFullDuplexSimulator_AutoReconnect(t *testing.T) {
+	log.SetShowLog(false)
+	defer log.SetShowLog(true)
+
 	// hmmm, to ensure the server is stable, we need to run the test 100 times
 	// don't ask me why, just trust me, I have spent 1 days to handle this race condition
 	wg := sync.WaitGroup{}