瀏覽代碼

feat: add handshake

Yeuoly 1 年之前
父節點
當前提交
d7093972a9

+ 23 - 2
internal/core/plugin_manager/remote_manager/hooks.go

@@ -5,6 +5,7 @@ import (
 	"time"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 	"github.com/panjf2000/gnet/v2"
@@ -117,6 +118,26 @@ func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
 func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 	// handle message
 	if !runtime.handshake {
+		key := string(message)
+
+		info, err := GetConnectionInfo(key)
+		if err == cache.ErrNotFound {
+			// close connection if handshake failed
+			runtime.conn.Write([]byte("handshake failed\n"))
+			runtime.conn.Close()
+			return
+		} else if err != nil {
+			// close connection if handshake failed
+			runtime.conn.Write([]byte("internal error\n"))
+			runtime.conn.Close()
+			return
+		}
+
+		runtime.State.TenantID = info.TenantId
+
+		// handshake completed
+		runtime.handshake = true
+	} else if !runtime.registration_transferred {
 		// process handle shake if not completed
 		declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](message)
 		if err != nil {
@@ -128,8 +149,8 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 
 		runtime.Config = declaration
 
-		// handshake completed
-		runtime.handshake = true
+		// registration transferred
+		runtime.registration_transferred = true
 
 		// publish runtime to watcher
 		s.response.Write(runtime)

+ 29 - 0
internal/core/plugin_manager/remote_manager/server_test.go

@@ -10,6 +10,7 @@ import (
 
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 )
 
@@ -67,6 +68,21 @@ func TestLaunchAndClosePluginServer(t *testing.T) {
 
 // TestAcceptConnection tests the acceptance of the connection
 func TestAcceptConnection(t *testing.T) {
+	if cache.InitRedisClient("0.0.0.0:6379", "difyai123456") != nil {
+		t.Errorf("failed to init redis client")
+		return
+	}
+
+	defer cache.Close()
+	key, err := GetConnectionKey(ConnectionInfo{
+		TenantId: "test",
+	})
+	if err != nil {
+		t.Errorf("failed to get connection key: %s", err.Error())
+		return
+	}
+	defer ClearConnectionKey("test")
+
 	server, port := preparePluginServer(t)
 	if server == nil {
 		return
@@ -91,6 +107,10 @@ func TestAcceptConnection(t *testing.T) {
 				connection_err = errors.New("plugin name not matched")
 			}
 
+			if runtime.State.TenantID != "test" {
+				connection_err = errors.New("tenant id not matched")
+			}
+
 			got_connection = true
 			runtime.Stop()
 		}
@@ -125,6 +145,8 @@ func TestAcceptConnection(t *testing.T) {
 			Launch:  "echo 'hello'",
 		},
 	})
+	conn.Write([]byte(key))
+	conn.Write([]byte("\n"))
 	conn.Write(handle_shake_message)
 	conn.Write([]byte("\n"))
 	closed_chan := make(chan bool)
@@ -218,6 +240,13 @@ func TestNoHandleShakeIn10Seconds(t *testing.T) {
 }
 
 func TestIncorrectHandshake(t *testing.T) {
+	if cache.InitRedisClient("0.0.0.0:6379", "difyai123456") != nil {
+		t.Errorf("failed to init redis client")
+		return
+	}
+
+	defer cache.Close()
+
 	server, port := preparePluginServer(t)
 	if server == nil {
 		return

+ 3 - 0
internal/core/plugin_manager/remote_manager/type.go

@@ -31,6 +31,9 @@ type RemotePluginRuntime struct {
 	// hand shake process completed
 	handshake bool
 
+	// registration transferred
+	registration_transferred bool
+
 	alive bool
 }