Kaynağa Gözat

fix: connection closed unexpectedly

Yeuoly 7 ay önce
ebeveyn
işleme
0d51a59e9d

+ 1 - 0
internal/core/plugin_daemon/generic.go

@@ -81,6 +81,7 @@ func GenericInvokePlugin[Req any, Rsp any](
 		}
 	})
 
+	// close the listener if stream outside is closed due to close of connection
 	response.OnClose(func() {
 		listener.Close()
 	})

+ 6 - 4
internal/core/plugin_manager/remote_manager/hooks.go

@@ -66,10 +66,12 @@ func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
 			s.mediaManager,
 		),
 
-		conn:          c,
-		response:      stream.NewStream[[]byte](512),
-		callbacks:     make(map[string][]func([]byte)),
-		callbacksLock: &sync.RWMutex{},
+		conn:                      c,
+		response:                  stream.NewStream[[]byte](512),
+		messageCallbacks:          make(map[string][]func([]byte)),
+		messageCallbacksLock:      &sync.RWMutex{},
+		sessionMessageClosers:     make(map[string][]func()),
+		sessionMessageClosersLock: &sync.RWMutex{},
 
 		assets:      make(map[string]*bytes.Buffer),
 		assetsBytes: 0,

+ 25 - 2
internal/core/plugin_manager/remote_manager/io.go

@@ -1,20 +1,43 @@
 package remote_manager
 
 import (
+	"encoding/json"
+
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/exception"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 	"github.com/panjf2000/gnet/v2"
 )
 
 func (r *RemotePluginRuntime) Listen(session_id string) *entities.Broadcast[plugin_entities.SessionMessage] {
 	listener := entities.NewBroadcast[plugin_entities.SessionMessage]()
 	listener.OnClose(func() {
-		r.removeCallback(session_id)
+		// execute in new goroutine to avoid deadlock
+		routine.Submit(map[string]string{
+			"module": "remote_manager",
+			"method": "removeMessageCallbackHandler",
+		}, func() {
+			r.removeMessageCallbackHandler(session_id)
+			r.removeSessionMessageCloser(session_id)
+		})
+	})
+
+	// add session message closer to avoid unexpected connection closed
+	r.addSessionMessageCloser(session_id, func() {
+		listener.Send(plugin_entities.SessionMessage{
+			Type: plugin_entities.SESSION_MESSAGE_TYPE_ERROR,
+			Data: json.RawMessage(parser.MarshalJson(plugin_entities.ErrorResponse{
+				ErrorType: exception.PluginConnectionClosedError,
+				Message:   "Connection closed unexpectedly",
+				Args:      map[string]any{},
+			})),
+		})
 	})
 
-	r.addCallback(session_id, func(data []byte) {
+	r.addMessageCallbackHandler(session_id, func(data []byte) {
 		// unmarshal the session message
 		chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](data)
 		if err != nil {

+ 3 - 3
internal/core/plugin_manager/remote_manager/run.go

@@ -66,9 +66,9 @@ func (r *RemotePluginRuntime) StartPlugin() error {
 		plugin_entities.ParsePluginUniversalEvent(
 			data,
 			func(session_id string, data []byte) {
-				r.callbacksLock.RLock()
-				listeners := r.callbacks[session_id][:]
-				r.callbacksLock.RUnlock()
+				r.messageCallbacksLock.RLock()
+				listeners := r.messageCallbacks[session_id][:]
+				r.messageCallbacksLock.RUnlock()
 
 				// handle session event
 				for _, listener := range listeners {

+ 64 - 14
internal/core/plugin_manager/remote_manager/type.go

@@ -3,6 +3,7 @@ package remote_manager
 import (
 	"bytes"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/basic_manager"
@@ -26,9 +27,13 @@ type RemotePluginRuntime struct {
 	// response entity to accept new events
 	response *stream.Stream[[]byte]
 
-	// callbacks for each session
-	callbacks     map[string][]func([]byte)
-	callbacksLock *sync.RWMutex
+	// messageCallbacks for each session
+	messageCallbacks     map[string][]func([]byte)
+	messageCallbacksLock *sync.RWMutex
+
+	// sessionMessageCloser for each session
+	sessionMessageClosers     map[string][]func()
+	sessionMessageClosersLock *sync.RWMutex
 
 	// channel to notify all waiting routines
 	shutdownChan chan bool
@@ -74,25 +79,70 @@ type RemotePluginRuntime struct {
 	waitLaunchedChanOnce sync.Once
 }
 
+// TODO: unify below methods to a standard interface
+
 // Listen creates a new listener for the given session_id
 // session id is an unique identifier for a request
-func (r *RemotePluginRuntime) addCallback(session_id string, fn func([]byte)) {
-	r.callbacksLock.Lock()
-	if _, ok := r.callbacks[session_id]; !ok {
-		r.callbacks[session_id] = make([]func([]byte), 0)
+func (r *RemotePluginRuntime) addMessageCallbackHandler(session_id string, fn func([]byte)) {
+	r.messageCallbacksLock.Lock()
+	if _, ok := r.messageCallbacks[session_id]; !ok {
+		r.messageCallbacks[session_id] = make([]func([]byte), 0)
+	}
+	r.messageCallbacks[session_id] = append(r.messageCallbacks[session_id], fn)
+	r.messageCallbacksLock.Unlock()
+}
+
+// removeMessageCallbackHandler removes the listener for the given session_id
+func (r *RemotePluginRuntime) removeMessageCallbackHandler(session_id string) {
+	r.messageCallbacksLock.Lock()
+	delete(r.messageCallbacks, session_id)
+	r.messageCallbacksLock.Unlock()
+}
+
+// addSessionMessageCloser adds a closer for the given session_id
+// once the session is closed or the connection is closed, the closer will be called
+func (r *RemotePluginRuntime) addSessionMessageCloser(session_id string, fn func()) {
+	// do nothing if the session is already closed
+	if atomic.LoadInt32(&r.closed) == 1 {
+		return
+	}
+
+	r.sessionMessageClosersLock.Lock()
+	if _, ok := r.sessionMessageClosers[session_id]; !ok {
+		r.sessionMessageClosers[session_id] = make([]func(), 0)
 	}
-	r.callbacks[session_id] = append(r.callbacks[session_id], fn)
-	r.callbacksLock.Unlock()
+	r.sessionMessageClosers[session_id] = append(r.sessionMessageClosers[session_id], fn)
+	r.sessionMessageClosersLock.Unlock()
 }
 
-// removeCallback removes the listener for the given session_id
-func (r *RemotePluginRuntime) removeCallback(session_id string) {
-	r.callbacksLock.Lock()
-	delete(r.callbacks, session_id)
-	r.callbacksLock.Unlock()
+// removeSessionMessageCloser removes the closer for the given session_id
+func (r *RemotePluginRuntime) removeSessionMessageCloser(session_id string) {
+	// do nothing if the session is already closed
+	if atomic.LoadInt32(&r.closed) == 1 {
+		return
+	}
+
+	r.sessionMessageClosersLock.Lock()
+	delete(r.sessionMessageClosers, session_id)
+	r.sessionMessageClosersLock.Unlock()
 }
 
 func (r *RemotePluginRuntime) onDisconnected() {
+	// call all session message closers
+	r.sessionMessageClosersLock.RLock()
+	for _, closer := range r.sessionMessageClosers {
+		for _, fn := range closer {
+			fn()
+		}
+	}
+	r.sessionMessageClosersLock.RUnlock()
+
+	// change the alive status
+	r.alive = false
+
+	// change the closed status
+	atomic.StoreInt32(&r.closed, 1)
+
 	// close shutdown channel to notify all waiting routines
 	close(r.shutdownChan)
 

+ 2 - 2
internal/server/controllers/base.go

@@ -37,13 +37,13 @@ func BindPluginDispatchRequest[T any](r *gin.Context, success func(
 	BindRequest(r, func(req plugin_entities.InvokePluginRequest[T]) {
 		pluginUniqueIdentifierAny, exists := r.Get(constants.CONTEXT_KEY_PLUGIN_UNIQUE_IDENTIFIER)
 		if !exists {
-			r.JSON(400, exception.PluginUniqueIdentifierError(errors.New("Plugin unique identifier is required")).ToResponse())
+			r.JSON(400, exception.UniqueIdentifierError(errors.New("Plugin unique identifier is required")).ToResponse())
 			return
 		}
 
 		pluginUniqueIdentifier, ok := pluginUniqueIdentifierAny.(plugin_entities.PluginUniqueIdentifier)
 		if !ok {
-			r.JSON(400, exception.PluginUniqueIdentifierError(errors.New("Plugin unique identifier is not valid")).ToResponse())
+			r.JSON(400, exception.UniqueIdentifierError(errors.New("Plugin unique identifier is not valid")).ToResponse())
 			return
 		}
 

+ 1 - 1
internal/server/endpoint.go

@@ -61,7 +61,7 @@ func (app *App) EndpointHandler(ctx *gin.Context, hookId string, path string) {
 		pluginInstallation.PluginUniqueIdentifier,
 	)
 	if err != nil {
-		ctx.JSON(400, exception.PluginUniqueIdentifierError(
+		ctx.JSON(400, exception.UniqueIdentifierError(
 			errors.New("invalid plugin unique identifier"),
 		).ToResponse())
 		return

+ 1 - 1
internal/server/middleware.go

@@ -57,7 +57,7 @@ func (app *App) FetchPluginInstallation() gin.HandlerFunc {
 
 		identity, err := plugin_entities.NewPluginUniqueIdentifier(installation.PluginUniqueIdentifier)
 		if err != nil {
-			ctx.AbortWithStatusJSON(400, exception.PluginUniqueIdentifierError(err).ToResponse())
+			ctx.AbortWithStatusJSON(400, exception.UniqueIdentifierError(err).ToResponse())
 			return
 		}
 

+ 3 - 3
internal/service/endpoint.go

@@ -73,7 +73,7 @@ func Endpoint(
 
 	identifier, err := plugin_entities.NewPluginUniqueIdentifier(pluginInstallation.PluginUniqueIdentifier)
 	if err != nil {
-		ctx.JSON(400, exception.PluginUniqueIdentifierError(err).ToResponse())
+		ctx.JSON(400, exception.UniqueIdentifierError(err).ToResponse())
 		return
 	}
 
@@ -260,7 +260,7 @@ func ListEndpoints(tenant_id string, page int, page_size int) *entities.Response
 			pluginInstallation.PluginUniqueIdentifier,
 		)
 		if err != nil {
-			return exception.PluginUniqueIdentifierError(
+			return exception.UniqueIdentifierError(
 				fmt.Errorf("failed to parse plugin unique identifier: %v", err),
 			).ToResponse()
 		}
@@ -350,7 +350,7 @@ func ListPluginEndpoints(tenant_id string, plugin_id string, page int, page_size
 		)
 
 		if err != nil {
-			return exception.PluginUniqueIdentifierError(
+			return exception.UniqueIdentifierError(
 				fmt.Errorf("failed to parse plugin unique identifier: %v", err),
 			).ToResponse()
 		}

+ 1 - 1
internal/service/install_plugin.go

@@ -567,7 +567,7 @@ func UninstallPlugin(
 
 	pluginUniqueIdentifier, err := plugin_entities.NewPluginUniqueIdentifier(installation.PluginUniqueIdentifier)
 	if err != nil {
-		return exception.PluginUniqueIdentifierError(err).ToResponse()
+		return exception.UniqueIdentifierError(err).ToResponse()
 	}
 
 	// Uninstall the plugin

+ 1 - 1
internal/service/manage_plugin.go

@@ -50,7 +50,7 @@ func ListPlugins(tenant_id string, page int, page_size int) *entities.Response {
 			plugin_installation.PluginUniqueIdentifier,
 		)
 		if err != nil {
-			return exception.PluginUniqueIdentifierError(err).ToResponse()
+			return exception.UniqueIdentifierError(err).ToResponse()
 		}
 
 		pluginDeclaration, err := helper.CombinedGetPluginDeclaration(

+ 1 - 1
internal/service/setup_endpoint.go

@@ -163,7 +163,7 @@ func UpdateEndpoint(endpoint_id string, tenant_id string, user_id string, name s
 		installation.PluginUniqueIdentifier,
 	)
 	if err != nil {
-		return exception.PluginUniqueIdentifierError(fmt.Errorf("failed to parse plugin unique identifier: %v", err)).ToResponse()
+		return exception.UniqueIdentifierError(fmt.Errorf("failed to parse plugin unique identifier: %v", err)).ToResponse()
 	}
 
 	// get plugin

+ 30 - 9
internal/types/exception/factory.go

@@ -1,35 +1,56 @@
 package exception
 
+const (
+	PluginDaemonInternalServerError   = "PluginDaemonInternalServerError"
+	PluginDaemonBadRequestError       = "PluginDaemonBadRequestError"
+	PluginDaemonNotFoundError         = "PluginDaemonNotFoundError"
+	PluginDaemonUnauthorizedError     = "PluginDaemonUnauthorizedError"
+	PluginDaemonPermissionDeniedError = "PluginDaemonPermissionDeniedError"
+	PluginDaemonInvokeError           = "PluginDaemonInvokeError"
+	PluginUniqueIdentifierError       = "PluginUniqueIdentifierError"
+	PluginNotFoundError               = "PluginNotFoundError"
+	PluginUnauthorizedError           = "PluginUnauthorizedError"
+	PluginPermissionDeniedError       = "PluginPermissionDeniedError"
+	PluginInvokeError                 = "PluginInvokeError"
+	PluginConnectionClosedError       = "ConnectionClosedError"
+)
+
 func InternalServerError(err error) PluginDaemonError {
-	return ErrorWithTypeAndCode(err.Error(), "PluginDaemonInternalServerError", -500)
+	return ErrorWithTypeAndCode(err.Error(), PluginDaemonInternalServerError, -500)
 }
 
 func BadRequestError(err error) PluginDaemonError {
-	return ErrorWithTypeAndCode(err.Error(), "PluginDaemonBadRequestError", -400)
+	return ErrorWithTypeAndCode(err.Error(), PluginDaemonBadRequestError, -400)
 }
 
 func NotFoundError(err error) PluginDaemonError {
-	return ErrorWithTypeAndCode(err.Error(), "PluginDaemonNotFoundError", -404)
+	return ErrorWithTypeAndCode(err.Error(), PluginDaemonNotFoundError, -404)
 }
 
-func PluginUniqueIdentifierError(err error) PluginDaemonError {
-	return ErrorWithTypeAndCode(err.Error(), "PluginUniqueIdentifierError", -400)
+func UniqueIdentifierError(err error) PluginDaemonError {
+	return ErrorWithTypeAndCode(err.Error(), PluginUniqueIdentifierError, -400)
 }
 
 // the difference between NotFoundError and ErrPluginNotFound is that the latter is used to notify
 // the caller that the plugin is not installed, while the former is a generic NotFound error.
 func ErrPluginNotFound() PluginDaemonError {
-	return ErrorWithTypeAndCode("plugin not found", "PluginNotFoundError", -404)
+	return ErrorWithTypeAndCode("plugin not found", PluginNotFoundError, -404)
 }
 
 func UnauthorizedError() PluginDaemonError {
-	return ErrorWithTypeAndCode("unauthorized", "PluginDaemonUnauthorizedError", -401)
+	return ErrorWithTypeAndCode("unauthorized", PluginDaemonUnauthorizedError, -401)
 }
 
 func PermissionDeniedError(msg string) PluginDaemonError {
-	return ErrorWithTypeAndCode(msg, "PluginPermissionDeniedError", -403)
+	return ErrorWithTypeAndCode(msg, PluginPermissionDeniedError, -403)
 }
 
 func InvokePluginError(err error) PluginDaemonError {
-	return ErrorWithTypeAndCode(err.Error(), "PluginInvokeError", -500)
+	return ErrorWithTypeAndCode(err.Error(), PluginInvokeError, -500)
+}
+
+// ConnectionClosedError is designed to be used when the connection was closed unexpectedly
+// but the session is not closed yet.
+func ConnectionClosedError() PluginDaemonError {
+	return ErrorWithTypeAndCode("connection closed", PluginConnectionClosedError, -500)
 }