瀏覽代碼

feat: aws lambda function invoke lifetime

Yeuoly 11 月之前
父節點
當前提交
32d0160262

+ 3 - 0
cmd/server/main.go

@@ -1,6 +1,8 @@
 package main
 
 import (
+	"time"
+
 	"github.com/joho/godotenv"
 	"github.com/kelseyhightower/envconfig"
 	"github.com/langgenius/dify-plugin-daemon/internal/server"
@@ -41,6 +43,7 @@ func setDefault(config *app.Config) {
 	setDefaultInt(&config.PluginRemoteInstallServerEventLoopNums, 8)
 	setDefaultInt(&config.PluginRemoteInstallingMaxConn, 128)
 	setDefaultInt(&config.MaxPluginPackageSize, 52428800)
+	setDefaultInt(&config.MaxAWSLambdaTransactionTimeout, time.Second*150)
 	setDefaultBool(&config.PluginRemoteInstallingEnabled, true)
 	setDefaultBool(&config.PluginWebhookEnabled, true)
 	setDefaultString(&config.DBSslMode, "disable")

+ 78 - 5
internal/core/plugin_daemon/backwards_invocation/transaction/aws_event_handler.go

@@ -1,14 +1,87 @@
 package transaction
 
-import "github.com/gin-gonic/gin"
+import (
+	"io"
+	"net/http"
+	"sync/atomic"
+	"time"
 
-type AWSEventHandler struct {
+	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/aws_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+)
+
+type AWSTransactionHandler struct {
+	max_timeout time.Duration
+}
+
+func NewAWSTransactionHandler(max_timeout time.Duration) *AWSTransactionHandler {
+	return &AWSTransactionHandler{
+		max_timeout: max_timeout,
+	}
+}
+
+type awsTransactionWriteCloser struct {
+	gin.ResponseWriter
+	done   chan bool
+	closed int32
 }
 
-func NewAWSEventHandler() *AWSEventHandler {
-	return &AWSEventHandler{}
+func (w *awsTransactionWriteCloser) Close() error {
+	if atomic.CompareAndSwapInt32(&w.closed, 0, 1) {
+		close(w.done)
+	}
+	return nil
 }
 
-func (h *AWSEventHandler) Handle(ctx *gin.Context) {
+func (h *AWSTransactionHandler) Handle(
+	ctx *gin.Context,
+	session_id string,
+	runtime *aws_manager.AWSPluginRuntime,
+) {
+	writer := &awsTransactionWriteCloser{
+		ResponseWriter: ctx.Writer,
+		done:           make(chan bool),
+	}
+
+	body := ctx.Request.Body
+	// read at most 6MB
+	bytes, err := io.ReadAll(io.LimitReader(body, 6*1024*1024))
+	if err != nil {
+		writer.WriteHeader(http.StatusBadRequest)
+		writer.Write([]byte(err.Error()))
+		return
+	}
+
+	writer.WriteHeader(http.StatusOK)
+	writer.Header().Set("Content-Type", "text/event-stream")
+
+	// parse the data
+	data, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](bytes)
+	if err != nil {
+		log.Error("unmarshal json failed: %s, failed to parse session message", err.Error())
+		writer.WriteHeader(http.StatusBadRequest)
+		writer.Write([]byte(err.Error()))
+		return
+	}
+
+	data.RuntimeType = plugin_entities.PLUGIN_RUNTIME_TYPE_AWS
+	data.SessionWriter = writer
+
+	// send the data to the plugin runtime
+	if err := runtime.PushRequest(session_id, data); err != nil {
+		log.Error("push request failed: %s", err.Error())
+		writer.WriteHeader(http.StatusInternalServerError)
+		writer.Write([]byte(err.Error()))
+		return
+	}
 
+	select {
+	case <-writer.done:
+		return
+	case <-time.After(h.max_timeout):
+		return
+	}
 }

+ 14 - 0
internal/core/plugin_manager/aws_manager/io.go

@@ -80,3 +80,17 @@ func (r *AWSPluginRuntime) Write(session_id string, data []byte) {
 		l.Close()
 	})
 }
+
+func (r *AWSPluginRuntime) PushRequest(session_id string, data plugin_entities.SessionMessage) error {
+	if r.Type() != data.RuntimeType {
+		return fmt.Errorf("runtime type mismatch")
+	}
+
+	broadcast, ok := r.listeners.Load(session_id)
+	if !ok {
+		return fmt.Errorf("session %s not found", session_id)
+	}
+
+	broadcast.Send(data)
+	return nil
+}

+ 6 - 2
internal/core/session_manager/session.go

@@ -19,7 +19,7 @@ var (
 // session need to implement the backwards_invocation.BackwardsInvocationWriter interface
 type Session struct {
 	id      string
-	runtime plugin_entities.PluginRuntimeSessionIOInterface
+	runtime plugin_entities.PluginRuntimeInterface
 
 	tenant_id       string
 	user_id         string
@@ -108,10 +108,14 @@ func (s *Session) PluginIdentity() string {
 	return s.plugin_identity
 }
 
-func (s *Session) BindRuntime(runtime plugin_entities.PluginRuntimeSessionIOInterface) {
+func (s *Session) BindRuntime(runtime plugin_entities.PluginRuntimeInterface) {
 	s.runtime = runtime
 }
 
+func (s *Session) Runtime() plugin_entities.PluginRuntimeInterface {
+	return s.runtime
+}
+
 type PLUGIN_IN_STREAM_EVENT string
 
 const (

+ 1 - 1
internal/server/app.go

@@ -16,5 +16,5 @@ type App struct {
 
 	// aws transaction handler
 	// accept aws transaction request and forward to the plugin daemon
-	aws_transaction_handler *transaction.AWSEventHandler
+	aws_transaction_handler *transaction.AWSTransactionHandler
 }

+ 7 - 2
internal/server/http_server.go

@@ -8,6 +8,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation/transaction"
 	"github.com/langgenius/dify-plugin-daemon/internal/server/controllers"
+	"github.com/langgenius/dify-plugin-daemon/internal/service"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 )
@@ -75,7 +76,11 @@ func (app *App) webhookGroup(group *gin.RouterGroup, config *app.Config) {
 
 func (appRef *App) awsLambdaTransactionGroup(group *gin.RouterGroup, config *app.Config) {
 	if config.Platform == app.PLATFORM_AWS_LAMBDA {
-		appRef.aws_transaction_handler = transaction.NewAWSEventHandler()
-		group.POST("/transaction", appRef.RedirectAWSLambdaTransaction, appRef.aws_transaction_handler.Handle)
+		appRef.aws_transaction_handler = transaction.NewAWSTransactionHandler(config.MaxAWSLambdaTransactionTimeout)
+		group.POST(
+			"/transaction",
+			appRef.RedirectAWSLambdaTransaction,
+			service.HandleAWSPluginTransaction(appRef.aws_transaction_handler),
+		)
 	}
 }

+ 1 - 1
internal/server/middleware.go

@@ -119,7 +119,7 @@ func (app *App) redirectPluginInvokeByPluginID(ctx *gin.Context, plugin_id strin
 }
 
 func (app *App) RedirectAWSLambdaTransaction(ctx *gin.Context) {
-	session_id := ctx.GetString("session_id")
+	session_id := ctx.GetHeader("Dify-Plugin-Session-ID")
 	if session_id == "" {
 		ctx.AbortWithStatusJSON(404, gin.H{"error": "Session not found"})
 		return

+ 37 - 0
internal/service/aws_transaction.go

@@ -0,0 +1,37 @@
+package service
+
+import (
+	"net/http"
+
+	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation/transaction"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/aws_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
+)
+
+func HandleAWSPluginTransaction(handler *transaction.AWSTransactionHandler) gin.HandlerFunc {
+	return func(c *gin.Context) {
+		// get session id from the context
+		session_id := c.GetString("session_id")
+		session := session_manager.GetSession(session_id)
+		if session == nil {
+			c.JSON(http.StatusBadRequest, gin.H{"error": "session not found"})
+			return
+		}
+
+		// get runtime from the session
+		runtime := session.Runtime()
+		if runtime == nil {
+			c.JSON(http.StatusBadRequest, gin.H{"error": "runtime not found"})
+			return
+		}
+
+		aws_runtime, ok := runtime.(*aws_manager.AWSPluginRuntime)
+		if !ok {
+			c.JSON(http.StatusBadRequest, gin.H{"error": "runtime is not aws plugin runtime"})
+			return
+		}
+
+		handler.Handle(c, session_id, aws_runtime)
+	}
+}

+ 7 - 0
internal/types/app/config.go

@@ -2,6 +2,7 @@ package app
 
 import (
 	"fmt"
+	"time"
 
 	"github.com/go-playground/validator/v10"
 )
@@ -49,6 +50,8 @@ type Config struct {
 	DifyPluginServerlessConnectorAPIKey *string `envconfig:"DIFY_PLUGIN_SERVERLESS_CONNECTOR_API_KEY"`
 
 	MaxPluginPackageSize int64 `envconfig:"MAX_PLUGIN_PACKAGE_SIZE" validate:"required"`
+
+	MaxAWSLambdaTransactionTimeout time.Duration `envconfig:"MAX_AWS_LAMBDA_TRANSACTION_TIMEOUT"`
 }
 
 func (c *Config) Validate() error {
@@ -81,6 +84,10 @@ func (c *Config) Validate() error {
 		if c.DifyPluginServerlessConnectorAPIKey == nil {
 			return fmt.Errorf("dify plugin serverless connector api key is empty")
 		}
+
+		if c.MaxAWSLambdaTransactionTimeout == 0 {
+			return fmt.Errorf("max aws lambda transaction timeout is empty")
+		}
 	} else if c.Platform == PLATFORM_LOCAL {
 		if c.PluginWorkingPath == "" {
 			return fmt.Errorf("plugin working path is empty")