浏览代码

feat: support removing local plugin runtime

Yeuoly 8 月之前
父节点
当前提交
033636c83f

+ 3 - 3
internal/core/plugin_manager/local_manager/io.go

@@ -10,9 +10,9 @@ import (
 func (r *LocalPluginRuntime) Listen(session_id string) *entities.Broadcast[plugin_entities.SessionMessage] {
 func (r *LocalPluginRuntime) Listen(session_id string) *entities.Broadcast[plugin_entities.SessionMessage] {
 	listener := entities.NewBroadcast[plugin_entities.SessionMessage]()
 	listener := entities.NewBroadcast[plugin_entities.SessionMessage]()
 	listener.OnClose(func() {
 	listener.OnClose(func() {
-		RemoveStdioListener(r.io_identity, session_id)
+		removeStdioHandlerListener(r.io_identity, session_id)
 	})
 	})
-	OnStdioEvent(r.io_identity, session_id, func(b []byte) {
+	setupStdioEventListener(r.io_identity, session_id, func(b []byte) {
 		// unmarshal the session message
 		// unmarshal the session message
 		data, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](b)
 		data, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](b)
 		if err != nil {
 		if err != nil {
@@ -26,5 +26,5 @@ func (r *LocalPluginRuntime) Listen(session_id string) *entities.Broadcast[plugi
 }
 }
 
 
 func (r *LocalPluginRuntime) Write(session_id string, data []byte) {
 func (r *LocalPluginRuntime) Write(session_id string, data []byte) {
-	WriteToStdio(r.io_identity, append(data, '\n'))
+	writeToStdioHandler(r.io_identity, append(data, '\n'))
 }
 }

+ 16 - 2
internal/core/plugin_manager/local_manager/run.go

@@ -17,7 +17,7 @@ import (
 // gc performs garbage collection for the LocalPluginRuntime
 // gc performs garbage collection for the LocalPluginRuntime
 func (r *LocalPluginRuntime) gc() {
 func (r *LocalPluginRuntime) gc() {
 	if r.io_identity != "" {
 	if r.io_identity != "" {
-		RemoveStdio(r.io_identity)
+		removeStdioHandler(r.io_identity)
 	}
 	}
 
 
 	if r.wait_chan != nil {
 	if r.wait_chan != nil {
@@ -109,12 +109,14 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 
 
 		r.gc()
 		r.gc()
 	}()
 	}()
+
+	// ensure the plugin process is killed after the plugin exits
 	defer e.Process.Kill()
 	defer e.Process.Kill()
 
 
 	log.Info("plugin %s started", r.Config.Identity())
 	log.Info("plugin %s started", r.Config.Identity())
 
 
 	// setup stdio
 	// setup stdio
-	stdio := PutStdioIo(r.Config.Identity(), stdin, stdout, stderr)
+	stdio := registerStdioHandler(r.Config.Identity(), stdin, stdout, stderr)
 	r.io_identity = stdio.GetID()
 	r.io_identity = stdio.GetID()
 	defer stdio.Stop()
 	defer stdio.Stop()
 
 
@@ -180,3 +182,15 @@ func (r *LocalPluginRuntime) WaitStopped() <-chan bool {
 	r.wait_chan_lock.Unlock()
 	r.wait_chan_lock.Unlock()
 	return c
 	return c
 }
 }
+
+// Stop stops the plugin
+func (r *LocalPluginRuntime) Stop() {
+	// inherit from PluginRuntime
+	r.PluginRuntime.Stop()
+
+	// get stdio
+	stdio := getStdioHandler(r.io_identity)
+	if stdio != nil {
+		stdio.Stop()
+	}
+}

+ 35 - 18
internal/core/plugin_manager/local_manager/stdio_handle.go

@@ -30,13 +30,17 @@ type stdioHolder struct {
 	error_listener           map[string]func([]byte)
 	error_listener           map[string]func([]byte)
 	started                  bool
 	started                  bool
 
 
+	// error message container
 	err_message                 string
 	err_message                 string
 	last_err_message_updated_at time.Time
 	last_err_message_updated_at time.Time
 
 
-	health_chan        chan bool
-	health_chan_closed bool
-	health_chan_lock   *sync.Mutex
-	last_active_at     time.Time
+	// waiting controller channel to notify the exit signal to the Wait() function
+	waiting_controller_chan        chan bool
+	waiting_controller_chan_closed bool
+	wait_controller_chan_lock      *sync.Mutex
+
+	// the last time the plugin sent a heartbeat
+	last_active_at time.Time
 }
 }
 
 
 func (s *stdioHolder) Error() error {
 func (s *stdioHolder) Error() error {
@@ -49,21 +53,26 @@ func (s *stdioHolder) Error() error {
 	return nil
 	return nil
 }
 }
 
 
+// Stop stops the stdio, of course, it will shutdown the plugin asynchronously
+// by closing a channel to notify the `Wait()` function to exit
 func (s *stdioHolder) Stop() {
 func (s *stdioHolder) Stop() {
 	s.writer.Close()
 	s.writer.Close()
 	s.reader.Close()
 	s.reader.Close()
 	s.err_reader.Close()
 	s.err_reader.Close()
 
 
-	s.health_chan_lock.Lock()
-	if !s.health_chan_closed {
-		close(s.health_chan)
-		s.health_chan_closed = true
+	s.wait_controller_chan_lock.Lock()
+	if !s.waiting_controller_chan_closed {
+		close(s.waiting_controller_chan)
+		s.waiting_controller_chan_closed = true
 	}
 	}
-	s.health_chan_lock.Unlock()
+	s.wait_controller_chan_lock.Unlock()
 
 
 	stdio_holder.Delete(s.id)
 	stdio_holder.Delete(s.id)
 }
 }
 
 
+// StartStdout starts to read the stdout of the plugin
+// it will notify the heartbeat function when the plugin is active
+// and parse the stdout data to trigger corresponding listeners
 func (s *stdioHolder) StartStdout(notify_heartbeat func()) {
 func (s *stdioHolder) StartStdout(notify_heartbeat func()) {
 	s.started = true
 	s.started = true
 	s.last_active_at = time.Now()
 	s.last_active_at = time.Now()
@@ -103,6 +112,8 @@ func (s *stdioHolder) StartStdout(notify_heartbeat func()) {
 	}
 	}
 }
 }
 
 
+// WriteError writes the error message to the stdio holder
+// it will keep the last 1024 bytes of the error message
 func (s *stdioHolder) WriteError(msg string) {
 func (s *stdioHolder) WriteError(msg string) {
 	const MAX_ERR_MSG_LEN = 1024
 	const MAX_ERR_MSG_LEN = 1024
 	reduce := len(msg) + len(s.err_message) - MAX_ERR_MSG_LEN
 	reduce := len(msg) + len(s.err_message) - MAX_ERR_MSG_LEN
@@ -118,6 +129,8 @@ func (s *stdioHolder) WriteError(msg string) {
 	s.last_err_message_updated_at = time.Now()
 	s.last_err_message_updated_at = time.Now()
 }
 }
 
 
+// StartStderr starts to read the stderr of the plugin
+// it will write the error message to the stdio holder
 func (s *stdioHolder) StartStderr() {
 func (s *stdioHolder) StartStderr() {
 	for {
 	for {
 		buf := make([]byte, 1024)
 		buf := make([]byte, 1024)
@@ -135,32 +148,35 @@ func (s *stdioHolder) StartStderr() {
 	}
 	}
 }
 }
 
 
+// Wait waits for the plugin to exit
+// it will return an error if the plugin is not active
+// you can also call `Stop()` to stop the waiting process
 func (s *stdioHolder) Wait() error {
 func (s *stdioHolder) Wait() error {
-	s.health_chan_lock.Lock()
-	if s.health_chan_closed {
-		s.health_chan_lock.Unlock()
+	s.wait_controller_chan_lock.Lock()
+	if s.waiting_controller_chan_closed {
+		s.wait_controller_chan_lock.Unlock()
 		return errors.New("you need to start the health check before waiting")
 		return errors.New("you need to start the health check before waiting")
 	}
 	}
-	s.health_chan_lock.Unlock()
+	s.wait_controller_chan_lock.Unlock()
 
 
 	ticker := time.NewTicker(5 * time.Second)
 	ticker := time.NewTicker(5 * time.Second)
 	defer ticker.Stop()
 	defer ticker.Stop()
 
 
 	// check status of plugin every 5 seconds
 	// check status of plugin every 5 seconds
 	for {
 	for {
-		s.health_chan_lock.Lock()
-		if s.health_chan_closed {
-			s.health_chan_lock.Unlock()
+		s.wait_controller_chan_lock.Lock()
+		if s.waiting_controller_chan_closed {
+			s.wait_controller_chan_lock.Unlock()
 			break
 			break
 		}
 		}
-		s.health_chan_lock.Unlock()
+		s.wait_controller_chan_lock.Unlock()
 		select {
 		select {
 		case <-ticker.C:
 		case <-ticker.C:
 			// check heartbeat
 			// check heartbeat
 			if time.Since(s.last_active_at) > 60*time.Second {
 			if time.Since(s.last_active_at) > 60*time.Second {
 				return plugin_errors.ErrPluginNotActive
 				return plugin_errors.ErrPluginNotActive
 			}
 			}
-		case <-s.health_chan:
+		case <-s.waiting_controller_chan:
 			// closed
 			// closed
 			return s.Error()
 			return s.Error()
 		}
 		}
@@ -169,6 +185,7 @@ func (s *stdioHolder) Wait() error {
 	return nil
 	return nil
 }
 }
 
 
+// GetID returns the id of the stdio holder
 func (s *stdioHolder) GetID() string {
 func (s *stdioHolder) GetID() string {
 	return s.id
 	return s.id
 }
 }

+ 8 - 8
internal/core/plugin_manager/local_manager/stdio_store.go

@@ -7,7 +7,7 @@ import (
 	"github.com/google/uuid"
 	"github.com/google/uuid"
 )
 )
 
 
-func PutStdioIo(
+func registerStdioHandler(
 	plugin_unique_identifier string, writer io.WriteCloser,
 	plugin_unique_identifier string, writer io.WriteCloser,
 	reader io.ReadCloser, err_reader io.ReadCloser,
 	reader io.ReadCloser, err_reader io.ReadCloser,
 ) *stdioHolder {
 ) *stdioHolder {
@@ -21,15 +21,15 @@ func PutStdioIo(
 		id:                       id,
 		id:                       id,
 		l:                        &sync.Mutex{},
 		l:                        &sync.Mutex{},
 
 
-		health_chan_lock: &sync.Mutex{},
-		health_chan:      make(chan bool),
+		wait_controller_chan_lock: &sync.Mutex{},
+		waiting_controller_chan:   make(chan bool),
 	}
 	}
 
 
 	stdio_holder.Store(id, holder)
 	stdio_holder.Store(id, holder)
 	return holder
 	return holder
 }
 }
 
 
-func Get(id string) *stdioHolder {
+func getStdioHandler(id string) *stdioHolder {
 	if v, ok := stdio_holder.Load(id); ok {
 	if v, ok := stdio_holder.Load(id); ok {
 		if holder, ok := v.(*stdioHolder); ok {
 		if holder, ok := v.(*stdioHolder); ok {
 			return holder
 			return holder
@@ -39,11 +39,11 @@ func Get(id string) *stdioHolder {
 	return nil
 	return nil
 }
 }
 
 
-func RemoveStdio(id string) {
+func removeStdioHandler(id string) {
 	stdio_holder.Delete(id)
 	stdio_holder.Delete(id)
 }
 }
 
 
-func OnStdioEvent(id string, session_id string, listener func([]byte)) {
+func setupStdioEventListener(id string, session_id string, listener func([]byte)) {
 	if v, ok := stdio_holder.Load(id); ok {
 	if v, ok := stdio_holder.Load(id); ok {
 		if holder, ok := v.(*stdioHolder); ok {
 		if holder, ok := v.(*stdioHolder); ok {
 			holder.l.Lock()
 			holder.l.Lock()
@@ -71,7 +71,7 @@ func OnError(id string, session_id string, listener func([]byte)) {
 	}
 	}
 }
 }
 
 
-func RemoveStdioListener(id string, listener string) {
+func removeStdioHandlerListener(id string, listener string) {
 	if v, ok := stdio_holder.Load(id); ok {
 	if v, ok := stdio_holder.Load(id); ok {
 		if holder, ok := v.(*stdioHolder); ok {
 		if holder, ok := v.(*stdioHolder); ok {
 			holder.l.Lock()
 			holder.l.Lock()
@@ -88,7 +88,7 @@ func OnGlobalEvent(listener func(string, []byte)) {
 	listeners[uuid.New().String()] = listener
 	listeners[uuid.New().String()] = listener
 }
 }
 
 
-func WriteToStdio(id string, data []byte) error {
+func writeToStdioHandler(id string, data []byte) error {
 	if v, ok := stdio_holder.Load(id); ok {
 	if v, ok := stdio_holder.Load(id); ok {
 		if holder, ok := v.(*stdioHolder); ok {
 		if holder, ok := v.(*stdioHolder); ok {
 			_, err := holder.writer.Write(data)
 			_, err := holder.writer.Write(data)

+ 9 - 3
internal/core/plugin_manager/manager.go

@@ -212,7 +212,9 @@ func (p *PluginManager) SavePackage(plugin_unique_identifier plugin_entities.Plu
 	return &declaration, nil
 	return &declaration, nil
 }
 }
 
 
-func (p *PluginManager) GetPackage(plugin_unique_identifier plugin_entities.PluginUniqueIdentifier) ([]byte, error) {
+func (p *PluginManager) GetPackage(
+	plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+) ([]byte, error) {
 	file, err := os.ReadFile(filepath.Join(p.packageCachePath, plugin_unique_identifier.String()))
 	file, err := os.ReadFile(filepath.Join(p.packageCachePath, plugin_unique_identifier.String()))
 
 
 	if err != nil {
 	if err != nil {
@@ -225,11 +227,15 @@ func (p *PluginManager) GetPackage(plugin_unique_identifier plugin_entities.Plug
 	return file, nil
 	return file, nil
 }
 }
 
 
-func (p *PluginManager) GetPackagePath(plugin_unique_identifier plugin_entities.PluginUniqueIdentifier) (string, error) {
+func (p *PluginManager) GetPackagePath(
+	plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+) (string, error) {
 	return filepath.Join(p.packageCachePath, plugin_unique_identifier.String()), nil
 	return filepath.Join(p.packageCachePath, plugin_unique_identifier.String()), nil
 }
 }
 
 
-func (p *PluginManager) GetDeclaration(plugin_unique_identifier plugin_entities.PluginUniqueIdentifier) (
+func (p *PluginManager) GetDeclaration(
+	plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+) (
 	*plugin_entities.PluginDeclaration, error,
 	*plugin_entities.PluginDeclaration, error,
 ) {
 ) {
 	return helper.CombinedGetPluginDeclaration(plugin_unique_identifier)
 	return helper.CombinedGetPluginDeclaration(plugin_unique_identifier)

+ 25 - 0
internal/core/plugin_manager/uninstall.go

@@ -0,0 +1,25 @@
+package plugin_manager
+
+import (
+	"os"
+	"path/filepath"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+)
+
+// UninstallFromLocal uninstalls a plugin from local storage
+// once deleted, local runtime will automatically shutdown and exit after several time
+func (p *PluginManager) UninstallFromLocal(identity plugin_entities.PluginUniqueIdentifier) error {
+	plugin_installation_path := filepath.Join(p.pluginStoragePath, identity.String())
+	if err := os.RemoveAll(plugin_installation_path); err != nil {
+		return err
+	}
+	// send shutdown runtime
+	runtime, ok := p.m.Load(identity.String())
+	if !ok {
+		// no runtime to shutdown, already uninstalled
+		return nil
+	}
+	runtime.Shutdown()
+	return nil
+}

+ 6 - 0
internal/core/plugin_manager/watcher.go

@@ -28,6 +28,7 @@ func (p *PluginManager) startLocalWatcher() {
 		p.handleNewLocalPlugins()
 		p.handleNewLocalPlugins()
 		for range time.NewTicker(time.Second * 30).C {
 		for range time.NewTicker(time.Second * 30).C {
 			p.handleNewLocalPlugins()
 			p.handleNewLocalPlugins()
+			p.removeUninstalledLocalPlugins()
 		}
 		}
 	}()
 	}()
 }
 }
@@ -92,6 +93,11 @@ func (p *PluginManager) handleNewLocalPlugins() {
 	}
 	}
 }
 }
 
 
+// an async function to remove uninstalled local plugins
+func (p *PluginManager) removeUninstalledLocalPlugins() {
+	// TODO: implement
+}
+
 func (p *PluginManager) launchLocal(plugin_package_path string) (
 func (p *PluginManager) launchLocal(plugin_package_path string) (
 	plugin_entities.PluginFullDuplexLifetime, <-chan error, error,
 	plugin_entities.PluginFullDuplexLifetime, <-chan error, error,
 ) {
 ) {

+ 14 - 1
internal/service/install_plugin.go

@@ -480,7 +480,7 @@ func UninstallPlugin(
 	}
 	}
 
 
 	// Uninstall the plugin
 	// Uninstall the plugin
-	_, err = curd.UninstallPlugin(
+	delete_response, err := curd.UninstallPlugin(
 		tenant_id,
 		tenant_id,
 		plugin_unique_identifier,
 		plugin_unique_identifier,
 		installation.ID,
 		installation.ID,
@@ -489,5 +489,18 @@ func UninstallPlugin(
 		return entities.NewErrorResponse(-500, fmt.Sprintf("Failed to uninstall plugin: %s", err.Error()))
 		return entities.NewErrorResponse(-500, fmt.Sprintf("Failed to uninstall plugin: %s", err.Error()))
 	}
 	}
 
 
+	if delete_response.IsPluginDeleted {
+		// delete the plugin if no installation left
+		manager := plugin_manager.Manager()
+		if delete_response.Installation.RuntimeType == string(
+			plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL,
+		) {
+			err = manager.UninstallFromLocal(plugin_unique_identifier)
+			if err != nil {
+				return entities.NewErrorResponse(-500, fmt.Sprintf("Failed to uninstall plugin: %s", err.Error()))
+			}
+		}
+	}
+
 	return entities.NewSuccessResponse(true)
 	return entities.NewSuccessResponse(true)
 }
 }

+ 5 - 2
internal/types/models/curd/atomic.go

@@ -142,8 +142,11 @@ func InstallPlugin(
 }
 }
 
 
 type DeletePluginResponse struct {
 type DeletePluginResponse struct {
-	Plugin          *models.Plugin
-	Installation    *models.PluginInstallation
+	Plugin       *models.Plugin
+	Installation *models.PluginInstallation
+
+	// whether the refers of the plugin has been decreased to 0
+	// which means the whole plugin has been uninstalled, not just the installation
 	IsPluginDeleted bool
 	IsPluginDeleted bool
 }
 }