|
@@ -4,48 +4,89 @@ import (
|
|
|
"bytes"
|
|
|
"fmt"
|
|
|
"net/http"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
"sync"
|
|
|
+ "sync/atomic"
|
|
|
"testing"
|
|
|
"time"
|
|
|
|
|
|
"github.com/gin-gonic/gin"
|
|
|
+ "github.com/langgenius/dify-plugin-daemon/internal/utils/debugging"
|
|
|
"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
|
|
|
"github.com/langgenius/dify-plugin-daemon/internal/utils/network"
|
|
|
"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
|
|
|
)
|
|
|
|
|
|
-func server(recv_timeout time.Duration, send_timeout time.Duration) (string, func(), error) {
|
|
|
+func init() {
|
|
|
routine.InitPool(1024)
|
|
|
+}
|
|
|
+
|
|
|
+type S struct {
|
|
|
+ srv *http.Server
|
|
|
+ url string
|
|
|
+ port int
|
|
|
+
|
|
|
+ send_count int32
|
|
|
+ recv_buffered_count int32
|
|
|
+ recv_count int32
|
|
|
+
|
|
|
+ send_request int32
|
|
|
+ recv_request int32
|
|
|
+
|
|
|
+ data_mu sync.Mutex
|
|
|
+ data map[string]chan []byte
|
|
|
+
|
|
|
+ current_recv_request_id string
|
|
|
+ current_send_request_id string
|
|
|
+}
|
|
|
+
|
|
|
+func (s *S) Stop() {
|
|
|
+ s.srv.Close()
|
|
|
+}
|
|
|
|
|
|
+func server(recv_timeout time.Duration, send_timeout time.Duration) (*S, error) {
|
|
|
port, err := network.GetRandomPort()
|
|
|
if err != nil {
|
|
|
- return "", nil, err
|
|
|
+ return nil, err
|
|
|
}
|
|
|
|
|
|
- data := map[string]chan []byte{}
|
|
|
- data_mu := sync.Mutex{}
|
|
|
+ eng := gin.New()
|
|
|
+ srv := &http.Server{
|
|
|
+ Addr: fmt.Sprintf(":%d", port),
|
|
|
+ Handler: eng,
|
|
|
+ }
|
|
|
|
|
|
- recved := 0
|
|
|
+ s := &S{
|
|
|
+ srv: srv,
|
|
|
+ url: fmt.Sprintf("http://localhost:%d", port),
|
|
|
+ data: make(map[string]chan []byte),
|
|
|
|
|
|
- eng := gin.New()
|
|
|
+ send_count: 0,
|
|
|
+ recv_count: 0,
|
|
|
+ }
|
|
|
|
|
|
// avoid log
|
|
|
gin.SetMode(gin.ReleaseMode)
|
|
|
|
|
|
eng.POST("/invoke", func(c *gin.Context) {
|
|
|
+ atomic.AddInt32(&s.send_request, 1)
|
|
|
+ defer atomic.AddInt32(&s.send_request, -1)
|
|
|
+
|
|
|
// fmt.Println("new send request")
|
|
|
id := c.Request.Header.Get("x-dify-plugin-request-id")
|
|
|
+ s.current_send_request_id = id
|
|
|
+
|
|
|
var ch chan []byte
|
|
|
|
|
|
- data_mu.Lock()
|
|
|
- if _, ok := data[id]; !ok {
|
|
|
- ch = make(chan []byte, 1024)
|
|
|
- data[id] = ch
|
|
|
+ s.data_mu.Lock()
|
|
|
+ if _, ok := s.data[id]; !ok {
|
|
|
+ ch = make(chan []byte)
|
|
|
+ s.data[id] = ch
|
|
|
} else {
|
|
|
- ch = data[id]
|
|
|
+ ch = s.data[id]
|
|
|
}
|
|
|
- data_mu.Unlock()
|
|
|
+ s.data_mu.Unlock()
|
|
|
|
|
|
time.AfterFunc(send_timeout, func() {
|
|
|
c.Request.Body.Close()
|
|
@@ -56,8 +97,9 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (string, fun
|
|
|
buf := make([]byte, 1024)
|
|
|
n, err := c.Request.Body.Read(buf)
|
|
|
if n != 0 {
|
|
|
- recved += n
|
|
|
+ atomic.AddInt32(&s.recv_buffered_count, int32(n))
|
|
|
ch <- buf[:n]
|
|
|
+ atomic.AddInt32(&s.recv_count, int32(n))
|
|
|
}
|
|
|
if err != nil {
|
|
|
break
|
|
@@ -70,20 +112,23 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (string, fun
|
|
|
c.Writer.Flush()
|
|
|
})
|
|
|
|
|
|
- response := 0
|
|
|
-
|
|
|
eng.GET("/response", func(ctx *gin.Context) {
|
|
|
+ atomic.AddInt32(&s.recv_request, 1)
|
|
|
+ defer atomic.AddInt32(&s.recv_request, -1)
|
|
|
+
|
|
|
// fmt.Println("new recv request")
|
|
|
id := ctx.Request.Header.Get("x-dify-plugin-request-id")
|
|
|
+ s.current_recv_request_id = id
|
|
|
+
|
|
|
var ch chan []byte
|
|
|
- data_mu.Lock()
|
|
|
- if _, ok := data[id]; ok {
|
|
|
- ch = data[id]
|
|
|
+ s.data_mu.Lock()
|
|
|
+ if _, ok := s.data[id]; ok {
|
|
|
+ ch = s.data[id]
|
|
|
} else {
|
|
|
- ch = make(chan []byte, 1024)
|
|
|
- data[id] = ch
|
|
|
+ ch = make(chan []byte)
|
|
|
+ s.data[id] = ch
|
|
|
}
|
|
|
- data_mu.Unlock()
|
|
|
+ s.data_mu.Unlock()
|
|
|
|
|
|
ctx.Writer.WriteHeader(http.StatusOK)
|
|
|
ctx.Writer.Header().Set("Content-Type", "application/octet-stream")
|
|
@@ -99,7 +144,7 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (string, fun
|
|
|
case data := <-ch:
|
|
|
ctx.Writer.Write(data)
|
|
|
ctx.Writer.Flush()
|
|
|
- response += len(data)
|
|
|
+ atomic.AddInt32(&s.send_count, int32(len(data)))
|
|
|
case <-ctx.Done():
|
|
|
return
|
|
|
case <-ctx.Writer.CloseNotify():
|
|
@@ -111,34 +156,26 @@ func server(recv_timeout time.Duration, send_timeout time.Duration) (string, fun
|
|
|
}
|
|
|
})
|
|
|
|
|
|
- srv := &http.Server{
|
|
|
- Addr: fmt.Sprintf(":%d", port),
|
|
|
- Handler: eng,
|
|
|
- }
|
|
|
-
|
|
|
go func() {
|
|
|
srv.ListenAndServe()
|
|
|
}()
|
|
|
|
|
|
- return fmt.Sprintf("http://localhost:%d", port), func() {
|
|
|
- srv.Close()
|
|
|
- // fmt.Printf("recved: %d, responsed: %d\n", recved, response)
|
|
|
- }, nil
|
|
|
+ return s, nil
|
|
|
}
|
|
|
|
|
|
func TestFullDuplexSimulator_SingleSendAndReceive(t *testing.T) {
|
|
|
log.SetShowLog(false)
|
|
|
defer log.SetShowLog(true)
|
|
|
|
|
|
- url, cleanup, err := server(time.Second*100, time.Second*100)
|
|
|
+ srv, err := server(time.Second*100, time.Second*100)
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- defer cleanup()
|
|
|
+ defer srv.Stop()
|
|
|
|
|
|
time.Sleep(time.Second)
|
|
|
|
|
|
- simulator, err := NewFullDuplexSimulator(url, time.Second*100, time.Second*100)
|
|
|
+ simulator, err := NewFullDuplexSimulator(srv.url, time.Second*100, time.Second*100)
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
@@ -178,22 +215,22 @@ func TestFullDuplexSimulator_AutoReconnect(t *testing.T) {
|
|
|
defer log.SetShowLog(true)
|
|
|
|
|
|
// hmmm, to ensure the server is stable, we need to run the test 100 times
|
|
|
- // don't ask me why, just trust me, I have spent 1 days to handle this race condition
|
|
|
+ // don't ask me why, just trust me, I have spent 1 days to correctly handle this race condition
|
|
|
wg := sync.WaitGroup{}
|
|
|
wg.Add(100)
|
|
|
for i := 0; i < 100; i++ {
|
|
|
go func() {
|
|
|
defer wg.Done()
|
|
|
|
|
|
- url, cleanup, err := server(time.Millisecond*700, time.Second*10)
|
|
|
+ srv, err := server(time.Millisecond*700, time.Second*10)
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
|
- defer cleanup()
|
|
|
+ defer srv.Stop()
|
|
|
|
|
|
time.Sleep(time.Second)
|
|
|
|
|
|
- simulator, err := NewFullDuplexSimulator(url, time.Millisecond*700, time.Second*10)
|
|
|
+ simulator, err := NewFullDuplexSimulator(srv.url, time.Millisecond*700, time.Second*10)
|
|
|
if err != nil {
|
|
|
t.Fatal(err)
|
|
|
}
|
|
@@ -227,8 +264,8 @@ func TestFullDuplexSimulator_AutoReconnect(t *testing.T) {
|
|
|
time.Sleep(time.Millisecond * 500)
|
|
|
|
|
|
if l != 3000*5 {
|
|
|
- sent, received := simulator.GetStats()
|
|
|
- t.Errorf(fmt.Sprintf("expected: %d, actual: %d, sent: %d, received: %d", 3000*5, l, sent, received))
|
|
|
+ sent, received, restarts := simulator.GetStats()
|
|
|
+ t.Errorf(fmt.Sprintf("expected: %d, actual: %d, sent: %d, received: %d, restarts: %d", 3000*5, l, sent, received, restarts))
|
|
|
// to find which one is missing
|
|
|
for i := 0; i < 3000; i++ {
|
|
|
if !strings.Contains(recved.String(), fmt.Sprintf("%05d", i)) {
|
|
@@ -241,3 +278,127 @@ func TestFullDuplexSimulator_AutoReconnect(t *testing.T) {
|
|
|
|
|
|
wg.Wait()
|
|
|
}
|
|
|
+
|
|
|
+func TestFullDuplexSimulator_MultipleTransactions(t *testing.T) {
|
|
|
+ log.SetShowLog(false)
|
|
|
+ defer log.SetShowLog(true)
|
|
|
+
|
|
|
+ // avoid too many test cases, it will cause too many goroutines
|
|
|
+ // finally, os will run into busy, and requests can not be handled correctly in time
|
|
|
+ const NUM_CASES = 30
|
|
|
+
|
|
|
+ w := sync.WaitGroup{}
|
|
|
+ w.Add(NUM_CASES)
|
|
|
+
|
|
|
+ for j := 0; j < NUM_CASES; j++ {
|
|
|
+ // j := j
|
|
|
+ go func() {
|
|
|
+ defer w.Done()
|
|
|
+
|
|
|
+ srv, err := server(time.Millisecond*700, time.Second*10)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ defer srv.Stop()
|
|
|
+
|
|
|
+ time.Sleep(time.Second)
|
|
|
+
|
|
|
+ simulator, err := NewFullDuplexSimulator(srv.url, time.Millisecond*700, time.Second*10)
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ l := int32(0)
|
|
|
+
|
|
|
+ dones := make(map[int]func())
|
|
|
+ dones_lock := sync.Mutex{}
|
|
|
+
|
|
|
+ buf := bytes.Buffer{}
|
|
|
+ simulator.On(func(data []byte) {
|
|
|
+ debugging.PossibleBlocking(
|
|
|
+ func() any {
|
|
|
+ atomic.AddInt32(&l, int32(len(data)))
|
|
|
+
|
|
|
+ buf.Write(data)
|
|
|
+
|
|
|
+ bytes := buf.Bytes()
|
|
|
+ buf.Reset()
|
|
|
+
|
|
|
+ i := 0
|
|
|
+ for i < len(bytes) {
|
|
|
+ num, err := strconv.Atoi(string(bytes[i : i+5]))
|
|
|
+ if err != nil {
|
|
|
+ t.Fatalf("invalid data: %s", string(bytes))
|
|
|
+ }
|
|
|
+
|
|
|
+ dones_lock.Lock()
|
|
|
+
|
|
|
+ if done, ok := dones[num]; ok {
|
|
|
+ done()
|
|
|
+ } else {
|
|
|
+ t.Fatalf("done not found: %d", num)
|
|
|
+ }
|
|
|
+
|
|
|
+ dones_lock.Unlock()
|
|
|
+
|
|
|
+ i += 5
|
|
|
+ }
|
|
|
+
|
|
|
+ if buf.Len() != i {
|
|
|
+ // write the rest of the data
|
|
|
+ b := make([]byte, len(bytes)-i)
|
|
|
+ copy(b, bytes[i:])
|
|
|
+ buf.Write(b)
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil
|
|
|
+ },
|
|
|
+ time.Second*1,
|
|
|
+ func() {
|
|
|
+ t.Fatal("possible blocking triggered")
|
|
|
+ },
|
|
|
+ )
|
|
|
+ })
|
|
|
+
|
|
|
+ wg := sync.WaitGroup{}
|
|
|
+ wg.Add(100)
|
|
|
+
|
|
|
+ for i := 0; i < 100; i++ {
|
|
|
+ i := i
|
|
|
+ time.Sleep(time.Millisecond * 20)
|
|
|
+ go func() {
|
|
|
+ done, err := simulator.StartTransaction()
|
|
|
+ if err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+
|
|
|
+ dones_lock.Lock()
|
|
|
+ dones[i] = func() {
|
|
|
+ done()
|
|
|
+ wg.Done()
|
|
|
+ }
|
|
|
+ dones_lock.Unlock()
|
|
|
+
|
|
|
+ if err := simulator.Send([]byte(fmt.Sprintf("%05d", i))); err != nil {
|
|
|
+ t.Fatal(err)
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ }
|
|
|
+
|
|
|
+ // time.AfterFunc(time.Second*5, func() {
|
|
|
+ // // fmt.Println("server recv count: ", srv.recv_count, "server send count: ", srv.send_count, "j: ", j,
|
|
|
+ // // "server recv request: ", srv.recv_request, "server send request: ", srv.send_request)
|
|
|
+ // // fmt.Println("current_recv_request_id: ", srv.current_recv_request_id, "current_send_request_id: ", srv.current_send_request_id)
|
|
|
+ // })
|
|
|
+
|
|
|
+ wg.Wait()
|
|
|
+
|
|
|
+ if l != 100*5 {
|
|
|
+ sent, received, restarts := simulator.GetStats()
|
|
|
+ t.Errorf(fmt.Sprintf("expected: %d, actual: %d, sent: %d, received: %d, restarts: %d", 100*5, l, sent, received, restarts))
|
|
|
+ }
|
|
|
+ }()
|
|
|
+ }
|
|
|
+
|
|
|
+ w.Wait()
|
|
|
+}
|