Browse Source

refactor: support aws lambda using persistence storage

Yeuoly 11 months ago
parent
commit
1115a3f669

+ 2 - 1
internal/core/plugin_daemon/backwards_invocation/task.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
@@ -370,7 +371,7 @@ func executeDifyInvocationStorageTask(
 		return
 	}
 
-	persistence := handle.session.Storage()
+	persistence := persistence.GetPersistence()
 	if persistence == nil {
 		handle.WriteError(fmt.Errorf("persistence not found"))
 		return

+ 3 - 3
internal/core/plugin_daemon/backwards_invocation/transaction/aws_event_handler.go

@@ -68,10 +68,10 @@ func (h *AWSTransactionHandler) Handle(
 	}
 
 	session := session_manager.GetSession(session_id)
-	if err != nil {
-		log.Error("get session failed: %s", err.Error())
+	if session == nil {
+		log.Error("session not found: %s", session_id)
 		writer.WriteHeader(http.StatusInternalServerError)
-		writer.Write([]byte(err.Error()))
+		writer.Write([]byte("session not found"))
 		return
 	}
 

+ 13 - 8
internal/core/session_manager/session.go

@@ -47,7 +47,6 @@ func NewSession(
 	invoke_from access_types.PluginAccessType,
 	action access_types.PluginAccessAction,
 	declaration *plugin_entities.PluginDeclaration,
-	persistence *persistence.Persistence,
 ) *Session {
 	s := &Session{
 		ID:             uuid.New().String(),
@@ -58,7 +57,6 @@ func NewSession(
 		InvokeFrom:     invoke_from,
 		Action:         action,
 		Declaration:    declaration,
-		persistence:    persistence,
 	}
 
 	session_lock.Lock()
@@ -74,9 +72,20 @@ func NewSession(
 
 func GetSession(id string) *Session {
 	session_lock.RLock()
-	defer session_lock.RUnlock()
+	session := sessions[id]
+	session_lock.RUnlock()
+
+	if session == nil {
+		// if session not found, it may be generated by another node, try to get it from cache
+		session, err := cache.Get[Session](sessionKey(id))
+		if err != nil {
+			log.Error("get session info from cache failed, %s", err)
+			return nil
+		}
+		return session
+	}
 
-	return sessions[id]
+	return session
 }
 
 func DeleteSession(id string) {
@@ -101,10 +110,6 @@ func (s *Session) Runtime() plugin_entities.PluginRuntimeInterface {
 	return s.runtime
 }
 
-func (s *Session) Storage() *persistence.Persistence {
-	return s.persistence
-}
-
 type PLUGIN_IN_STREAM_EVENT string
 
 const (

+ 0 - 8
internal/service/aws_transaction.go

@@ -1,22 +1,14 @@
 package service
 
 import (
-	"net/http"
-
 	"github.com/gin-gonic/gin"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation/transaction"
-	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 )
 
 func HandleAWSPluginTransaction(handler *transaction.AWSTransactionHandler) gin.HandlerFunc {
 	return func(c *gin.Context) {
 		// get session id from the context
 		session_id := c.Request.Header.Get("Dify-Plugin-Session-ID")
-		session := session_manager.GetSession(session_id)
-		if session == nil {
-			c.JSON(http.StatusBadRequest, gin.H{"error": "session not found"})
-			return
-		}
 
 		handler.Handle(c, session_id)
 	}

+ 0 - 8
internal/service/endpoint.go

@@ -8,7 +8,6 @@ import (
 	"time"
 
 	"github.com/gin-gonic/gin"
-	"github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
@@ -41,12 +40,6 @@ func Endpoint(
 		return
 	}
 
-	persistence := persistence.GetPersistence()
-	if persistence == nil {
-		ctx.JSON(500, gin.H{"error": "persistence not found"})
-		return
-	}
-
 	session := session_manager.NewSession(
 		endpoint.TenantID,
 		"",
@@ -55,7 +48,6 @@ func Endpoint(
 		access_types.PLUGIN_ACCESS_TYPE_Endpoint,
 		access_types.PLUGIN_ACCESS_ACTION_INVOKE_ENDPOINT,
 		runtime.Configuration(),
-		persistence,
 	)
 	defer session.Close()
 

+ 0 - 9
internal/service/invoke_tool.go

@@ -1,10 +1,7 @@
 package service
 
 import (
-	"errors"
-
 	"github.com/gin-gonic/gin"
-	"github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
@@ -22,11 +19,6 @@ func createSession[T any](
 	access_action access_types.PluginAccessAction,
 	cluster_id string,
 ) (*session_manager.Session, error) {
-	persistence := persistence.GetPersistence()
-	if persistence == nil {
-		return nil, errors.New("persistence not found")
-	}
-
 	plugin_identity := parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)
 	runtime := plugin_manager.GetGlobalPluginManager().Get(plugin_identity)
 
@@ -38,7 +30,6 @@ func createSession[T any](
 		access_type,
 		access_action,
 		runtime.Configuration(),
-		persistence,
 	)
 
 	session.BindRuntime(runtime)