Selaa lähdekoodia

refactor: aws lambda io

Yeuoly 11 kuukautta sitten
vanhempi
commit
0a9c71fbab

+ 4 - 1
internal/core/plugin_daemon/backwards_invocation/transaction/aws_event_writer.go

@@ -17,7 +17,10 @@ type AWSTransactionWriter struct {
 }
 
 // NewAWSTransactionWriter creates a new transaction writer
-func NewAWSTransactionWriter(session *session_manager.Session, writeCloser io.WriteCloser) *AWSTransactionWriter {
+func NewAWSTransactionWriter(
+	session *session_manager.Session,
+	writeCloser io.WriteCloser,
+) *AWSTransactionWriter {
 	return &AWSTransactionWriter{
 		session:     session,
 		writeCloser: writeCloser,

+ 48 - 71
internal/core/plugin_manager/aws_manager/io.go

@@ -1,104 +1,81 @@
 package aws_manager
 
 import (
+	"bufio"
+	"bytes"
 	"context"
 	"fmt"
-	"io"
 	"net/http"
 	"net/url"
 	"time"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
 
-// consume data from data stream
-func (r *AWSPluginRuntime) consume() {
-	for {
-		select {
-		case data := <-r.data_stream:
-			fmt.Println(data)
-		}
-	}
-}
-
 func (r *AWSPluginRuntime) Listen(session_id string) *entities.Broadcast[plugin_entities.SessionMessage] {
 	l := entities.NewBroadcast[plugin_entities.SessionMessage]()
-	l.OnClose(func() {
-		// close the pipe writer
-		writer, exists := r.session_pool.Load(session_id)
-		if exists {
-			writer.Close()
-		}
-	})
+	// store the listener
+	r.listeners.Store(session_id, l)
 	return l
 }
 
+// For AWS Lambda, write is equivalent to http request, it's not a normal stream like stdio and tcp
 func (r *AWSPluginRuntime) Write(session_id string, data []byte) {
-	// check if session exists
-	var pw *io.PipeWriter
-	var exists bool
+	l, ok := r.listeners.Load(session_id)
+	if !ok {
+		log.Error("session %s not found", session_id)
+		return
+	}
 
-	if pw, exists = r.session_pool.Load(session_id); !exists {
-		url, err := url.JoinPath(r.lambda_url, "invoke")
-		if err != nil {
-			r.Error(fmt.Sprintf("Error creating request: %v", err))
-			return
-		}
+	url, err := url.JoinPath(r.lambda_url, "invoke")
+	if err != nil {
+		r.Error(fmt.Sprintf("Error creating request: %v", err))
+		return
+	}
 
-		// create a new http request here
-		npr, npw := io.Pipe()
-		r.session_pool.Store(session_id, npw)
-		pw = npw
+	// create a new http request
+	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
+	defer cancel()
+	req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(data))
+	if err != nil {
+		r.Error(fmt.Sprintf("Error creating request: %v", err))
+		return
+	}
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("Accept", "text/event-stream")
+
+	routine.Submit(func() {
+		// remove the session from listeners
+		defer r.listeners.Delete(session_id)
 
-		// create a new http request
-		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
-		defer cancel()
-		req, err := http.NewRequestWithContext(ctx, "POST", url, npr)
+		response, err := r.client.Do(req)
 		if err != nil {
-			r.Error(fmt.Sprintf("Error creating request: %v", err))
+			r.Error(fmt.Sprintf("Error sending request to aws lambda: %v", err))
 			return
 		}
 
-		req.Header.Set("Content-Type", "application/octet-stream")
-		req.Header.Set("Accept", "application/octet-stream")
-
-		routine.Submit(func() {
-			response, err := http.DefaultClient.Do(req)
-			if err != nil {
-				r.Error(fmt.Sprintf("Error sending request to aws lambda: %v", err))
-				return
+		// write to data stream
+		scanner := bufio.NewScanner(response.Body)
+		for scanner.Scan() {
+			bytes := scanner.Bytes()
+			if len(bytes) == 0 {
+				continue
 			}
 
-			// write to data stream
-			for {
-				buf := make([]byte, 1024)
-				n, err := response.Body.Read(buf)
-				if err != nil {
-					if err == io.EOF {
-						break
-					} else {
-						r.Error(fmt.Sprintf("Error reading response from aws lambda: %v", err))
-						break
-					}
-				}
-				// write to data stream
-				select {
-				case r.data_stream <- buf[:n]:
-				default:
-					r.Error("Data stream is full")
-				}
+			data, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](bytes)
+			if err != nil {
+				log.Error("unmarshal json failed: %s, failed to parse session message", err.Error())
+				continue
 			}
 
-			// remove the session from the pool
-			r.session_pool.Delete(session_id)
-		})
-	}
-
-	if pw != nil {
-		if _, err := pw.Write(data); err != nil {
-			r.Error(fmt.Sprintf("Error writing to pipe writer: %v", err))
+			data.RuntimeType = r.Type()
+			l.Send(data)
 		}
-	}
+
+		l.Close()
+	})
 }

+ 4 - 6
internal/core/plugin_manager/aws_manager/type.go

@@ -1,10 +1,10 @@
 package aws_manager
 
 import (
-	"io"
 	"net/http"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/positive_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/mapping"
 )
@@ -17,10 +17,8 @@ type AWSPluginRuntime struct {
 	lambda_url  string
 	lambda_name string
 
-	client *http.Client
-
-	session_pool mapping.Map[string, *io.PipeWriter]
+	// listeners mapping session id to the listener
+	listeners mapping.Map[string, *entities.Broadcast[plugin_entities.SessionMessage]]
 
-	// data stream take responsibility of listen to response from lambda and redirect to runtime listener
-	data_stream chan []byte
+	client *http.Client
 }