Yeuoly 10 meses atrás
pai
commit
408025591d

+ 27 - 30
internal/core/plugin_manager/remote_manager/hooks.go

@@ -3,6 +3,7 @@ package remote_manager
 import (
 	"encoding/hex"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/basic_manager"
@@ -15,6 +16,10 @@ import (
 	"github.com/panjf2000/gnet/v2"
 )
 
+var (
+	_mode pluginRuntimeMode
+)
+
 type DifyServer struct {
 	gnet.BuiltinEventEngine
 
@@ -100,12 +105,11 @@ func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
 	plugin.ClearAssets()
 
 	// uninstall plugin
-	if plugin.handshake && plugin.registration_transferred &&
-		plugin.endpoints_registration_transferred &&
-		plugin.models_registration_transferred &&
-		plugin.tools_registration_transferred {
-		if err := plugin.Unregister(); err != nil {
-			log.Error("unregister plugin failed, error: %v", err)
+	if plugin.assets_transferred {
+		if _mode != _PLUGIN_RUNTIME_MODE_CI {
+			if err := plugin.Unregister(); err != nil {
+				log.Error("unregister plugin failed, error: %v", err)
+			}
 		}
 	}
 
@@ -150,20 +154,25 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 		return
 	}
 
+	close := func(message []byte) {
+		if atomic.CompareAndSwapInt32(&runtime.closed, 0, 1) {
+			runtime.conn.Write(message)
+			runtime.conn.Close()
+		}
+	}
+
 	if !runtime.handshake {
 		key := string(message)
 
 		info, err := GetConnectionInfo(key)
 		if err == cache.ErrNotFound {
 			// close connection if handshake failed
-			runtime.conn.Write([]byte("handshake failed, invalid key\n"))
-			runtime.conn.Close()
+			close([]byte("handshake failed, invalid key\n"))
 			runtime.handshake_failed = true
 			return
 		} else if err != nil {
 			// close connection if handshake failed
-			runtime.conn.Write([]byte("internal error\n"))
-			runtime.conn.Close()
+			close([]byte("internal error\n"))
 			return
 		}
 
@@ -176,8 +185,7 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 		declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](message)
 		if err != nil {
 			// close connection if handshake failed
-			runtime.conn.Write([]byte("handshake failed\n"))
-			runtime.conn.Close()
+			close([]byte("handshake failed, invalid plugin declaration\n"))
 			return
 		}
 
@@ -188,9 +196,8 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 	} else if !runtime.tools_registration_transferred {
 		tools, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ToolProviderDeclaration](message)
 		if err != nil {
-			runtime.conn.Write([]byte("tools register failed\n"))
 			log.Error("tools register failed, error: %v", err)
-			runtime.conn.Close()
+			close([]byte("tools register failed, invalid tools declaration\n"))
 			return
 		}
 
@@ -204,9 +211,8 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 	} else if !runtime.models_registration_transferred {
 		models, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.ModelProviderDeclaration](message)
 		if err != nil {
-			runtime.conn.Write([]byte("models register failed\n"))
 			log.Error("models register failed, error: %v", err)
-			runtime.conn.Close()
+			close([]byte("models register failed, invalid models declaration\n"))
 			return
 		}
 
@@ -220,9 +226,8 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 	} else if !runtime.endpoints_registration_transferred {
 		endpoints, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.EndpointProviderDeclaration](message)
 		if err != nil {
-			runtime.conn.Write([]byte("endpoints register failed\n"))
 			log.Error("endpoints register failed, error: %v", err)
-			runtime.conn.Close()
+			close([]byte("endpoints register failed, invalid endpoints declaration\n"))
 			return
 		}
 
@@ -236,9 +241,8 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 	} else if !runtime.assets_transferred {
 		assets, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.RemoteAssetPayload](message)
 		if err != nil {
-			runtime.conn.Write([]byte("assets register failed\n"))
 			log.Error("assets register failed, error: %v", err)
-			runtime.conn.Close()
+			close([]byte("assets register failed, invalid assets declaration\n"))
 			return
 		}
 
@@ -246,18 +250,16 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 		for _, asset := range assets {
 			files[asset.Filename], err = hex.DecodeString(asset.Data)
 			if err != nil {
-				runtime.conn.Write([]byte("assets decode failed\n"))
 				log.Error("assets decode failed, error: %v", err)
-				runtime.conn.Close()
+				close([]byte("assets decode failed, invalid assets data, cannot decode file\n"))
 				return
 			}
 		}
 
 		// remap assets
 		if err := runtime.RemapAssets(&runtime.Config, files); err != nil {
-			runtime.conn.Write([]byte("assets remap failed\n"))
 			log.Error("assets remap failed, error: %v", err)
-			runtime.conn.Close()
+			close([]byte("assets remap failed, invalid assets data, cannot remap\n"))
 			return
 		}
 
@@ -269,9 +271,8 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 
 		// trigger registration event
 		if err := runtime.Register(); err != nil {
-			runtime.conn.Write([]byte("register failed\n"))
 			log.Error("register failed, error: %v", err)
-			runtime.conn.Close()
+			close([]byte("register failed, cannot register\n"))
 			return
 		}
 
@@ -282,7 +283,3 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 		runtime.response.Write(message)
 	}
 }
-
-func (s *DifyServer) onAssets(runtime *RemotePluginRuntime, assets []plugin_entities.RemoteAssetPayload) {
-
-}

+ 14 - 2
internal/core/plugin_manager/remote_manager/server_test.go

@@ -8,6 +8,7 @@ import (
 	"testing"
 	"time"
 
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/db"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/constants"
@@ -17,6 +18,10 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 )
 
+func init() {
+	_mode = _PLUGIN_RUNTIME_MODE_CI
+}
+
 func preparePluginServer(t *testing.T) (*RemotePluginServer, uint16) {
 	db.Init(&app.Config{
 		DBUsername: "postgres",
@@ -39,7 +44,7 @@ func preparePluginServer(t *testing.T) (*RemotePluginServer, uint16) {
 		PluginRemoteInstallingPort:             port,
 		PluginRemoteInstallingMaxConn:          1,
 		PluginRemoteInstallServerEventLoopNums: 8,
-	}, nil), port
+	}, media_manager.NewMediaManager("./storage/assets", 10)), port
 }
 
 // TestLaunchAndClosePluginServer tests the launch and close of the plugin server
@@ -140,6 +145,7 @@ func TestAcceptConnection(t *testing.T) {
 			Type:    plugin_entities.PluginType,
 			Author:  "Yeuoly",
 			Name:    "ci_test",
+			Icon:    "test.svg",
 			Label: plugin_entities.I18nObject{
 				EnUS: "ci_test",
 			},
@@ -171,7 +177,13 @@ func TestAcceptConnection(t *testing.T) {
 	conn.Write([]byte("[]\n")) // transfer tool
 	conn.Write([]byte("[]\n")) // transfer model
 	conn.Write([]byte("[]\n")) // transfer endpoint
-	conn.Write([]byte("[]\n")) // transfer file
+	conn.Write(parser.MarshalJsonBytes([]plugin_entities.RemoteAssetPayload{
+		{
+			Filename: "test.svg",
+			Data:     "a2a2",
+		},
+	}))
+	conn.Write([]byte("\n"))
 	closed_chan := make(chan bool)
 
 	msg := ""

+ 6 - 1
internal/core/plugin_manager/remote_manager/type.go

@@ -10,12 +10,17 @@ import (
 	"github.com/panjf2000/gnet/v2"
 )
 
+type pluginRuntimeMode string
+
+const _PLUGIN_RUNTIME_MODE_CI pluginRuntimeMode = "ci"
+
 type RemotePluginRuntime struct {
 	basic_manager.BasicPluginRuntime
 	plugin_entities.PluginRuntime
 
 	// connection
-	conn gnet.Conn
+	conn   gnet.Conn
+	closed int32
 
 	// response entity to accept new events
 	response *stream.StreamResponse[[]byte]