Pārlūkot izejas kodu

refactor: local plugin cwd

Yeuoly 11 mēneši atpakaļ
vecāks
revīzija
277ce2081f

+ 8 - 0
internal/core/plugin_manager/manager.go

@@ -9,7 +9,9 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/lock"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/mapping"
 )
 
 type PluginManager struct {
@@ -19,6 +21,11 @@ type PluginManager struct {
 
 	maxPluginPackageSize int64
 	workingDirectory     string
+
+	// running plugin in storage contains relations between plugin packages and their running instances
+	runningPluginInStorage mapping.Map[string, string]
+	// start process lock
+	startProcessLock *lock.HighGranularityLock
 }
 
 var (
@@ -30,6 +37,7 @@ func InitGlobalPluginManager(cluster *cluster.Cluster, configuration *app.Config
 		cluster:              cluster,
 		maxPluginPackageSize: configuration.MaxPluginPackageSize,
 		workingDirectory:     configuration.PluginWorkingPath,
+		startProcessLock:     lock.NewHighGranularityLock(),
 	}
 	manager.Init(configuration)
 }

+ 0 - 1
internal/core/plugin_manager/remote_manager/server.go

@@ -60,7 +60,6 @@ func (r *RemotePluginServer) Stop() error {
 // Launch starts the server
 func (r *RemotePluginServer) Launch() error {
 	// kill the process if port is already in use
-	// TODO: switch to optional
 	exec.Command("fuser", "-k", "tcp", fmt.Sprintf("%d", r.server.port)).Run()
 
 	time.Sleep(time.Millisecond * 100)

+ 40 - 20
internal/core/plugin_manager/watcher.go

@@ -1,6 +1,7 @@
 package plugin_manager
 
 import (
+	"errors"
 	"fmt"
 	"io"
 	"os"
@@ -11,6 +12,7 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/local_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/positive_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/remote_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/checksum"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/decoder"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/verifier"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
@@ -75,7 +77,23 @@ func (p *PluginManager) handleNewPlugins(config *app.Config) {
 			continue
 		}
 
+		identity, err := plugin_interface.Identity()
+		if err != nil {
+			log.Error("get plugin identity error: %v", err)
+			continue
+		}
+
+		// store the plugin in the storage, avoid duplicate loading
+		p.runningPluginInStorage.Store(plugin.Runtime.State.AbsolutePath, identity.String())
+
 		routine.Submit(func() {
+			defer func() {
+				if r := recover(); r != nil {
+					log.Error("plugin runtime error: %v", r)
+				}
+			}()
+			// delete the plugin from the storage when the plugin is stopped
+			defer p.runningPluginInStorage.Delete(plugin.Runtime.State.AbsolutePath)
 			p.lifetime(plugin_interface)
 		})
 	}
@@ -100,16 +118,20 @@ func (p *PluginManager) loadNewPlugins(root_path string) <-chan *pluginRuntimeWi
 	routine.Submit(func() {
 		for _, plugin := range plugins {
 			if !plugin.IsDir() {
-				plugin, err := p.loadPlugin(path.Join(root_path, plugin.Name()))
+				abs_path := path.Join(root_path, plugin.Name())
+				if _, ok := p.runningPluginInStorage.Load(abs_path); ok {
+					// if the plugin is already running, skip it
+					continue
+				}
+
+				plugin, err := p.loadPlugin(abs_path)
 				if err != nil {
 					log.Error("load plugin error: %v", err)
 					continue
 				}
-
 				ch <- plugin
 			}
 		}
-
 		close(ch)
 	})
 
@@ -119,14 +141,12 @@ func (p *PluginManager) loadNewPlugins(root_path string) <-chan *pluginRuntimeWi
 func (p *PluginManager) loadPlugin(plugin_path string) (*pluginRuntimeWithDecoder, error) {
 	pack, err := os.Open(plugin_path)
 	if err != nil {
-		log.Error("open plugin package error: %v", err)
-		return nil, err
+		return nil, errors.Join(err, fmt.Errorf("open plugin package error"))
 	}
 	defer pack.Close()
 
 	if info, err := pack.Stat(); err != nil {
-		log.Error("get plugin package info error: %v", err)
-		return nil, err
+		return nil, errors.Join(err, fmt.Errorf("get plugin package info error"))
 	} else if info.Size() > p.maxPluginPackageSize {
 		log.Error("plugin package size is too large: %d", info.Size())
 		return nil, err
@@ -134,35 +154,36 @@ func (p *PluginManager) loadPlugin(plugin_path string) (*pluginRuntimeWithDecode
 
 	plugin_zip, err := io.ReadAll(pack)
 	if err != nil {
-		log.Error("read plugin package error: %v", err)
-		return nil, err
+		return nil, errors.Join(err, fmt.Errorf("read plugin package error"))
 	}
 
 	decoder, err := decoder.NewZipPluginDecoder(plugin_zip)
 	if err != nil {
-		log.Error("create plugin decoder error: %v", err)
-		return nil, err
+		return nil, errors.Join(err, fmt.Errorf("create plugin decoder error"))
 	}
 
 	// get manifest
 	manifest, err := decoder.Manifest()
 	if err != nil {
-		log.Error("get plugin manifest error: %v", err)
-		return nil, err
+		return nil, errors.Join(err, fmt.Errorf("get plugin manifest error"))
 	}
 
 	// check if already exists
 	if _, exist := p.m.Load(manifest.Identity()); exist {
-		log.Warn("plugin already exists: %s", manifest.Identity())
-		return nil, fmt.Errorf("plugin already exists: %s", manifest.Identity())
+		return nil, errors.Join(fmt.Errorf("plugin already exists: %s", manifest.Identity()), err)
 	}
 
-	plugin_working_path := path.Join(p.workingDirectory, manifest.Identity())
+	// TODO: use plugin unique id as the working directory
+	checksum, err := checksum.CalculateChecksum(decoder)
+	if err != nil {
+		return nil, errors.Join(err, fmt.Errorf("calculate checksum error"))
+	}
+
+	plugin_working_path := path.Join(p.workingDirectory, fmt.Sprintf("%s@%s", manifest.Identity(), checksum))
 
 	// check if working directory exists
 	if _, err := os.Stat(plugin_working_path); err == nil {
-		log.Warn("plugin working directory already exists: %s", plugin_working_path)
-		return nil, fmt.Errorf("plugin working directory already exists: %s", plugin_working_path)
+		return nil, errors.Join(fmt.Errorf("plugin working directory already exists: %s", plugin_working_path), err)
 	}
 
 	// copy to working directory
@@ -187,8 +208,7 @@ func (p *PluginManager) loadPlugin(plugin_path string) (*pluginRuntimeWithDecode
 
 		return nil
 	}); err != nil {
-		log.Error("copy plugin to working directory error: %v", err)
-		return nil, err
+		return nil, errors.Join(fmt.Errorf("copy plugin to working directory error: %v", err), err)
 	}
 
 	return &pluginRuntimeWithDecoder{

+ 1 - 1
internal/types/entities/plugin_entities/plugin_declaration.go

@@ -162,7 +162,7 @@ func isPluginName(fl validator.FieldLevel) bool {
 }
 
 func (p *PluginDeclaration) Identity() string {
-	return parser.MarshalPluginUniqueIdentifier(p.Name, p.Version)
+	return parser.MarshalPluginID(p.Name, p.Version)
 }
 
 func (p *PluginDeclaration) ManifestValidate() error {

+ 0 - 58
internal/utils/http_requests/http_warpper.go

@@ -131,64 +131,6 @@ func RequestAndParseStream[T any](client *http.Client, url string, method string
 	return ch, nil
 }
 
-// TODO: improve this, deduplicate code
-func RequestAndParseStreamMap(client *http.Client, url string, method string, options ...HttpOptions) (*stream.StreamResponse[map[string]any], error) {
-	resp, err := Request(client, url, method, options...)
-	if err != nil {
-		return nil, err
-	}
-
-	if resp.StatusCode != http.StatusOK {
-		defer resp.Body.Close()
-		error_text, _ := io.ReadAll(resp.Body)
-		return nil, fmt.Errorf("request failed with status code: %d and respond with: %s", resp.StatusCode, error_text)
-	}
-
-	ch := stream.NewStreamResponse[map[string]any](1024)
-
-	// get read timeout
-	read_timeout := int64(60000)
-	for _, option := range options {
-		if option.Type == "read_timeout" {
-			read_timeout = option.Value.(int64)
-			break
-		}
-	}
-	time.AfterFunc(time.Millisecond*time.Duration(read_timeout), func() {
-		// close the response body if timeout
-		resp.Body.Close()
-	})
-
-	routine.Submit(func() {
-		scanner := bufio.NewScanner(resp.Body)
-		defer resp.Body.Close()
-
-		for scanner.Scan() {
-			data := scanner.Bytes()
-			if len(data) == 0 {
-				continue
-			}
-
-			if bytes.HasPrefix(data, []byte("data: ")) {
-				// split
-				data = data[6:]
-			}
-
-			// unmarshal
-			t, err := parser.UnmarshalJsonBytes2Map(data)
-			if err != nil {
-				continue
-			}
-
-			ch.Write(t)
-		}
-
-		ch.Close()
-	})
-
-	return ch, nil
-}
-
 func GetAndParseStream[T any](client *http.Client, url string, options ...HttpOptions) (*stream.StreamResponse[T], error) {
 	return RequestAndParseStream[T](client, url, "GET", options...)
 }

+ 51 - 0
internal/utils/lock/lock.go

@@ -0,0 +1,51 @@
+package lock
+
+import (
+	"sync"
+	"sync/atomic"
+)
+
+type mutex struct {
+	*sync.Mutex
+	count int32
+}
+
+type HighGranularityLock struct {
+	m map[string]*mutex
+	l sync.Mutex
+}
+
+func NewHighGranularityLock() *HighGranularityLock {
+	return &HighGranularityLock{
+		m: make(map[string]*mutex),
+	}
+}
+
+func (l *HighGranularityLock) Lock(key string) {
+	l.l.Lock()
+	var m *mutex
+	var ok bool
+	if m, ok = l.m[key]; !ok {
+		m = &mutex{Mutex: &sync.Mutex{}, count: 1}
+		l.m[key] = m
+	} else {
+		atomic.AddInt32(&m.count, 1)
+	}
+	l.l.Unlock()
+
+	m.Lock()
+}
+
+func (l *HighGranularityLock) Unlock(key string) {
+	l.l.Lock()
+	m, ok := l.m[key]
+	if !ok {
+		return
+	}
+	atomic.AddInt32(&m.count, -1)
+	if atomic.LoadInt32(&m.count) == 0 {
+		delete(l.m, key)
+	}
+	l.l.Unlock()
+	m.Unlock()
+}

+ 40 - 0
internal/utils/lock/lock_test.go

@@ -0,0 +1,40 @@
+package lock
+
+import (
+	"fmt"
+	"sync"
+	"testing"
+)
+
+func TestHighGranularityLock(t *testing.T) {
+	l := NewHighGranularityLock()
+
+	data := []int{}
+	add := func(key int) {
+		l.Lock(fmt.Sprintf("%d", key))
+		data[key]++
+		l.Unlock(fmt.Sprintf("%d", key))
+	}
+
+	for i := 0; i < 1000; i++ {
+		data = append(data, 0)
+	}
+
+	wg := sync.WaitGroup{}
+	for i := 0; i < 1000; i++ {
+		wg.Add(1)
+		go func() {
+			for j := 0; j < 1000; j++ {
+				add(j)
+			}
+			wg.Done()
+		}()
+	}
+	wg.Wait()
+
+	for _, v := range data {
+		if v != 1000 {
+			t.Fatal("data not equal")
+		}
+	}
+}

+ 1 - 1
internal/utils/parser/identity.go

@@ -2,6 +2,6 @@ package parser
 
 import "fmt"
 
-func MarshalPluginUniqueIdentifier(name string, version string) string {
+func MarshalPluginID(name string, version string) string {
 	return fmt.Sprintf("%s:%s", name, version)
 }