ソースを参照

refactor: support upgrade plugin to any verions

Yeuoly 9 ヶ月 前
コミット
eeadb3e29c

+ 0 - 1
internal/core/plugin_manager/install_to_local.go

@@ -13,7 +13,6 @@ import (
 
 // InstallToLocal installs a plugin to local
 func (p *PluginManager) InstallToLocal(
-	tenant_id string,
 	plugin_path string,
 	plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
 	source string,

+ 0 - 19
internal/core/plugin_manager/install_to_serverless.go

@@ -6,16 +6,13 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/serverless"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/decoder"
 	"github.com/langgenius/dify-plugin-daemon/internal/db"
-	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
-	"github.com/langgenius/dify-plugin-daemon/internal/types/models/curd"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
 // InstallToAWSFromPkg installs a plugin to AWS Lambda
 func (p *PluginManager) InstallToAWSFromPkg(
-	tenant_id string,
 	decoder decoder.PluginDecoder,
 	source string,
 	meta map[string]any,
@@ -94,22 +91,6 @@ func (p *PluginManager) InstallToAWSFromPkg(
 					return
 				}
 
-				_, _, err = curd.InstallPlugin(
-					tenant_id,
-					unique_identity,
-					plugin_entities.PLUGIN_RUNTIME_TYPE_AWS,
-					&declaration,
-					source,
-					meta,
-				)
-				if err != nil {
-					new_response.Write(PluginInstallResponse{
-						Event: PluginInstallEventError,
-						Data:  "Failed to create plugin",
-					})
-					return
-				}
-
 				new_response.Write(PluginInstallResponse{
 					Event: PluginInstallEventDone,
 					Data:  "Installed",

+ 40 - 0
internal/core/plugin_manager/manager.go

@@ -10,9 +10,13 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation/real"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/media_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/serverless"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_packager/decoder"
+	"github.com/langgenius/dify-plugin-daemon/internal/db"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache/helper"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/lock"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/mapping"
@@ -146,6 +150,36 @@ func (p *PluginManager) SavePackage(plugin_unique_identifier plugin_entities.Plu
 		return err
 	}
 
+	// try to decode the package
+	package_decoder, err := decoder.NewZipPluginDecoder(pkg)
+	if err != nil {
+		return err
+	}
+
+	// get the declaration
+	declaration, err := package_decoder.Manifest()
+	if err != nil {
+		return err
+	}
+
+	unique_identifier, err := package_decoder.UniqueIdentity()
+	if err != nil {
+		return err
+	}
+
+	// create plugin if not exists
+	if _, err := db.GetOne[models.PluginDeclaration](
+		db.Equal("plugin_unique_identifier", unique_identifier.String()),
+	); err == db.ErrDatabaseNotFound {
+		return db.Create(&models.PluginDeclaration{
+			PluginUniqueIdentifier: unique_identifier.String(),
+			PluginID:               unique_identifier.PluginID(),
+			Declaration:            declaration,
+		})
+	} else if err != nil {
+		return err
+	}
+
 	return nil
 }
 
@@ -165,3 +199,9 @@ func (p *PluginManager) GetPackage(plugin_unique_identifier plugin_entities.Plug
 func (p *PluginManager) GetPackagePath(plugin_unique_identifier plugin_entities.PluginUniqueIdentifier) (string, error) {
 	return filepath.Join(p.packageCachePath, plugin_unique_identifier.String()), nil
 }
+
+func (p *PluginManager) GetDeclaration(plugin_unique_identifier plugin_entities.PluginUniqueIdentifier) (
+	*plugin_entities.PluginDeclaration, error,
+) {
+	return helper.CombinedGetPluginDeclaration(plugin_unique_identifier)
+}

+ 1 - 0
internal/db/init.go

@@ -80,6 +80,7 @@ func autoMigrate() error {
 	return DifyPluginDB.AutoMigrate(
 		models.Plugin{},
 		models.PluginInstallation{},
+		models.PluginDeclaration{},
 		models.Endpoint{},
 		models.ServerlessRuntime{},
 		models.ToolInstallation{},

+ 21 - 0
internal/server/controllers/plugins.go

@@ -55,6 +55,27 @@ func UploadPlugin(app *app.Config) gin.HandlerFunc {
 	}
 }
 
+func UpgradePlugin(app *app.Config) gin.HandlerFunc {
+	return func(c *gin.Context) {
+		BindRequest(c, func(request struct {
+			TenantID                       string                                 `uri:"tenant_id" validate:"required"`
+			OriginalPluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `json:"original_plugin_unique_identifier" validate:"required,plugin_unique_identifier"`
+			NewPluginUniqueIdentifier      plugin_entities.PluginUniqueIdentifier `json:"new_plugin_unique_identifier" validate:"required,plugin_unique_identifier"`
+			Source                         string                                 `json:"source" validate:"required"`
+			Meta                           map[string]any                         `json:"meta" validate:"omitempty"`
+		}) {
+			c.JSON(http.StatusOK, service.UpgradePlugin(
+				app,
+				request.TenantID,
+				request.Source,
+				request.Meta,
+				request.OriginalPluginUniqueIdentifier,
+				request.NewPluginUniqueIdentifier,
+			))
+		})
+	}
+}
+
 func InstallPluginFromIdentifiers(app *app.Config) gin.HandlerFunc {
 	return func(c *gin.Context) {
 		BindRequest(c, func(request struct {

+ 1 - 0
internal/server/http_server.go

@@ -115,6 +115,7 @@ func (app *App) endpointManagementGroup(group *gin.RouterGroup) {
 func (app *App) pluginManagementGroup(group *gin.RouterGroup, config *app.Config) {
 	group.POST("/install/upload", controllers.UploadPlugin(config))
 	group.POST("/install/identifiers", controllers.InstallPluginFromIdentifiers(config))
+	group.POST("/install/upgrade", controllers.UpgradePlugin(config))
 	group.GET("/install/tasks/:id", controllers.FetchPluginInstallationTask)
 	group.POST("/install/tasks/:id/delete", controllers.DeletePluginInstallationTask)
 	group.POST("/install/tasks/:id/delete/*identifier", controllers.DeletePluginInstallationItemFromTask)

+ 130 - 13
internal/service/install_plugin.go

@@ -18,19 +18,26 @@ import (
 	"gorm.io/gorm"
 )
 
-func InstallPluginFromIdentifiers(
+type InstallPluginResponse struct {
+	AllInstalled bool   `json:"all_installed"`
+	TaskID       string `json:"task_id"`
+}
+
+type InstallPluginOnDoneHandler func(
+	plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+	declaration *plugin_entities.PluginDeclaration,
+) error
+
+func InstallPluginRuntimeToTenant(
 	config *app.Config,
 	tenant_id string,
 	plugin_unique_identifiers []plugin_entities.PluginUniqueIdentifier,
 	source string,
 	meta map[string]any,
-) *entities.Response {
-	var response struct {
-		AllInstalled bool   `json:"all_installed"`
-		TaskID       string `json:"task_id"`
-	}
+	on_done InstallPluginOnDoneHandler, // since installing plugin is a async task, we need to call it asynchronously
+) (*InstallPluginResponse, error) {
+	response := &InstallPluginResponse{}
 
-	// TODO: create installation task and dispatch to workers
 	plugins_wait_for_installation := []plugin_entities.PluginUniqueIdentifier{}
 
 	task := &models.InstallTask{
@@ -65,7 +72,7 @@ func InstallPluginFromIdentifiers(
 				source,
 				meta,
 			); err != nil {
-				return entities.NewErrorResponse(-500, err.Error())
+				return nil, err
 			}
 
 			task.CompletedPlugins++
@@ -75,7 +82,7 @@ func InstallPluginFromIdentifiers(
 		}
 
 		if err != db.ErrDatabaseNotFound {
-			return entities.NewErrorResponse(-500, err.Error())
+			return nil, err
 		}
 
 		plugins_wait_for_installation = append(plugins_wait_for_installation, plugin_unique_identifier)
@@ -84,12 +91,12 @@ func InstallPluginFromIdentifiers(
 	if len(plugins_wait_for_installation) == 0 {
 		response.AllInstalled = true
 		response.TaskID = ""
-		return entities.NewSuccessResponse(response)
+		return response, nil
 	}
 
 	err := db.Create(task)
 	if err != nil {
-		return entities.NewErrorResponse(-500, err.Error())
+		return nil, err
 	}
 
 	response.TaskID = task.ID
@@ -97,6 +104,14 @@ func InstallPluginFromIdentifiers(
 
 	tasks := []func(){}
 	for _, plugin_unique_identifier := range plugins_wait_for_installation {
+		// copy the variable to avoid race condition
+		plugin_unique_identifier := plugin_unique_identifier
+
+		declaration, err := manager.GetDeclaration(plugin_unique_identifier)
+		if err != nil {
+			return nil, err
+		}
+
 		tasks = append(tasks, func() {
 			updateTaskStatus := func(modifier func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus)) {
 				if err := db.WithTransaction(func(tx *gorm.DB) error {
@@ -173,9 +188,9 @@ func InstallPluginFromIdentifiers(
 					})
 					return
 				}
-				stream, err = manager.InstallToAWSFromPkg(tenant_id, zip_decoder, source, meta)
+				stream, err = manager.InstallToAWSFromPkg(zip_decoder, source, meta)
 			} else if config.Platform == app.PLATFORM_LOCAL {
-				stream, err = manager.InstallToLocal(tenant_id, pkg_path, plugin_unique_identifier, source, meta)
+				stream, err = manager.InstallToLocal(pkg_path, plugin_unique_identifier, source, meta)
 			} else {
 				updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
 					task.Status = models.InstallTaskStatusFailed
@@ -213,6 +228,17 @@ func InstallPluginFromIdentifiers(
 					})
 					return
 				}
+
+				if message.Event == plugin_manager.PluginInstallEventDone {
+					if err := on_done(plugin_unique_identifier, declaration); err != nil {
+						updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
+							task.Status = models.InstallTaskStatusFailed
+							plugin.Status = models.InstallTaskStatusFailed
+							plugin.Message = "Failed to create plugin"
+						})
+						return
+					}
+				}
 			}
 
 			updateTaskStatus(func(task *models.InstallTask, plugin *models.InstallTaskPluginStatus) {
@@ -231,6 +257,97 @@ func InstallPluginFromIdentifiers(
 	// submit async tasks
 	routine.WithMaxRoutine(3, tasks)
 
+	return response, nil
+}
+
+func InstallPluginFromIdentifiers(
+	config *app.Config,
+	tenant_id string,
+	plugin_unique_identifiers []plugin_entities.PluginUniqueIdentifier,
+	source string,
+	meta map[string]any,
+) *entities.Response {
+	response, err := InstallPluginRuntimeToTenant(config, tenant_id, plugin_unique_identifiers, source, meta, func(
+		plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+		declaration *plugin_entities.PluginDeclaration,
+	) error {
+		_, _, err := curd.InstallPlugin(
+			tenant_id,
+			plugin_unique_identifier,
+			plugin_entities.PLUGIN_RUNTIME_TYPE_AWS,
+			declaration,
+			source,
+			meta,
+		)
+		return err
+	})
+	if err != nil {
+		return entities.NewErrorResponse(-500, err.Error())
+	}
+
+	return entities.NewSuccessResponse(response)
+}
+
+func UpgradePlugin(
+	config *app.Config,
+	tenant_id string,
+	source string,
+	meta map[string]any,
+	original_plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+	new_plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+) *entities.Response {
+	if original_plugin_unique_identifier == new_plugin_unique_identifier {
+		return entities.NewErrorResponse(-400, "original and new plugin unique identifier are the same")
+	}
+
+	// uninstall the original plugin
+	installation, err := db.GetOne[models.PluginInstallation](
+		db.Equal("tenant_id", tenant_id),
+		db.Equal("plugin_unique_identifier", original_plugin_unique_identifier.String()),
+		db.Equal("source", source),
+	)
+
+	if err == db.ErrDatabaseNotFound {
+		return entities.NewErrorResponse(-404, "Plugin installation not found for this tenant")
+	}
+
+	if err != nil {
+		return entities.NewErrorResponse(-500, err.Error())
+	}
+
+	// install the new plugin runtime
+	response, err := InstallPluginRuntimeToTenant(
+		config,
+		tenant_id,
+		[]plugin_entities.PluginUniqueIdentifier{new_plugin_unique_identifier},
+		source,
+		meta,
+		func(
+			plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+			declaration *plugin_entities.PluginDeclaration,
+		) error {
+			// uninstall the original plugin
+			err = curd.UpgradePlugin(
+				tenant_id,
+				original_plugin_unique_identifier,
+				new_plugin_unique_identifier,
+				declaration,
+				plugin_entities.PluginRuntimeType(installation.RuntimeType),
+				source,
+				meta,
+			)
+
+			if err != nil {
+				return err
+			}
+
+			return nil
+		},
+	)
+	if err != nil {
+		return entities.NewErrorResponse(-500, err.Error())
+	}
+
 	return entities.NewSuccessResponse(response)
 }
 

+ 92 - 1
internal/types/models/curd/atomic.go

@@ -148,7 +148,11 @@ type DeletePluginResponse struct {
 // Delete plugin for a tenant, delete the plugin if it has never been created before
 // and uninstall it from the tenant, return the plugin and the installation
 // if the plugin has been created before, return the plugin which has been created before
-func UninstallPlugin(tenant_id string, plugin_unique_identifier plugin_entities.PluginUniqueIdentifier, installation_id string) (*DeletePluginResponse, error) {
+func UninstallPlugin(
+	tenant_id string,
+	plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+	installation_id string,
+) (*DeletePluginResponse, error) {
 	var plugin_to_be_returns *models.Plugin
 	var installation_to_be_returns *models.PluginInstallation
 
@@ -253,3 +257,90 @@ func UninstallPlugin(tenant_id string, plugin_unique_identifier plugin_entities.
 		IsPluginDeleted: plugin_to_be_returns.Refers == 0,
 	}, nil
 }
+
+// Upgrade plugin for a tenant, upgrade the plugin if it has been created before
+// and uninstall the original plugin and install the new plugin, but keep the original installation information
+// like endpoint_setups, etc.
+func UpgradePlugin(
+	tenant_id string,
+	original_plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+	new_plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+	new_declaration *plugin_entities.PluginDeclaration,
+	install_type plugin_entities.PluginRuntimeType,
+	source string,
+	meta map[string]any,
+) error {
+	return db.WithTransaction(func(tx *gorm.DB) error {
+		installation, err := db.GetOne[models.PluginInstallation](
+			db.WithTransactionContext(tx),
+			db.Equal("plugin_unique_identifier", original_plugin_unique_identifier.String()),
+			db.Equal("tenant_id", tenant_id),
+			db.WLock(),
+		)
+
+		if err == db.ErrDatabaseNotFound {
+			return errors.New("plugin has not been installed")
+		} else if err != nil {
+			return err
+		}
+
+		// check if the new plugin has existed
+		plugin, err := db.GetOne[models.Plugin](
+			db.WithTransactionContext(tx),
+			db.Equal("plugin_unique_identifier", new_plugin_unique_identifier.String()),
+		)
+
+		if err == db.ErrDatabaseNotFound {
+			// create new plugin
+			plugin = models.Plugin{
+				PluginID:               new_plugin_unique_identifier.PluginID(),
+				PluginUniqueIdentifier: new_plugin_unique_identifier.String(),
+				InstallType:            install_type,
+				Refers:                 0,
+				Declaration:            *new_declaration,
+				ManifestType:           plugin_entities.PluginType,
+			}
+
+			err := db.Create(&plugin, tx)
+			if err != nil {
+				return err
+			}
+		} else if err != nil {
+			return err
+		}
+
+		// update exists installation
+		installation.PluginUniqueIdentifier = new_plugin_unique_identifier.String()
+		installation.Meta = meta
+		err = db.Update(installation, tx)
+		if err != nil {
+			return err
+		}
+
+		// decrease the refers of the original plugin
+		err = db.Run(
+			db.WithTransactionContext(tx),
+			db.Model(&models.Plugin{}),
+			db.Equal("plugin_unique_identifier", original_plugin_unique_identifier.String()),
+			db.Inc(map[string]int{"refers": -1}),
+		)
+
+		if err != nil {
+			return err
+		}
+
+		// increase the refers of the new plugin
+		err = db.Run(
+			db.WithTransactionContext(tx),
+			db.Model(&models.Plugin{}),
+			db.Equal("plugin_unique_identifier", new_plugin_unique_identifier.String()),
+			db.Inc(map[string]int{"refers": 1}),
+		)
+
+		if err != nil {
+			return err
+		}
+
+		return nil
+	})
+}

+ 7 - 0
internal/types/models/plugin.go

@@ -31,3 +31,10 @@ type ServerlessRuntime struct {
 	Declaration            plugin_entities.PluginDeclaration `json:"declaration" gorm:"serializer:json;type:text;size:65535"`
 	Checksum               string                            `json:"checksum" gorm:"size:127;index"`
 }
+
+type PluginDeclaration struct {
+	Model
+	PluginUniqueIdentifier string                            `json:"plugin_unique_identifier" gorm:"size:127;unique"`
+	PluginID               string                            `json:"plugin_id" gorm:"size:127;index"`
+	Declaration            plugin_entities.PluginDeclaration `json:"declaration" gorm:"serializer:json;type:text;size:65535"`
+}

+ 11 - 0
internal/utils/cache/helper/redis.go

@@ -11,6 +11,17 @@ func CombinedGetPluginDeclaration(plugin_unique_identifier plugin_entities.Plugi
 	return cache.AutoGetWithGetter(
 		plugin_unique_identifier.String(),
 		func() (*plugin_entities.PluginDeclaration, error) {
+			declaration, err := db.GetOne[models.PluginDeclaration](
+				db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
+			)
+			if err != nil && err != db.ErrDatabaseNotFound {
+				return nil, err
+			}
+
+			if err == nil {
+				return &declaration.Declaration, nil
+			}
+
 			model, err := db.GetOne[models.Plugin](
 				db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
 			)