Quellcode durchsuchen

refactor: installing plugin

Yeuoly vor 10 Monaten
Ursprung
Commit
3de55a81ce

+ 12 - 0
internal/core/plugin_manager/manager.go

@@ -24,6 +24,7 @@ type PluginManager struct {
 
 	maxPluginPackageSize int64
 	workingDirectory     string
+	packageCachePath     string
 
 	// mediaManager is used to manage media files like plugin icons, images, etc.
 	mediaManager *media_manager.MediaManager
@@ -55,6 +56,7 @@ var (
 func NewManager(configuration *app.Config) *PluginManager {
 	manager = &PluginManager{
 		maxPluginPackageSize: configuration.MaxPluginPackageSize,
+		packageCachePath:     configuration.PluginPackageCachePath,
 		workingDirectory:     configuration.PluginWorkingPath,
 		mediaManager: media_manager.NewMediaManager(
 			configuration.PluginMediaCachePath,
@@ -74,6 +76,7 @@ func NewManager(configuration *app.Config) *PluginManager {
 	os.MkdirAll(configuration.PluginWorkingPath, 0755)
 	os.MkdirAll(configuration.PluginStoragePath, 0755)
 	os.MkdirAll(configuration.PluginMediaCachePath, 0755)
+	os.MkdirAll(configuration.PluginPackageCachePath, 0755)
 	os.MkdirAll(filepath.Dir(configuration.ProcessCachingPath), 0755)
 
 	return manager
@@ -146,3 +149,12 @@ func (p *PluginManager) Init(configuration *app.Config) {
 func (p *PluginManager) BackwardsInvocation() dify_invocation.BackwardsInvocation {
 	return p.backwardsInvocation
 }
+
+func (p *PluginManager) SavePackage(plugin_unique_identifier plugin_entities.PluginUniqueIdentifier, pkg []byte) error {
+	// save to storage
+	return os.WriteFile(filepath.Join(p.packageCachePath, plugin_unique_identifier.String()), pkg, 0644)
+}
+
+func (p *PluginManager) GetPackage(plugin_unique_identifier plugin_entities.PluginUniqueIdentifier) ([]byte, error) {
+	return os.ReadFile(filepath.Join(p.packageCachePath, plugin_unique_identifier.String()))
+}

+ 1 - 0
internal/db/init.go

@@ -84,6 +84,7 @@ func autoMigrate() error {
 		models.ServerlessRuntime{},
 		models.ToolInstallation{},
 		models.AIModelInstallation{},
+		models.InstallTask{},
 	)
 }
 

+ 1 - 1
internal/server/controllers/plugins.go

@@ -59,7 +59,7 @@ func InstallPluginFromIdentifiers(app *app.Config) gin.HandlerFunc {
 	return func(c *gin.Context) {
 		BindRequest(c, func(request struct {
 			TenantID                string                                   `uri:"tenant_id" validate:"required"`
-			PluginUniqueIdentifiers []plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifiers" validate:"required,dive,plugin_unique_identifier"`
+			PluginUniqueIdentifiers []plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifiers" validate:"required,max=64,dive,plugin_unique_identifier"`
 			Source                  string                                   `json:"source" validate:"required"`
 			Meta                    map[string]any                           `json:"meta" validate:"omitempty"`
 		}) {

+ 207 - 36
internal/service/install_plugin.go

@@ -5,8 +5,10 @@ import (
 	"fmt"
 	"io"
 	"mime/multipart"
+	"time"
 
 	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/decoder"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/verifier"
 	"github.com/langgenius/dify-plugin-daemon/internal/db"
@@ -15,6 +17,9 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/models/curd"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
+	"gorm.io/gorm"
 )
 
 func UploadPluginFromPkg(
@@ -57,14 +62,194 @@ func InstallPluginFromIdentifiers(
 	source string,
 	meta map[string]any,
 ) *entities.Response {
+	var response struct {
+		AllInstalled bool   `json:"all_installed"`
+		TaskID       string `json:"task_id"`
+	}
+
 	// TODO: create installation task and dispatch to workers
-	for _, plugin_unique_identifier := range plugin_unique_identifiers {
-		if err := InstallPluginFromIdentifier(tenant_id, plugin_unique_identifier, source, meta); err != nil {
+	plugins_wait_for_installation := []plugin_entities.PluginUniqueIdentifier{}
+
+	task := &models.InstallTask{
+		Status:           models.InstallTaskStatusRunning,
+		TotalPlugins:     len(plugins_wait_for_installation),
+		CompletedPlugins: 0,
+		Plugins:          []models.InstallTaskPluginStatus{},
+	}
+
+	for i, plugin_unique_identifier := range plugin_unique_identifiers {
+		// check if plugin is already installed
+		plugin, err := db.GetOne[models.Plugin](
+			db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
+		)
+
+		task.Plugins = append(task.Plugins, models.InstallTaskPluginStatus{
+			PluginUniqueIdentifier: plugin_unique_identifier,
+			PluginID:               plugin_unique_identifier.PluginID(),
+			Status:                 models.InstallTaskStatusPending,
+			Message:                "",
+		})
+
+		task.TotalPlugins++
+
+		if err == nil {
+			// already installed by other tenant
+			declaration := plugin.Declaration
+			if _, _, err := curd.InstallPlugin(
+				tenant_id,
+				plugin_unique_identifier,
+				plugin.InstallType,
+				&declaration,
+				source,
+				meta,
+			); err != nil {
+				return entities.NewErrorResponse(-500, err.Error())
+			}
+
+			task.CompletedPlugins++
+			task.Plugins[i].Status = models.InstallTaskStatusSuccess
+			task.Plugins[i].Message = "Installed"
+			continue
+		}
+
+		if err != db.ErrDatabaseNotFound {
 			return entities.NewErrorResponse(-500, err.Error())
 		}
+
+		plugins_wait_for_installation = append(plugins_wait_for_installation, plugin_unique_identifier)
 	}
 
-	return entities.NewSuccessResponse(true)
+	if len(plugins_wait_for_installation) == 0 {
+		response.AllInstalled = true
+		response.TaskID = ""
+		return entities.NewSuccessResponse(response)
+	}
+
+	err := db.Create(task)
+	if err != nil {
+		return entities.NewErrorResponse(-500, err.Error())
+	}
+
+	response.TaskID = task.ID
+
+	manager := plugin_manager.Manager()
+
+	tasks := []func(){}
+	for _, plugin_unique_identifier := range plugins_wait_for_installation {
+		tasks = append(tasks, func() {
+			updateTaskStatus := func(modifier func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus)) {
+				if err := db.WithTransaction(func(tx *gorm.DB) error {
+					task, err := db.GetOne[models.InstallTask](
+						db.WithTransactionContext(tx),
+						db.Equal("id", task.ID),
+						db.WLock(), // write lock, multiple tasks can't update the same task
+					)
+					if err != nil {
+						return err
+					}
+
+					task_pointer := &task
+					var plugin_status *models.InstallTaskPluginStatus
+					for _, plugin := range task.Plugins {
+						if plugin.PluginUniqueIdentifier == plugin_unique_identifier {
+							plugin_status = &plugin
+						}
+					}
+					modifier(task_pointer, plugin_status)
+					return db.Update(task_pointer, tx)
+				}); err != nil {
+					log.Error("failed to update install task status %s", err.Error())
+				}
+			}
+
+			pkg, err := manager.GetPackage(plugin_unique_identifier)
+			if err != nil {
+				updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
+					task.Status = models.InstallTaskStatusFailed
+					plugin.Status = models.InstallTaskStatusFailed
+					plugin.Message = err.Error()
+				})
+				return
+			}
+
+			decoder, err := decoder.NewZipPluginDecoder(pkg)
+			if err != nil {
+				updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
+					task.Status = models.InstallTaskStatusFailed
+					plugin.Status = models.InstallTaskStatusFailed
+					plugin.Message = err.Error()
+				})
+				return
+			}
+
+			updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
+				plugin.Status = models.InstallTaskStatusRunning
+				plugin.Message = "Installing"
+			})
+
+			stream, err := manager.Install(tenant_id, decoder, source, meta)
+			if err != nil {
+				updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
+					task.Status = models.InstallTaskStatusFailed
+					plugin.Status = models.InstallTaskStatusFailed
+					plugin.Message = err.Error()
+				})
+				return
+			}
+
+			for stream.Next() {
+				message, err := stream.Read()
+				if err != nil {
+					updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
+						task.Status = models.InstallTaskStatusFailed
+						plugin.Status = models.InstallTaskStatusFailed
+						plugin.Message = err.Error()
+					})
+					return
+				}
+
+				if message.Event == plugin_manager.PluginInstallEventError {
+					updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
+						task.Status = models.InstallTaskStatusFailed
+						plugin.Status = models.InstallTaskStatusFailed
+						plugin.Message = message.Data
+					})
+					return
+				}
+			}
+
+			updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
+				plugin.Status = models.InstallTaskStatusSuccess
+				plugin.Message = "Installed"
+				task.CompletedPlugins++
+
+				// check if all plugins are installed
+				if task.CompletedPlugins == task.TotalPlugins {
+					task.Status = models.InstallTaskStatusSuccess
+				}
+			})
+		})
+	}
+
+	// submit async tasks
+	routine.WithMaxRoutine(3, tasks, func() {
+		time.AfterFunc(time.Second*5, func() {
+			// get task
+			task, err := db.GetOne[models.InstallTask](
+				db.Equal("id", task.ID),
+			)
+			if err != nil {
+				return
+			}
+
+			if task.CompletedPlugins == task.TotalPlugins {
+				// delete task if all plugins are installed successfully
+				db.Delete(&task)
+			}
+		})
+	})
+
+	return entities.NewSuccessResponse(response)
 }
 
 func FetchPluginInstallationTasks(
@@ -72,14 +257,31 @@ func FetchPluginInstallationTasks(
 	page int,
 	page_size int,
 ) *entities.Response {
-	return nil
+	tasks, err := db.GetAll[models.InstallTask](
+		db.Equal("tenant_id", tenant_id),
+		db.OrderBy("created_at", true),
+		db.Page(page, page_size),
+	)
+	if err != nil {
+		return entities.NewErrorResponse(-500, err.Error())
+	}
+
+	return entities.NewSuccessResponse(tasks)
 }
 
 func FetchPluginInstallationTask(
 	tenant_id string,
 	task_id string,
 ) *entities.Response {
-	return nil
+	task, err := db.GetOne[models.InstallTask](
+		db.Equal("id", task_id),
+		db.Equal("tenant_id", tenant_id),
+	)
+	if err != nil {
+		return entities.NewErrorResponse(-500, err.Error())
+	}
+
+	return entities.NewSuccessResponse(task)
 }
 
 func FetchPluginManifest(
@@ -89,37 +291,6 @@ func FetchPluginManifest(
 	return nil
 }
 
-func InstallPluginFromIdentifier(
-	tenant_id string,
-	plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
-	source string,
-	meta map[string]any,
-) error {
-	// TODO: refactor
-	// check if identifier exists
-	plugin, err := db.GetOne[models.Plugin](
-		db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
-	)
-	if err == db.ErrDatabaseNotFound {
-		return errors.New("plugin not found")
-	}
-	if err != nil {
-		return err
-	}
-
-	if plugin.InstallType == plugin_entities.PLUGIN_RUNTIME_TYPE_REMOTE {
-		return errors.New("remote plugin not supported")
-	}
-
-	declaration := plugin.Declaration
-	// install to this workspace
-	if _, _, err := curd.InstallPlugin(tenant_id, plugin_unique_identifier, plugin.InstallType, &declaration, source, meta); err != nil {
-		return err
-	}
-
-	return nil
-}
-
 func FetchPluginFromIdentifier(
 	plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
 ) *entities.Response {

+ 11 - 6
internal/types/app/config.go

@@ -24,11 +24,12 @@ type Config struct {
 
 	PluginEndpointEnabled bool `envconfig:"PLUGIN_ENDPOINT_ENABLED"`
 
-	PluginStoragePath    string `envconfig:"STORAGE_PLUGIN_PATH" validate:"required"`
-	PluginWorkingPath    string `envconfig:"PLUGIN_WORKING_PATH"`
-	PluginMediaCacheSize uint16 `envconfig:"PLUGIN_MEDIA_CACHE_SIZE"`
-	PluginMediaCachePath string `envconfig:"PLUGIN_MEDIA_CACHE_PATH"`
-	ProcessCachingPath   string `envconfig:"PROCESS_CACHING_PATH"`
+	PluginStoragePath      string `envconfig:"STORAGE_PLUGIN_PATH" validate:"required"`
+	PluginPackageCachePath string `envconfig:"PLUGIN_PACKAGE_CACHE_PATH"`
+	PluginWorkingPath      string `envconfig:"PLUGIN_WORKING_PATH"`
+	PluginMediaCacheSize   uint16 `envconfig:"PLUGIN_MEDIA_CACHE_SIZE"`
+	PluginMediaCachePath   string `envconfig:"PLUGIN_MEDIA_CACHE_PATH"`
+	ProcessCachingPath     string `envconfig:"PROCESS_CACHING_PATH"`
 
 	PluginMaxExecutionTimeout int `envconfig:"PLUGIN_MAX_EXECUTION_TIMEOUT" validate:"required"`
 
@@ -128,10 +129,14 @@ func (c *Config) Validate() error {
 			c.PersistenceStorageS3AccessKey == "" ||
 			c.PersistenceStorageS3SecretKey == "" ||
 			c.PersistenceStorageS3Bucket == "" {
-			return fmt.Errorf("s3 region, access key, secret key, bucket is empty")
+			return fmt.Errorf("s3 region, access key, secret key or bucket is empty")
 		}
 	}
 
+	if c.PluginPackageCachePath == "" {
+		return fmt.Errorf("plugin package cache path is empty")
+	}
+
 	return nil
 }
 

+ 1 - 0
internal/types/app/default.go

@@ -21,6 +21,7 @@ func (config *Config) SetDefault() {
 	setDefaultString(&config.PluginMediaCachePath, "./storage/assets")
 	setDefaultString(&config.PersistenceStorageLocalPath, "./storage/persistence")
 	setDefaultString(&config.ProcessCachingPath, "./storage/subprocesses")
+	setDefaultString(&config.PluginPackageCachePath, "./storage/plugin_packages")
 }
 
 func setDefaultInt[T constraints.Integer](value *T, defaultValue T) {

+ 27 - 0
internal/types/models/task.go

@@ -0,0 +1,27 @@
+package models
+
+import "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+
+type InstallTaskStatus string
+
+const (
+	InstallTaskStatusPending InstallTaskStatus = "pending"
+	InstallTaskStatusRunning InstallTaskStatus = "running"
+	InstallTaskStatusSuccess InstallTaskStatus = "success"
+	InstallTaskStatusFailed  InstallTaskStatus = "failed"
+)
+
+type InstallTaskPluginStatus struct {
+	PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifier"`
+	PluginID               string                                 `json:"plugin_id"`
+	Status                 InstallTaskStatus                      `json:"status"`
+	Message                string                                 `json:"message"`
+}
+
+type InstallTask struct {
+	Model
+	Status           InstallTaskStatus         `json:"status" gorm:"not null"`
+	TotalPlugins     int                       `json:"total_plugins" gorm:"not null"`
+	CompletedPlugins int                       `json:"completed_plugins" gorm:"not null"`
+	Plugins          []InstallTaskPluginStatus `json:"plugins" gorm:"serializer:json"`
+}

+ 40 - 0
internal/utils/routine/pool.go

@@ -2,6 +2,7 @@ package routine
 
 import (
 	"sync"
+	"sync/atomic"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"github.com/panjf2000/ants"
@@ -31,3 +32,42 @@ func InitPool(size int) {
 func Submit(f func()) {
 	p.Submit(f)
 }
+
+func WithMaxRoutine(max_routine int, tasks []func(), on_finish ...func()) {
+	if max_routine <= 0 {
+		max_routine = 1
+	}
+
+	if max_routine > len(tasks) {
+		max_routine = len(tasks)
+	}
+
+	Submit(func() {
+		wg := sync.WaitGroup{}
+		task_index := int32(0)
+
+		for i := 0; i < max_routine; i++ {
+			wg.Add(1)
+			Submit(func() {
+				defer wg.Done()
+				current_index := atomic.AddInt32(&task_index, 1)
+
+				if current_index >= int32(len(tasks)) {
+					return
+				}
+
+				for current_index < int32(len(tasks)) {
+					task := tasks[current_index]
+					task()
+					current_index++
+				}
+			})
+		}
+
+		wg.Wait()
+
+		if len(on_finish) > 0 {
+			on_finish[0]()
+		}
+	})
+}