Browse Source

improve: stream response

Yeuoly 11 months ago
parent
commit
17eb689a8c
2 changed files with 43 additions and 18 deletions
  1. 9 18
      internal/utils/stream/response.go
  2. 34 0
      tests/benchmark/stream/stream_response_test.go

+ 9 - 18
internal/utils/stream/response.go

@@ -3,6 +3,7 @@ package stream
 import (
 	"errors"
 	"sync"
+	"sync/atomic"
 
 	"github.com/gammazero/deque"
 )
@@ -11,7 +12,7 @@ type StreamResponse[T any] struct {
 	q         deque.Deque[T]
 	l         *sync.Mutex
 	sig       chan bool
-	closed    bool
+	closed    int32
 	max       int
 	listening bool
 	onClose   func()
@@ -36,7 +37,7 @@ func (r *StreamResponse[T]) OnClose(f func()) {
 // NOTE: even if the stream is closed, it will return true if there is data available
 func (r *StreamResponse[T]) Next() bool {
 	r.l.Lock()
-	if r.closed && r.q.Len() == 0 && r.err == nil {
+	if r.closed == 1 && r.q.Len() == 0 && r.err == nil {
 		r.l.Unlock()
 		return false
 	}
@@ -78,12 +79,9 @@ func (r *StreamResponse[T]) Read() (T, error) {
 
 // Wrap wraps the stream with a new stream, and allows customized operations
 func (r *StreamResponse[T]) Wrap(fn func(T)) error {
-	r.l.Lock()
-	if r.closed {
-		r.l.Unlock()
+	if atomic.LoadInt32(&r.closed) == 1 {
 		return errors.New("stream is closed")
 	}
-	r.l.Unlock()
 
 	for r.Next() {
 		data, err := r.Read()
@@ -99,12 +97,12 @@ func (r *StreamResponse[T]) Wrap(fn func(T)) error {
 // Write writes data to the stream
 // returns error if the buffer is full
 func (r *StreamResponse[T]) Write(data T) error {
-	r.l.Lock()
-	if r.closed {
-		r.l.Unlock()
+	if atomic.LoadInt32(&r.closed) == 1 {
 		return nil
 	}
 
+	r.l.Lock()
+
 	if r.q.Len() >= r.max {
 		r.l.Unlock()
 		return errors.New("queue is full")
@@ -122,13 +120,9 @@ func (r *StreamResponse[T]) Write(data T) error {
 
 // Close closes the stream
 func (r *StreamResponse[T]) Close() {
-	r.l.Lock()
-	if r.closed {
-		r.l.Unlock()
+	if !atomic.CompareAndSwapInt32(&r.closed, 0, 1) {
 		return
 	}
-	r.closed = true
-	r.l.Unlock()
 
 	select {
 	case r.sig <- false:
@@ -141,10 +135,7 @@ func (r *StreamResponse[T]) Close() {
 }
 
 func (r *StreamResponse[T]) IsClosed() bool {
-	r.l.Lock()
-	defer r.l.Unlock()
-
-	return r.closed
+	return atomic.LoadInt32(&r.closed) == 1
 }
 
 func (r *StreamResponse[T]) Size() int {

+ 34 - 0
tests/benchmark/stream/stream_response_test.go

@@ -0,0 +1,34 @@
+package stream
+
+import (
+	"sync"
+	"testing"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
+)
+
+func BenchmarkStreamResponse(b *testing.B) {
+	b.Run("Read", func(b *testing.B) {
+		wg_started := sync.WaitGroup{}
+		wg_started.Add(8)
+		resp := stream.NewStreamResponse[int](1024)
+
+		for i := 0; i < 8; i++ {
+			go func() {
+				wg_started.Done()
+				for !resp.IsClosed() {
+					resp.Write(1)
+				}
+			}()
+		}
+
+		// wait for the first element to be written
+		resp.Next()
+		b.ResetTimer()
+		for i := 0; i < b.N; i++ {
+			resp.Next()
+			resp.Read()
+		}
+		defer resp.Close()
+	})
+}