Bläddra i källkod

feat: handshake

Yeuoly 1 år sedan
förälder
incheckning
de325de8f9

+ 3 - 5
internal/core/plugin_manager/local_manager/io.go

@@ -1,22 +1,20 @@
 package local_manager
 
 import (
-	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/local_manager/stdio_holder"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 )
 
 func (r *LocalPluginRuntime) Listen(session_id string) *entities.BytesIOListener {
 	listener := entities.NewIOListener[[]byte]()
 	listener.OnClose(func() {
-		stdio_holder.RemoveListener(r.io_identity, session_id)
+		RemoveStdioListener(r.io_identity, session_id)
 	})
-	stdio_holder.OnEvent(r.io_identity, session_id, func(b []byte) {
+	OnStdioEvent(r.io_identity, session_id, func(b []byte) {
 		listener.Emit(b)
 	})
-
 	return listener
 }
 
 func (r *LocalPluginRuntime) Write(session_id string, data []byte) {
-	stdio_holder.Write(r.io_identity, append(data, '\n'))
+	WriteToStdio(r.io_identity, append(data, '\n'))
 }

+ 2 - 3
internal/core/plugin_manager/local_manager/run.go

@@ -6,7 +6,6 @@ import (
 	"os/exec"
 	"sync"
 
-	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/local_manager/stdio_holder"
 	"github.com/langgenius/dify-plugin-daemon/internal/process"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
@@ -15,7 +14,7 @@ import (
 
 func (r *LocalPluginRuntime) gc() {
 	if r.io_identity != "" {
-		stdio_holder.Remove(r.io_identity)
+		RemoveStdio(r.io_identity)
 	}
 
 	if r.w != nil {
@@ -84,7 +83,7 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 	log.Info("plugin %s started", r.Config.Identity())
 
 	// setup stdio
-	stdio := stdio_holder.Put(r.Config.Identity(), stdin, stdout, stderr)
+	stdio := PutStdioIo(r.Config.Identity(), stdin, stdout, stderr)
 	r.io_identity = stdio.GetID()
 	defer stdio.Stop()
 

+ 1 - 1
internal/core/plugin_manager/local_manager/stdio_holder/io.go

@@ -1,4 +1,4 @@
-package stdio_holder
+package local_manager
 
 import (
 	"bufio"

+ 6 - 6
internal/core/plugin_manager/local_manager/stdio_holder/store.go

@@ -1,4 +1,4 @@
-package stdio_holder
+package local_manager
 
 import (
 	"io"
@@ -7,7 +7,7 @@ import (
 	"github.com/google/uuid"
 )
 
-func Put(
+func PutStdioIo(
 	plugin_identity string, writer io.WriteCloser,
 	reader io.ReadCloser, err_reader io.ReadCloser,
 ) *stdioHolder {
@@ -39,11 +39,11 @@ func Get(id string) *stdioHolder {
 	return nil
 }
 
-func Remove(id string) {
+func RemoveStdio(id string) {
 	stdio_holder.Delete(id)
 }
 
-func OnEvent(id string, session_id string, listener func([]byte)) {
+func OnStdioEvent(id string, session_id string, listener func([]byte)) {
 	if v, ok := stdio_holder.Load(id); ok {
 		if holder, ok := v.(*stdioHolder); ok {
 			holder.l.Lock()
@@ -72,7 +72,7 @@ func OnError(id string, session_id string, listener func([]byte)) {
 
 }
 
-func RemoveListener(id string, listener string) {
+func RemoveStdioListener(id string, listener string) {
 	if v, ok := stdio_holder.Load(id); ok {
 		if holder, ok := v.(*stdioHolder); ok {
 			holder.l.Lock()
@@ -89,7 +89,7 @@ func OnGlobalEvent(listener func(string, []byte)) {
 	listeners[uuid.New().String()] = listener
 }
 
-func Write(id string, data []byte) error {
+func WriteToStdio(id string, data []byte) error {
 	if v, ok := stdio_holder.Load(id); ok {
 		if holder, ok := v.(*stdioHolder); ok {
 			_, err := holder.writer.Write(data)

+ 36 - 2
internal/core/plugin_manager/remote_manager/hooks.go

@@ -2,7 +2,10 @@ package remote_manager
 
 import (
 	"sync"
+	"time"
 
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 	"github.com/panjf2000/gnet/v2"
 )
@@ -54,7 +57,13 @@ func (s *DifyServer) OnOpen(c gnet.Conn) (out []byte, action gnet.Action) {
 	s.plugins[c.Fd()] = runtime
 	s.plugins_lock.Unlock()
 
-	s.response.Write(runtime)
+	// start a timer to check if handshake is completed in 10 seconds
+	time.AfterFunc(time.Second*10, func() {
+		if !runtime.handshake {
+			// close connection
+			c.Close()
+		}
+	})
 
 	// verified
 	verified := true
@@ -99,8 +108,33 @@ func (s *DifyServer) OnTraffic(c gnet.Conn) (action gnet.Action) {
 
 	// handle messages
 	for _, message := range messages {
-		runtime.response.Write(message)
+		s.onMessage(runtime, message)
 	}
 
 	return gnet.None
 }
+
+func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
+	// handle message
+	if !runtime.handshake {
+		// process handle shake if not completed
+		declaration, err := parser.UnmarshalJsonBytes[plugin_entities.PluginDeclaration](message)
+		if err != nil {
+			// close connection if handshake failed
+			runtime.conn.Write([]byte("handshake failed\n"))
+			runtime.conn.Close()
+			return
+		}
+
+		runtime.Config = declaration
+
+		// handshake completed
+		runtime.handshake = true
+
+		// publish runtime to watcher
+		s.response.Write(runtime)
+	} else {
+		// continue handle messages if handshake completed
+		runtime.response.Write(message)
+	}
+}

+ 34 - 2
internal/core/plugin_manager/remote_manager/server_test.go

@@ -1,12 +1,15 @@
 package remote_manager
 
 import (
+	"errors"
 	"fmt"
 	"net"
 	"testing"
 	"time"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 )
 
 func preparePluginServer(t *testing.T) (*RemotePluginServer, uint16) {
@@ -73,6 +76,7 @@ func TestAcceptConnection(t *testing.T) {
 	}()
 
 	got_connection := false
+	var connection_err error
 
 	go func() {
 		for server.Next() {
@@ -82,9 +86,11 @@ func TestAcceptConnection(t *testing.T) {
 				return
 			}
 
-			got_connection = true
+			if runtime.Config.Name != "ci_test" {
+				connection_err = errors.New("plugin name not matched")
+			}
 
-			time.Sleep(time.Second * 2)
+			got_connection = true
 			runtime.Stop()
 		}
 	}()
@@ -98,6 +104,28 @@ func TestAcceptConnection(t *testing.T) {
 		return
 	}
 
+	// send handshake
+	handle_shake_message := parser.MarshalJsonBytes(&plugin_entities.PluginDeclaration{
+		Version:   "1.0.0",
+		Type:      plugin_entities.PluginType,
+		Author:    "Yeuoly",
+		Name:      "ci_test",
+		CreatedAt: time.Now(),
+		Resource: plugin_entities.PluginResourceRequirement{
+			Memory:     1,
+			Storage:    1,
+			Permission: nil,
+		},
+		Plugins: []string{
+			"test",
+		},
+		Execution: plugin_entities.PluginDeclarationExecution{
+			Install: "echo 'hello'",
+			Launch:  "echo 'hello'",
+		},
+	})
+	conn.Write(handle_shake_message)
+	conn.Write([]byte("\n"))
 	closed_chan := make(chan bool)
 
 	go func() {
@@ -123,6 +151,10 @@ func TestAcceptConnection(t *testing.T) {
 			t.Errorf("failed to accept connection")
 			return
 		}
+		if connection_err != nil {
+			t.Errorf("failed to accept connection: %s", connection_err.Error())
+			return
+		}
 		return
 	}
 }

+ 3 - 0
internal/core/plugin_manager/remote_manager/type.go

@@ -28,6 +28,9 @@ type RemotePluginRuntime struct {
 	// heartbeat
 	last_active_at time.Time
 
+	// hand shake process completed
+	handshake bool
+
 	alive bool
 }