Forráskód Böngészése

refactor: install plugin to locals

Yeuoly 9 hónapja%!(EXTRA string=óta)
szülő
commit
79b6e0450c

+ 1 - 1
cmd/commandline/plugin.go

@@ -220,7 +220,7 @@ endpoint				- allow plugin to register endpoint`,
 			config.SetDefault()
 
 			// init plugin manager
-			plugin_manager := plugin_manager.NewManager(config)
+			plugin_manager := plugin_manager.InitGlobalManager(config)
 
 			// start to schedule plugin subprocesses
 			process.Init(config)

+ 1 - 0
internal/core/plugin_manager/errors.go

@@ -0,0 +1 @@
+package plugin_manager

+ 6 - 2
internal/core/plugin_manager/install_to_local.go

@@ -26,6 +26,10 @@ func (p *PluginManager) InstallToLocal(
 	}
 	defer plugin_file.Close()
 	installed_file_path := filepath.Join(p.pluginStoragePath, plugin_unique_identifier.String())
+	dir_path := filepath.Dir(installed_file_path)
+	if err := os.MkdirAll(dir_path, 0755); err != nil {
+		return nil, err
+	}
 	installed_file, err := os.Create(installed_file_path)
 	if err != nil {
 		return nil, err
@@ -36,7 +40,7 @@ func (p *PluginManager) InstallToLocal(
 		return nil, err
 	}
 
-	runtime, err := p.launchLocal(installed_file_path)
+	runtime, launched_chan, err := p.launchLocal(installed_file_path)
 	if err != nil {
 		return nil, err
 	}
@@ -66,7 +70,7 @@ func (p *PluginManager) InstallToLocal(
 				})
 				runtime.Stop()
 				return
-			case err := <-runtime.WaitLaunched():
+			case <-launched_chan:
 				// launched
 				if err != nil {
 					response.Write(PluginInstallResponse{

+ 9 - 23
internal/core/plugin_manager/lifetime.go

@@ -11,16 +11,7 @@ func (p *PluginManager) AddPluginRegisterHandler(handler func(r plugin_entities.
 	p.pluginRegisters = append(p.pluginRegisters, handler)
 }
 
-func (p *PluginManager) fullDuplexLifetime(r plugin_entities.PluginFullDuplexLifetime) {
-	identifier, err := r.Identity()
-	if err != nil {
-		log.Error("get plugin identity failed: %s", err.Error())
-		return
-	}
-
-	p.m.Store(identifier.String(), r)
-	defer p.m.Delete(identifier.String())
-
+func (p *PluginManager) fullDuplexLifetime(r plugin_entities.PluginFullDuplexLifetime, launched_chan chan error) {
 	configuration := r.Configuration()
 
 	log.Info("new plugin logged in: %s", configuration.Identity())
@@ -41,29 +32,24 @@ func (p *PluginManager) fullDuplexLifetime(r plugin_entities.PluginFullDuplexLif
 		}
 	}
 
-	start_failed_times := 0
-
 	// remove lifetime state after plugin if it has been stopped
 	defer r.TriggerStop()
 
-	// try at most 3 times to init environment
-	for i := 0; i < 3; i++ {
+	// try to init environment until succeed
+	for {
+		log.Info("init environment for plugin %s", configuration.Identity())
 		if err := r.InitEnvironment(); err != nil {
 			log.Error("init environment failed: %s, retry in 30s", err.Error())
-			if start_failed_times == 3 {
-				log.Error(
-					"init environment failed 3 times, plugin %s has been stopped",
-					configuration.Identity(),
-				)
-				return
-			}
 			time.Sleep(30 * time.Second)
-			start_failed_times++
 			continue
 		}
+		break
 	}
 
-	// TODO: launched will only be triggered when calling StartPlugin
+	// notify launched
+	if launched_chan != nil {
+		close(launched_chan)
+	}
 
 	// init environment successfully
 	// once succeed, we consider the plugin is installed successfully

+ 0 - 4
internal/core/plugin_manager/local_manager/environment.go

@@ -10,10 +10,6 @@ import (
 )
 
 func (r *LocalPluginRuntime) InitEnvironment() error {
-	if _, err := os.Stat(path.Join(r.State.WorkingPath, ".installed")); err == nil {
-		return nil
-	}
-
 	var err error
 	if r.Config.Meta.Runner.Language == constants.Python {
 		err = r.InitPythonEnvironment()

+ 6 - 1
internal/core/plugin_manager/local_manager/environment_python.go

@@ -17,10 +17,15 @@ import (
 )
 
 func (p *LocalPluginRuntime) InitPythonEnvironment() error {
+	// check if virtual environment exists
+	if _, err := os.Stat(path.Join(p.State.WorkingPath, ".venv")); err == nil {
+		return nil
+	}
+
 	// execute init command, create a virtual environment
 	success := false
 
-	cmd := exec.Command("bash", "-c", "python3 -m venv .venv")
+	cmd := exec.Command("bash", "-c", fmt.Sprintf("%s -m venv .venv", p.default_python_interpreter_path))
 	cmd.Dir = p.State.WorkingPath
 	b := bytes.NewBuffer(nil)
 	cmd.Stdout = b

+ 8 - 58
internal/core/plugin_manager/local_manager/run.go

@@ -27,17 +27,6 @@ func (r *LocalPluginRuntime) gc() {
 	}
 }
 
-// init initializes the LocalPluginRuntime
-func (r *LocalPluginRuntime) init() {
-	// reset wait chan
-	r.wait_chan = make(chan bool)
-	// reset wait launched chan
-	r.wait_launched_chan_once = sync.Once{}
-	r.wait_launched_chan = make(chan error)
-
-	r.SetLaunching()
-}
-
 // Type returns the runtime type of the plugin
 func (r *LocalPluginRuntime) Type() plugin_entities.PluginRuntimeType {
 	return plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL
@@ -68,18 +57,13 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 		r.wait_chan_lock.Unlock()
 	}()
 
-	r.init()
+	// reset wait chan
+	r.wait_chan = make(chan bool)
+	// reset wait launched chan
 
 	// start plugin
 	e, err := r.getCmd()
 	if err != nil {
-		r.wait_launched_chan_once.Do(func() {
-			select {
-			case r.wait_launched_chan <- err:
-			default:
-			}
-			close(r.wait_launched_chan)
-		})
 		return err
 	}
 
@@ -91,24 +75,11 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 	// ensure all subprocess are killed when parent process exits, especially on Golang debugger
 	process.WrapProcess(e)
 
-	// notify launched, notify error if any
-	notify_launched := func(err error) {
-		r.wait_launched_chan_once.Do(func() {
-			select {
-			case r.wait_launched_chan <- err:
-			default:
-			}
-			close(r.wait_launched_chan)
-		})
-	}
-
 	// get writer
 	stdin, err := e.StdinPipe()
 	if err != nil {
 		r.SetRestarting()
-		err = fmt.Errorf("get stdin pipe failed: %s", err.Error())
-		notify_launched(err)
-		return err
+		return fmt.Errorf("get stdin pipe failed: %s", err.Error())
 	}
 	defer stdin.Close()
 
@@ -116,9 +87,7 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 	stdout, err := e.StdoutPipe()
 	if err != nil {
 		r.SetRestarting()
-		err = fmt.Errorf("get stdout pipe failed: %s", err.Error())
-		notify_launched(err)
-		return err
+		return fmt.Errorf("get stdout pipe failed: %s", err.Error())
 	}
 	defer stdout.Close()
 
@@ -126,17 +95,13 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 	stderr, err := e.StderrPipe()
 	if err != nil {
 		r.SetRestarting()
-		err = fmt.Errorf("get stderr pipe failed: %s", err.Error())
-		notify_launched(err)
-		return err
+		return fmt.Errorf("get stderr pipe failed: %s", err.Error())
 	}
 	defer stderr.Close()
 
 	if err := e.Start(); err != nil {
 		r.SetRestarting()
-		err = fmt.Errorf("start plugin failed: %s", err.Error())
-		notify_launched(err)
-		return err
+		return fmt.Errorf("start plugin failed: %s", err.Error())
 	}
 
 	// add to subprocess manager
@@ -151,11 +116,6 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 			log.Error("plugin %s exited with error: %s", r.Config.Identity(), err.Error())
 		}
 
-		// close wait launched chan
-		r.wait_launched_chan_once.Do(func() {
-			close(r.wait_launched_chan)
-		})
-
 		r.gc()
 	}()
 	defer e.Process.Kill()
@@ -173,12 +133,7 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 	// listen to plugin stdout
 	routine.Submit(func() {
 		defer wg.Done()
-		stdio.StartStdout(func() {
-			// get heartbeat, notify launched
-			r.wait_launched_chan_once.Do(func() {
-				close(r.wait_launched_chan)
-			})
-		})
+		stdio.StartStdout(func() {})
 	})
 
 	// listen to plugin stderr
@@ -234,8 +189,3 @@ func (r *LocalPluginRuntime) WaitStopped() <-chan bool {
 	r.wait_chan_lock.Unlock()
 	return c
 }
-
-// WaitLaunched returns a channel that will receive an error if the plugin fails to launch
-func (r *LocalPluginRuntime) WaitLaunched() <-chan error {
-	return r.wait_launched_chan
-}

+ 9 - 8
internal/core/plugin_manager/local_manager/type.go

@@ -15,17 +15,18 @@ type LocalPluginRuntime struct {
 	io_identity string
 
 	// python interpreter path, currently only support python
-	python_interpreter_path string
+	python_interpreter_path         string
+	default_python_interpreter_path string
 
-	wait_chan_lock          sync.Mutex
-	wait_started_chan       []chan bool
-	wait_stopped_chan       []chan bool
-	wait_launched_chan      chan error
-	wait_launched_chan_once sync.Once
+	wait_chan_lock    sync.Mutex
+	wait_started_chan []chan bool
+	wait_stopped_chan []chan bool
 }
 
-func NewLocalPluginRuntime() *LocalPluginRuntime {
+func NewLocalPluginRuntime(
+	python_interpreter_path string,
+) *LocalPluginRuntime {
 	return &LocalPluginRuntime{
-		wait_launched_chan: make(chan error),
+		default_python_interpreter_path: python_interpreter_path,
 	}
 }

+ 5 - 1
internal/core/plugin_manager/manager.go

@@ -48,13 +48,16 @@ type PluginManager struct {
 
 	// backwardsInvocation is a handle to invoke dify
 	backwardsInvocation dify_invocation.BackwardsInvocation
+
+	// python interpreter path
+	pythonInterpreterPath string
 }
 
 var (
 	manager *PluginManager
 )
 
-func NewManager(configuration *app.Config) *PluginManager {
+func InitGlobalManager(configuration *app.Config) *PluginManager {
 	manager = &PluginManager{
 		maxPluginPackageSize: configuration.MaxPluginPackageSize,
 		packageCachePath:     configuration.PluginPackageCachePath,
@@ -65,6 +68,7 @@ func NewManager(configuration *app.Config) *PluginManager {
 			configuration.PluginMediaCacheSize,
 		),
 		localPluginLaunchingLock: lock.NewGranularityLock(),
+		pythonInterpreterPath:    configuration.PythonInterpreterPath,
 	}
 
 	// mkdir

+ 0 - 4
internal/core/plugin_manager/remote_manager/run.go

@@ -88,7 +88,3 @@ func (r *RemotePluginRuntime) Wait() (<-chan bool, error) {
 func (r *RemotePluginRuntime) Checksum() (string, error) {
 	return r.checksum, nil
 }
-
-func (r *RemotePluginRuntime) WaitLaunched() <-chan error {
-	return r.wait_launched_chan
-}

+ 2 - 2
internal/core/plugin_manager/tester.go

@@ -35,7 +35,7 @@ func (p *PluginManager) TestPlugin(
 		return nil, errors.Join(err, errors.New("failed to get assets"))
 	}
 
-	local_plugin_runtime := local_manager.NewLocalPluginRuntime()
+	local_plugin_runtime := local_manager.NewLocalPluginRuntime(p.pythonInterpreterPath)
 	local_plugin_runtime.PluginRuntime = plugin.runtime
 	local_plugin_runtime.PositivePluginRuntime = positive_manager.PositivePluginRuntime{
 		BasicPluginRuntime: basic_manager.NewBasicPluginRuntime(p.mediaManager),
@@ -66,7 +66,7 @@ func (p *PluginManager) TestPlugin(
 			}
 		}()
 		// delete the plugin from the storage when the plugin is stopped
-		p.fullDuplexLifetime(local_plugin_runtime)
+		p.fullDuplexLifetime(local_plugin_runtime, nil)
 	})
 
 	// wait for the plugin to start

+ 50 - 40
internal/core/plugin_manager/watcher.go

@@ -4,8 +4,10 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"io/fs"
 	"os"
 	"path"
+	"path/filepath"
 	"strings"
 	"time"
 
@@ -22,6 +24,9 @@ import (
 
 func (p *PluginManager) startLocalWatcher() {
 	go func() {
+		// delete all plugins in working directory
+		os.RemoveAll(p.workingDirectory)
+
 		log.Info("start to handle new plugins in path: %s", p.pluginStoragePath)
 		p.handleNewLocalPlugins()
 		for range time.NewTicker(time.Second * 30).C {
@@ -48,7 +53,7 @@ func (p *PluginManager) startRemoteWatcher(config *app.Config) {
 							log.Error("plugin runtime error: %v", err)
 						}
 					}()
-					p.fullDuplexLifetime(rpr)
+					p.fullDuplexLifetime(rpr, nil)
 				})
 			})
 		}()
@@ -56,32 +61,37 @@ func (p *PluginManager) startRemoteWatcher(config *app.Config) {
 }
 
 func (p *PluginManager) handleNewLocalPlugins() {
-	// load local plugins firstly
-	plugins, err := os.ReadDir(p.pluginStoragePath)
-	if err != nil {
-		log.Error("no plugin found in path: %s", p.pluginStoragePath)
-	}
-
-	for _, plugin := range plugins {
-		if !plugin.IsDir() {
-			abs_path := path.Join(p.pluginStoragePath, plugin.Name())
-			_, err := p.launchLocal(abs_path)
+	// walk through all plugins
+	err := filepath.WalkDir(p.pluginStoragePath, func(path string, d fs.DirEntry, err error) error {
+		if err != nil {
+			return err
+		}
+		if !d.IsDir() {
+			_, _, err := p.launchLocal(path)
 			if err != nil {
 				log.Error("launch local plugin failed: %s", err.Error())
 			}
 		}
+
+		return nil
+	})
+
+	if err != nil {
+		log.Error("walk through plugins failed: %s", err.Error())
 	}
 }
 
-func (p *PluginManager) launchLocal(plugin_package_path string) (plugin_entities.PluginFullDuplexLifetime, error) {
+func (p *PluginManager) launchLocal(plugin_package_path string) (
+	plugin_entities.PluginFullDuplexLifetime, <-chan error, error,
+) {
 	plugin, err := p.getLocalPluginRuntime(plugin_package_path)
 	if err != nil {
-		return nil, err
+		return nil, nil, err
 	}
 
 	identity, err := plugin.decoder.UniqueIdentity()
 	if err != nil {
-		return nil, err
+		return nil, nil, err
 	}
 
 	// lock launch process
@@ -89,22 +99,27 @@ func (p *PluginManager) launchLocal(plugin_package_path string) (plugin_entities
 	defer p.localPluginLaunchingLock.Unlock(identity.String())
 
 	// check if the plugin is already running
-	if _, ok := p.m.Load(identity.String()); ok {
-		lifetime, ok := p.Get(identity).(plugin_entities.PluginFullDuplexLifetime)
+	if lifetime, ok := p.m.Load(identity.String()); ok {
+		lifetime, ok := lifetime.(plugin_entities.PluginFullDuplexLifetime)
 		if !ok {
-			return nil, fmt.Errorf("plugin runtime not found")
+			return nil, nil, fmt.Errorf("plugin runtime not found")
 		}
-		return lifetime, nil
+
+		// returns a closed channel to indicate the plugin is already running, no more waiting is needed
+		c := make(chan error)
+		close(c)
+
+		return lifetime, c, nil
 	}
 
 	// extract plugin
 	decoder, ok := plugin.decoder.(*decoder.ZipPluginDecoder)
 	if !ok {
-		return nil, fmt.Errorf("plugin decoder is not a zip decoder")
+		return nil, nil, fmt.Errorf("plugin decoder is not a zip decoder")
 	}
 
 	if err := decoder.ExtractTo(plugin.runtime.State.WorkingPath); err != nil {
-		return nil, errors.Join(err, fmt.Errorf("extract plugin to working directory error"))
+		return nil, nil, errors.Join(err, fmt.Errorf("extract plugin to working directory error"))
 	}
 
 	success := false
@@ -118,10 +133,10 @@ func (p *PluginManager) launchLocal(plugin_package_path string) (plugin_entities
 	// get assets
 	assets, err := plugin.decoder.Assets()
 	if err != nil {
-		return nil, failed(err.Error())
+		return nil, nil, failed(err.Error())
 	}
 
-	local_plugin_runtime := local_manager.NewLocalPluginRuntime()
+	local_plugin_runtime := local_manager.NewLocalPluginRuntime(p.pythonInterpreterPath)
 	local_plugin_runtime.PluginRuntime = plugin.runtime
 	local_plugin_runtime.PositivePluginRuntime = positive_manager.PositivePluginRuntime{
 		BasicPluginRuntime: basic_manager.NewBasicPluginRuntime(p.mediaManager),
@@ -134,22 +149,27 @@ func (p *PluginManager) launchLocal(plugin_package_path string) (plugin_entities
 		&local_plugin_runtime.Config,
 		assets,
 	); err != nil {
-		return nil, failed(errors.Join(err, fmt.Errorf("remap plugin assets error")).Error())
+		return nil, nil, failed(errors.Join(err, fmt.Errorf("remap plugin assets error")).Error())
 	}
 
 	success = true
 
+	p.m.Store(identity.String(), local_plugin_runtime)
+
+	launched_chan := make(chan error)
+
 	// local plugin
 	routine.Submit(func() {
 		defer func() {
 			if r := recover(); r != nil {
 				log.Error("plugin runtime panic: %v", r)
 			}
+			p.m.Delete(identity.String())
 		}()
-		p.fullDuplexLifetime(local_plugin_runtime)
+		p.fullDuplexLifetime(local_plugin_runtime, launched_chan)
 	})
 
-	return local_plugin_runtime, nil
+	return local_plugin_runtime, launched_chan, nil
 }
 
 type pluginRuntimeWithDecoder struct {
@@ -158,7 +178,10 @@ type pluginRuntimeWithDecoder struct {
 }
 
 // extract plugin from package to working directory
-func (p *PluginManager) getLocalPluginRuntime(plugin_path string) (*pluginRuntimeWithDecoder, error) {
+func (p *PluginManager) getLocalPluginRuntime(plugin_path string) (
+	*pluginRuntimeWithDecoder,
+	error,
+) {
 	pack, err := os.Open(plugin_path)
 	if err != nil {
 		return nil, errors.Join(err, fmt.Errorf("open plugin package error"))
@@ -169,7 +192,7 @@ func (p *PluginManager) getLocalPluginRuntime(plugin_path string) (*pluginRuntim
 		return nil, errors.Join(err, fmt.Errorf("get plugin package info error"))
 	} else if info.Size() > p.maxPluginPackageSize {
 		log.Error("plugin package size is too large: %d", info.Size())
-		return nil, err
+		return nil, errors.Join(err, fmt.Errorf("plugin package size is too large"))
 	}
 
 	plugin_zip, err := io.ReadAll(pack)
@@ -188,27 +211,14 @@ func (p *PluginManager) getLocalPluginRuntime(plugin_path string) (*pluginRuntim
 		return nil, errors.Join(err, fmt.Errorf("get plugin manifest error"))
 	}
 
-	// check if already exists
-	if _, exist := p.m.Load(manifest.Identity()); exist {
-		return nil, errors.Join(fmt.Errorf("plugin already exists: %s", manifest.Identity()), err)
-	}
-
 	checksum, err := decoder.Checksum()
 	if err != nil {
 		return nil, errors.Join(err, fmt.Errorf("calculate checksum error"))
 	}
 
 	identity := manifest.Identity()
-	// replace : with -
 	identity = strings.ReplaceAll(identity, ":", "-")
-
 	plugin_working_path := path.Join(p.workingDirectory, fmt.Sprintf("%s@%s", identity, checksum))
-
-	// check if working directory exists
-	if _, err := os.Stat(plugin_working_path); err == nil {
-		return nil, errors.Join(fmt.Errorf("plugin working directory already exists: %s", plugin_working_path), err)
-	}
-
 	return &pluginRuntimeWithDecoder{
 		runtime: plugin_entities.PluginRuntime{
 			Config: manifest,

+ 1 - 1
internal/server/server.go

@@ -21,7 +21,7 @@ func (app *App) Run(config *app.Config) {
 	process.Init(config)
 
 	// create manager
-	manager := plugin_manager.NewManager(config)
+	manager := plugin_manager.InitGlobalManager(config)
 
 	// create cluster
 	app.cluster = cluster.NewCluster(config, manager)

+ 2 - 0
internal/types/app/config.go

@@ -77,6 +77,8 @@ type Config struct {
 	MaxPluginPackageSize int64 `envconfig:"MAX_PLUGIN_PACKAGE_SIZE" validate:"required"`
 
 	MaxAWSLambdaTransactionTimeout int `envconfig:"MAX_AWS_LAMBDA_TRANSACTION_TIMEOUT"`
+
+	PythonInterpreterPath string `envconfig:"PYTHON_INTERPRETER_PATH"`
 }
 
 func (c *Config) Validate() error {

+ 1 - 0
internal/types/app/default.go

@@ -24,6 +24,7 @@ func (config *Config) SetDefault() {
 	setDefaultString(&config.PersistenceStorageLocalPath, "./storage/persistence")
 	setDefaultString(&config.ProcessCachingPath, "./storage/subprocesses")
 	setDefaultString(&config.PluginPackageCachePath, "./storage/plugin_packages")
+	setDefaultString(&config.PythonInterpreterPath, "/usr/bin/python3")
 }
 
 func setDefaultInt[T constraints.Integer](value *T, defaultValue T) {

+ 0 - 2
internal/types/entities/plugin_entities/runtime.go

@@ -54,8 +54,6 @@ type (
 		WaitStarted() <-chan bool
 		// Stopped
 		WaitStopped() <-chan bool
-		// Launched
-		WaitLaunched() <-chan error
 	}
 
 	PluginServerlessLifetime interface {