Selaa lähdekoodia

enhancement: add testcase for remote plugin manager

Yeuoly 9 kuukautta sitten
vanhempi
commit
0ecb41e62c

+ 4 - 0
internal/core/plugin_manager/manager.go

@@ -9,6 +9,7 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation/real"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/remote_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/serverless"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/decoder"
 	"github.com/langgenius/dify-plugin-daemon/internal/db"
@@ -51,6 +52,9 @@ type PluginManager struct {
 
 	// python interpreter path
 	pythonInterpreterPath string
+
+	// remote plugin server
+	remotePluginServer remote_manager.RemotePluginServerInterface
 }
 
 var (

+ 1 - 1
internal/core/plugin_manager/remote_manager/hooks.go

@@ -40,7 +40,7 @@ type DifyServer struct {
 	num_loops int
 
 	// read new connections
-	response *stream.Stream[*RemotePluginRuntime]
+	response *stream.Stream[plugin_entities.PluginFullDuplexLifetime]
 
 	plugins      map[int]*RemotePluginRuntime
 	plugins_lock *sync.RWMutex

+ 12 - 3
internal/core/plugin_manager/remote_manager/server.go

@@ -10,6 +10,7 @@ import (
 
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 	"github.com/panjf2000/gnet/v2"
 
@@ -20,8 +21,16 @@ type RemotePluginServer struct {
 	server *DifyServer
 }
 
+type RemotePluginServerInterface interface {
+	Read() (plugin_entities.PluginFullDuplexLifetime, error)
+	Next() bool
+	Wrap(f func(plugin_entities.PluginFullDuplexLifetime))
+	Stop() error
+	Launch() error
+}
+
 // continue accepting new connections
-func (r *RemotePluginServer) Read() (*RemotePluginRuntime, error) {
+func (r *RemotePluginServer) Read() (plugin_entities.PluginFullDuplexLifetime, error) {
 	if r.server.response == nil {
 		return nil, errors.New("plugin server not started")
 	}
@@ -39,7 +48,7 @@ func (r *RemotePluginServer) Next() bool {
 }
 
 // Wrap wraps the wrap method of stream response
-func (r *RemotePluginServer) Wrap(f func(*RemotePluginRuntime)) {
+func (r *RemotePluginServer) Wrap(f func(plugin_entities.PluginFullDuplexLifetime)) {
 	r.server.response.Async(f)
 }
 
@@ -85,7 +94,7 @@ func NewRemotePluginServer(config *app.Config, media_manager *media_manager.Medi
 		config.PluginRemoteInstallingPort,
 	)
 
-	response := stream.NewStream[*RemotePluginRuntime](
+	response := stream.NewStream[plugin_entities.PluginFullDuplexLifetime](
 		config.PluginRemoteInstallingMaxConn,
 	)
 

+ 5 - 2
internal/core/plugin_manager/remote_manager/server_test.go

@@ -119,11 +119,14 @@ func TestAcceptConnection(t *testing.T) {
 				return
 			}
 
-			if runtime.Config.Name != "ci_test" {
+			remote_runtime := runtime.(*RemotePluginRuntime)
+
+			config := remote_runtime.Configuration()
+			if config.Name != "ci_test" {
 				connection_err = errors.New("plugin name not matched")
 			}
 
-			if runtime.tenant_id != tenant_id {
+			if remote_runtime.tenant_id != tenant_id {
 				connection_err = errors.New("tenant id not matched")
 			}
 

+ 10 - 3
internal/core/plugin_manager/watcher.go

@@ -35,18 +35,25 @@ func (p *PluginManager) startLocalWatcher() {
 	}()
 }
 
+func (p *PluginManager) initRemotePluginServer(config *app.Config) {
+	if p.remotePluginServer != nil {
+		return
+	}
+	p.remotePluginServer = remote_manager.NewRemotePluginServer(config, p.mediaManager)
+}
+
 func (p *PluginManager) startRemoteWatcher(config *app.Config) {
 	// launch TCP debugging server if enabled
 	if config.PluginRemoteInstallingEnabled {
-		server := remote_manager.NewRemotePluginServer(config, p.mediaManager)
+		p.initRemotePluginServer(config)
 		go func() {
-			err := server.Launch()
+			err := p.remotePluginServer.Launch()
 			if err != nil {
 				log.Error("start remote plugin server failed: %s", err.Error())
 			}
 		}()
 		go func() {
-			server.Wrap(func(rpr *remote_manager.RemotePluginRuntime) {
+			p.remotePluginServer.Wrap(func(rpr plugin_entities.PluginFullDuplexLifetime) {
 				identity, err := rpr.Identity()
 				if err != nil {
 					log.Error("get remote plugin identity failed: %s", err.Error())

+ 120 - 0
internal/core/plugin_manager/watcher_test.go

@@ -0,0 +1,120 @@
+package plugin_manager
+
+import (
+	"testing"
+	"time"
+
+	"github.com/google/uuid"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/positive_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
+)
+
+type fakePlugin struct {
+	plugin_entities.PluginRuntime
+	positive_manager.PositivePluginRuntime
+}
+
+func (r *fakePlugin) InitEnvironment() error {
+	return nil
+}
+
+func (r *fakePlugin) Checksum() (string, error) {
+	return "", nil
+}
+
+func (r *fakePlugin) Identity() (plugin_entities.PluginUniqueIdentifier, error) {
+	return plugin_entities.PluginUniqueIdentifier(""), nil
+}
+
+func (r *fakePlugin) StartPlugin() error {
+	return nil
+}
+
+func (r *fakePlugin) Type() plugin_entities.PluginRuntimeType {
+	return plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL
+}
+
+func (r *fakePlugin) Wait() (<-chan bool, error) {
+	return nil, nil
+}
+
+func (r *fakePlugin) Listen(string) *entities.Broadcast[plugin_entities.SessionMessage] {
+	return nil
+}
+
+func (r *fakePlugin) Write(string, []byte) {
+}
+
+func (r *fakePlugin) WaitStarted() <-chan bool {
+	c := make(chan bool)
+	close(c)
+	return c
+}
+
+func (r *fakePlugin) WaitStopped() <-chan bool {
+	c := make(chan bool)
+	return c
+}
+
+func getRandomPluginRuntime() *fakePlugin {
+	return &fakePlugin{
+		PluginRuntime: plugin_entities.PluginRuntime{
+			Config: plugin_entities.PluginDeclaration{
+				PluginDeclarationWithoutAdvancedFields: plugin_entities.PluginDeclarationWithoutAdvancedFields{
+					Name: uuid.New().String(),
+					Label: plugin_entities.I18nObject{
+						EnUS: "label",
+					},
+					Version:   "0.0.1",
+					Type:      plugin_entities.PluginType,
+					Author:    "Yeuoly",
+					CreatedAt: time.Now(),
+					Plugins: plugin_entities.PluginExtensions{
+						Tools: []string{"test"},
+					},
+				},
+			},
+		},
+	}
+}
+
+type fakeRemotePluginServer struct {
+}
+
+func (f *fakeRemotePluginServer) Launch() error {
+	return nil
+}
+
+func (f *fakeRemotePluginServer) Next() bool {
+	return false
+}
+
+func (f *fakeRemotePluginServer) Read() (plugin_entities.PluginFullDuplexLifetime, error) {
+	return nil, nil
+}
+
+func (f *fakeRemotePluginServer) Stop() error {
+	return nil
+}
+
+func (f *fakeRemotePluginServer) Wrap(fn func(plugin_entities.PluginFullDuplexLifetime)) {
+	fn(getRandomPluginRuntime())
+}
+
+func TestRemotePluginWatcherPluginStoredToManager(t *testing.T) {
+	config := &app.Config{}
+	config.SetDefault()
+	routine.InitPool(1024)
+	pm := InitGlobalManager(&app.Config{})
+	pm.remotePluginServer = &fakeRemotePluginServer{}
+	pm.startRemoteWatcher(config)
+
+	time.Sleep(1 * time.Second)
+
+	if pm.m.Len() != 1 {
+		t.Fatalf("Expected 1 plugin, got %d", pm.m.Len())
+	}
+}