Parcourir la source

feat: full duplex simulator

Yeuoly il y a 11 mois
Parent
commit
47f2fc9082

+ 367 - 0
internal/core/plugin_manager/aws_manager/full_duplex_simulator.go

@@ -0,0 +1,367 @@
+package aws_manager
+
+import (
+	"context"
+	"errors"
+	"io"
+	"net"
+	"net/http"
+	"net/url"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/strings"
+)
+
+// Full duplex simulator, using http protocol to simulate the full duplex communication
+// 1. during free time, no connection will be established
+// 2. when there is a virtual connection need to be established, 2 http transactions will be sent to the server
+// 3. one is used to send data chunk by chunk to simulate the data stream and the other is used to receive data using event stream
+// 4. after all data is sent, the connection will be closed to reduce the network traffic
+//
+// When http connection is closed, the simulator will restart it immediately until it has reached max_retries
+type FullDuplexSimulator struct {
+	baseurl *url.URL
+
+	// how many transactions are alive
+	alive_transactions int32
+
+	// total transactions
+	total_transactions int32
+
+	// sending_connection_timeline_lock
+	sending_connection_timeline_lock sync.Mutex
+	// sending pipeline
+	sending_pipeline *io.PipeWriter
+
+	// receiving_connection_timeline_lock
+	receiving_connection_timeline_lock sync.Mutex
+	// receiving reader
+	receiving_reader io.ReadCloser
+
+	// max retries
+	max_retries int
+
+	// is sending connection alive
+	sending_connection_alive int32
+
+	// is receiving connection alive
+	receiving_connection_alive int32
+
+	// listener for data
+	listeners []func(data []byte)
+
+	// mutex for listeners
+	listeners_mu sync.RWMutex
+
+	// request id
+	request_id string
+
+	// http client
+	client *http.Client
+}
+
+func NewFullDuplexSimulator(baseurl string) (*FullDuplexSimulator, error) {
+	u, err := url.Parse(baseurl)
+	if err != nil {
+		return nil, err
+	}
+
+	return &FullDuplexSimulator{
+		baseurl:     u,
+		max_retries: 10,
+		request_id:  strings.RandomString(32),
+
+		// using keep alive to reduce the connection reset
+		client: &http.Client{
+			Transport: &http.Transport{
+				Dial: (&net.Dialer{
+					Timeout:   5 * time.Second,
+					KeepAlive: 15 * time.Second,
+				}).Dial,
+				IdleConnTimeout: 120 * time.Second,
+			},
+		},
+	}, nil
+}
+
+// send data to server
+func (s *FullDuplexSimulator) Send(data []byte) error {
+	if atomic.LoadInt32(&s.sending_connection_alive) == 0 {
+		return errors.New("sending connection is not alive")
+	}
+
+	writer := s.sending_pipeline
+	if writer == nil {
+		return errors.New("sending pipeline is not alive")
+	}
+
+	if _, err := writer.Write(data); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (s *FullDuplexSimulator) On(f func(data []byte)) {
+	s.listeners_mu.Lock()
+	defer s.listeners_mu.Unlock()
+	s.listeners = append(s.listeners, f)
+}
+
+// start a transaction
+// returns a function to stop the transaction
+func (s *FullDuplexSimulator) StartTransaction() (func(), error) {
+	// start a transaction
+	atomic.AddInt32(&s.alive_transactions, 1)
+	atomic.AddInt32(&s.total_transactions, 1)
+
+	// start sending connection
+	if err := s.startSendingConnection(); err != nil {
+		return nil, err
+	}
+
+	// start receiving connection
+	if err := s.startReceivingConnection(); err != nil {
+		return nil, err
+	}
+
+	return s.stopTransaction, nil
+}
+
+func (s *FullDuplexSimulator) stopTransaction() {
+	// close if no transaction is alive
+	if atomic.AddInt32(&s.alive_transactions, -1) == 0 {
+		s.stopSendingConnection()
+		s.stopReceivingConnection()
+	}
+}
+
+func (s *FullDuplexSimulator) startSendingConnection() error {
+	if atomic.LoadInt32(&s.sending_connection_alive) == 1 {
+		return nil
+	}
+
+	// lock the sending connection
+	s.sending_connection_timeline_lock.Lock()
+	defer s.sending_connection_timeline_lock.Unlock()
+
+	// start a new sending connection
+	u, err := url.JoinPath(s.baseurl.String(), "/invoke")
+	if err != nil {
+		return err
+	}
+
+	pr, pw := io.Pipe()
+
+	req, err := http.NewRequest("POST", u, pr)
+	if err != nil {
+		return err
+	}
+
+	req.Header.Set("Content-Type", "octet-stream")
+	req.Header.Set("Connection", "keep-alive")
+	req.Header.Set("x-dify-plugin-request-id", s.request_id)
+
+	routine.Submit(func() {
+		s.sendingConnectionRoutine(req)
+	})
+
+	// mark sending connection as alive
+	atomic.StoreInt32(&s.sending_connection_alive, 1)
+
+	// set the sending pipeline
+	s.sending_pipeline = pw
+
+	return nil
+}
+
+func (s *FullDuplexSimulator) sendingConnectionRoutine(origin_req *http.Request) {
+	failed_times := 0
+	for {
+		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+		time.AfterFunc(5*time.Second, func() {
+			cancel()
+		})
+		req := origin_req.Clone(ctx)
+		req = req.WithContext(ctx)
+		resp, err := s.client.Do(req)
+
+		if err != nil {
+			failed_times++
+			if failed_times > s.max_retries {
+				log.Error("failed to establish sending connection: %v", err)
+				s.stopSendingConnection()
+				return
+			}
+
+			log.Error("failed to establish sending connection: %v", err)
+			continue
+		}
+
+		defer resp.Body.Close()
+
+		// mark sending connection as dead
+		atomic.StoreInt32(&s.sending_connection_alive, 0)
+
+		s.sending_connection_timeline_lock.Lock()
+		defer s.sending_connection_timeline_lock.Unlock()
+
+		// close the sending pipeline
+		if s.sending_pipeline != nil {
+			s.sending_pipeline.Close()
+			s.sending_pipeline = nil
+		}
+	}
+}
+
+func (s *FullDuplexSimulator) stopSendingConnection() error {
+	if atomic.LoadInt32(&s.sending_connection_alive) == 0 {
+		return nil
+	}
+
+	s.sending_connection_timeline_lock.Lock()
+	defer s.sending_connection_timeline_lock.Unlock()
+
+	// close the sending pipeline
+	if s.sending_pipeline != nil {
+		s.sending_pipeline.Close()
+		s.sending_pipeline = nil
+	}
+
+	// mark sending connection as dead
+	atomic.StoreInt32(&s.sending_connection_alive, 0)
+
+	return nil
+}
+
+func (s *FullDuplexSimulator) startReceivingConnection() error {
+	if atomic.LoadInt32(&s.receiving_connection_alive) == 1 {
+		return nil
+	}
+
+	// lock the receiving connection
+	s.receiving_connection_timeline_lock.Lock()
+	defer s.receiving_connection_timeline_lock.Unlock()
+
+	// start a new receiving connection
+	u, err := url.JoinPath(s.baseurl.String(), "/response")
+	if err != nil {
+		return err
+	}
+
+	req, err := http.NewRequest("GET", u, nil)
+	if err != nil {
+		return err
+	}
+	req.Header.Set("Content-Type", "octet-stream")
+	req.Header.Set("Connection", "keep-alive")
+	req.Header.Set("x-dify-plugin-request-id", s.request_id)
+
+	req = req.Clone(context.Background())
+	resp, err := s.client.Do(req)
+	if err != nil {
+		return errors.Join(err, errors.New("failed to establish receiving connection"))
+	}
+
+	routine.Submit(func() {
+		s.receivingConnectionRoutine(req, resp.Body)
+	})
+
+	// mark receiving connection as alive
+	atomic.StoreInt32(&s.receiving_connection_alive, 1)
+
+	return nil
+}
+
+func (s *FullDuplexSimulator) receivingConnectionRoutine(req *http.Request, reader io.ReadCloser) {
+	failed_times := 0
+	for {
+		s.receiving_reader = reader
+		recved_pong := false
+		buf := make([]byte, 0)
+		buf_len := 0
+
+		for {
+			data := make([]byte, 1024)
+			n, err := reader.Read(data)
+			if err != nil {
+				break
+			}
+
+			// check if pong\n is at the beginning of the data
+			if !recved_pong {
+				data = append(data, buf[:buf_len]...)
+				buf = make([]byte, 0)
+				buf_len = 0
+
+				if n >= 5 {
+					if string(data[:5]) == "pong\n" {
+						recved_pong = true
+						// remove pong\n from the beginning of the data
+						data = data[5:]
+						n -= 5
+					} else {
+						// not pong\n, break
+						break
+					}
+				} else if n < 5 {
+					// save the data to the buffer
+					buf = append(buf, data[:n]...)
+					buf_len += n
+					continue
+				}
+			}
+
+			for _, listener := range s.listeners[:] {
+				listener(data[:n])
+			}
+		}
+
+		s.receiving_reader = nil
+
+		s.receiving_connection_timeline_lock.Lock()
+		if atomic.LoadInt32(&s.receiving_connection_alive) == 0 {
+			s.receiving_connection_timeline_lock.Unlock()
+			return
+		}
+		s.receiving_connection_timeline_lock.Unlock()
+
+		req = req.Clone(context.Background())
+		resp, err := s.client.Do(req)
+		if err != nil {
+			failed_times++
+			if failed_times > s.max_retries {
+				log.Error("failed to establish receiving connection: %v", err)
+				s.stopReceivingConnection()
+				return
+			}
+
+			log.Error("failed to establish receiving connection: %v", err)
+			continue
+		}
+
+		reader = resp.Body
+	}
+}
+
+func (s *FullDuplexSimulator) stopReceivingConnection() {
+	if atomic.LoadInt32(&s.receiving_connection_alive) == 0 {
+		return
+	}
+
+	// mark receiving connection as dead
+	atomic.StoreInt32(&s.receiving_connection_alive, 0)
+
+	s.receiving_connection_timeline_lock.Lock()
+	defer s.receiving_connection_timeline_lock.Unlock()
+
+	// close the receiving reader
+	reader := s.receiving_reader
+	if reader != nil {
+		reader.Close()
+	}
+}

+ 149 - 0
internal/core/plugin_manager/aws_manager/full_duplex_simulator_test.go

@@ -0,0 +1,149 @@
+package aws_manager
+
+import (
+	"bytes"
+	"fmt"
+	"net/http"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/network"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
+)
+
+func server(timeout time.Duration) (string, func(), error) {
+	routine.InitPool(1024)
+
+	port, err := network.GetRandomPort()
+	if err != nil {
+		return "", nil, err
+	}
+
+	data := map[string]chan []byte{}
+	data_mu := sync.Mutex{}
+
+	eng := gin.New()
+	eng.POST("/invoke", func(c *gin.Context) {
+		id := c.Request.Header.Get("x-dify-plugin-request-id")
+		var ch chan []byte
+
+		data_mu.Lock()
+		if _, ok := data[id]; !ok {
+			ch = make(chan []byte, 1024)
+			data[id] = ch
+		} else {
+			ch = data[id]
+		}
+		data_mu.Unlock()
+
+		time.AfterFunc(timeout, func() {
+			c.Request.Body.Close()
+		})
+
+		// read data asynchronously
+		for {
+			buf := make([]byte, 1024)
+			n, err := c.Request.Body.Read(buf)
+			if err != nil {
+				break
+			}
+			ch <- buf[:n]
+		}
+
+		c.Status(http.StatusOK)
+	})
+
+	eng.GET("/response", func(ctx *gin.Context) {
+		id := ctx.Request.Header.Get("x-dify-plugin-request-id")
+		var ch chan []byte
+		data_mu.Lock()
+		if _, ok := data[id]; ok {
+			ch = data[id]
+		} else {
+			ch = make(chan []byte, 1024)
+			data[id] = ch
+		}
+		data_mu.Unlock()
+
+		ctx.Writer.WriteHeader(http.StatusOK)
+		ctx.Writer.Header().Set("Content-Type", "application/octet-stream")
+		ctx.Writer.Header().Set("Transfer-Encoding", "chunked")
+		ctx.Writer.Header().Set("Connection", "keep-alive")
+		ctx.Writer.Write([]byte("pong\n"))
+		ctx.Writer.Flush()
+
+		for {
+			select {
+			case data := <-ch:
+				ctx.Writer.Write(data)
+				ctx.Writer.Flush()
+			case <-ctx.Done():
+				return
+			case <-ctx.Writer.CloseNotify():
+				return
+			case <-time.After(timeout):
+				ctx.Status(http.StatusOK)
+				return
+			}
+		}
+	})
+
+	srv := &http.Server{
+		Addr:    fmt.Sprintf(":%d", port),
+		Handler: eng,
+	}
+
+	go func() {
+		srv.ListenAndServe()
+	}()
+
+	return fmt.Sprintf("http://localhost:%d", port), func() {
+		srv.Close()
+	}, nil
+}
+
+func TestFullDuplexSimulator_Send(t *testing.T) {
+	url, cleanup, err := server(time.Second * 100)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer cleanup()
+
+	time.Sleep(time.Second)
+
+	simulator, err := NewFullDuplexSimulator(url)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	recved := make([]byte, 0)
+
+	simulator.On(func(data []byte) {
+		if len(bytes.TrimSpace(data)) == 0 {
+			return
+		}
+
+		recved = append(recved, data...)
+	})
+
+	if done, err := simulator.StartTransaction(); err != nil {
+		t.Fatal(err)
+	} else {
+		defer done()
+	}
+
+	if err := simulator.Send([]byte("hello\n")); err != nil {
+		t.Fatal(err)
+	}
+	if err := simulator.Send([]byte("world\n")); err != nil {
+		t.Fatal(err)
+	}
+
+	time.Sleep(time.Millisecond * 500)
+
+	if string(recved) != "hello\nworld\n" {
+		t.Fatal(fmt.Sprintf("recved: %s", string(recved)))
+	}
+}