Browse Source

feat: webhook

Yeuoly 11 months ago
parent
commit
8a1c72f1cd

+ 3 - 2
internal/core/plugin_daemon/generic.go

@@ -40,9 +40,10 @@ func genericInvokePlugin[Req any, Rsp any](
 			chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
 			if err != nil {
 				log.Error("unmarshal json failed: %s", err.Error())
-				return
+				response.WriteError(err)
+			} else {
+				response.Write(chunk)
 			}
-			response.Write(chunk)
 		case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
 			if err := backwards_invocation.InvokeDify(runtime, typ, session, chunk.Data); err != nil {
 				log.Error("invoke dify failed: %s", err.Error())

+ 2 - 0
internal/core/plugin_daemon/webhook_service.go

@@ -61,8 +61,10 @@ func InvokeWebhook(
 				resp.Close()
 				return http.StatusInternalServerError, nil, nil, err
 			}
+
 			response.Write(dehexed)
 			routine.Submit(func() {
+				defer response.Close()
 				for resp.Next() {
 					chunk, err := resp.Read()
 					if err != nil {

+ 2 - 8
internal/core/plugin_manager/init.go

@@ -1,17 +1,11 @@
 package plugin_manager
 
 import (
-	"sync"
-
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 )
 
-var (
-	m sync.Map
-)
-
-func checkPluginExist(identity string) (entities.PluginRuntimeInterface, bool) {
-	if v, ok := m.Load(identity); ok {
+func (m *PluginManager) checkPluginExist(identity string) (entities.PluginRuntimeInterface, bool) {
+	if v, ok := m.m.Load(identity); ok {
 		return v.(entities.PluginRuntimeInterface), true
 	}
 

+ 7 - 0
internal/core/plugin_manager/lifetime.go

@@ -24,6 +24,13 @@ func (p *PluginManager) lifetime(config *app.Config, r entities.PluginRuntimeInt
 		return
 	}
 
+	// add plugin to manager
+	err = p.Add(r)
+	if err != nil {
+		log.Error("add plugin to manager failed: %s", err.Error())
+		return
+	}
+
 	start_failed_times := 0
 
 	// remove lifetime state after plugin if it has been stopped

+ 14 - 10
internal/core/plugin_manager/manager.go

@@ -2,6 +2,7 @@ package plugin_manager
 
 import (
 	"fmt"
+	"sync"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/cluster"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
@@ -12,6 +13,8 @@ import (
 )
 
 type PluginManager struct {
+	m sync.Map
+
 	cluster *cluster.Cluster
 }
 
@@ -30,9 +33,18 @@ func GetGlobalPluginManager() *PluginManager {
 	return manager
 }
 
+func (p *PluginManager) Add(plugin entities.PluginRuntimeInterface) error {
+	identity, err := plugin.Identity()
+	if err != nil {
+		return err
+	}
+	p.m.Store(identity, plugin)
+	return nil
+}
+
 func (p *PluginManager) List() []entities.PluginRuntimeInterface {
 	var runtimes []entities.PluginRuntimeInterface
-	m.Range(func(key, value interface{}) bool {
+	p.m.Range(func(key, value interface{}) bool {
 		if v, ok := value.(entities.PluginRuntimeInterface); ok {
 			runtimes = append(runtimes, v)
 		}
@@ -42,7 +54,7 @@ func (p *PluginManager) List() []entities.PluginRuntimeInterface {
 }
 
 func (p *PluginManager) Get(identity string) entities.PluginRuntimeInterface {
-	if v, ok := m.Load(identity); ok {
+	if v, ok := p.m.Load(identity); ok {
 		if r, ok := v.(entities.PluginRuntimeInterface); ok {
 			return r
 		}
@@ -50,14 +62,6 @@ func (p *PluginManager) Get(identity string) entities.PluginRuntimeInterface {
 	return nil
 }
 
-func (p *PluginManager) Put(path string, binary []byte) {
-	//TODO: put binary into
-}
-
-func (p *PluginManager) Delete(identity string) {
-	//TODO: delete binary from
-}
-
 func (p *PluginManager) Init(configuration *app.Config) {
 	// TODO: init plugin manager
 	log.Info("start plugin manager daemon...")

+ 8 - 1
internal/core/plugin_manager/remote_manager/hooks.go

@@ -87,6 +87,13 @@ func (s *DifyServer) OnClose(c gnet.Conn, err error) (action gnet.Action) {
 	// close plugin
 	plugin.onDisconnected()
 
+	// uninstall plugin
+	if plugin.handshake && plugin.registration_transferred {
+		if err := plugin.Unregister(); err != nil {
+			log.Error("unregister plugin failed, error: %v", err)
+		}
+	}
+
 	return gnet.None
 }
 
@@ -167,7 +174,7 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 		// trigger registration event
 		if err := runtime.Register(); err != nil {
 			runtime.conn.Write([]byte("register failed\n"))
-			log.Error("register failed", "error", err)
+			log.Error("register failed, error: %v", err)
 			runtime.conn.Close()
 			return
 		}

+ 10 - 1
internal/core/plugin_manager/remote_manager/register.go

@@ -3,5 +3,14 @@ package remote_manager
 import "github.com/langgenius/dify-plugin-daemon/internal/service/install_service"
 
 func (plugin *RemotePluginRuntime) Register() error {
-	return install_service.InstallPlugin(plugin.tenant_id, "", plugin, map[string]any{})
+	installation_id, err := install_service.InstallPlugin(plugin.tenant_id, "", plugin, map[string]any{})
+	if err != nil {
+		return err
+	}
+	plugin.installation_id = installation_id
+	return nil
+}
+
+func (plugin *RemotePluginRuntime) Unregister() error {
+	return install_service.UninstallPlugin(plugin.tenant_id, plugin.installation_id, plugin)
 }

+ 4 - 0
internal/core/plugin_manager/remote_manager/type.go

@@ -41,7 +41,11 @@ type RemotePluginRuntime struct {
 
 	alive bool
 
+	// checksum
 	checksum string
+
+	// installation id
+	installation_id string
 }
 
 func (r *RemotePluginRuntime) Identity() (string, error) {

+ 5 - 5
internal/core/plugin_manager/watcher.go

@@ -45,7 +45,7 @@ func (p *PluginManager) startRemoteWatcher(config *app.Config) {
 
 func (p *PluginManager) handleNewPlugins(config *app.Config) {
 	// load local plugins firstly
-	for plugin := range loadNewPlugins(config.StoragePath) {
+	for plugin := range p.loadNewPlugins(config.StoragePath) {
 		var plugin_interface entities.PluginRuntimeInterface
 
 		if config.Platform == app.PLATFORM_AWS_LAMBDA {
@@ -69,7 +69,7 @@ func (p *PluginManager) handleNewPlugins(config *app.Config) {
 }
 
 // chan should be closed after using that
-func loadNewPlugins(root_path string) <-chan entities.PluginRuntime {
+func (p *PluginManager) loadNewPlugins(root_path string) <-chan entities.PluginRuntime {
 	ch := make(chan entities.PluginRuntime)
 
 	plugins, err := os.ReadDir(root_path)
@@ -89,7 +89,7 @@ func loadNewPlugins(root_path string) <-chan entities.PluginRuntime {
 					continue
 				}
 
-				status := verifyPluginStatus(configuration)
+				status := p.verifyPluginStatus(configuration)
 				if status.exist {
 					continue
 				}
@@ -135,8 +135,8 @@ type pluginStatusResult struct {
 	exist bool
 }
 
-func verifyPluginStatus(config *plugin_entities.PluginDeclaration) pluginStatusResult {
-	_, exist := checkPluginExist(config.Identity())
+func (p *PluginManager) verifyPluginStatus(config *plugin_entities.PluginDeclaration) pluginStatusResult {
+	_, exist := p.checkPluginExist(config.Identity())
 	if exist {
 		return pluginStatusResult{
 			exist: true,

+ 0 - 1
internal/server/http_server.go

@@ -81,7 +81,6 @@ func (app *App) server(config *app.Config) func() {
 		engine.POST(
 			"/plugin/debugging/key",
 			CheckingKey(config.PluginInnerApiKey),
-			app.RedirectPluginInvoke(),
 			controllers.GetRemoteDebuggingKey,
 		)
 	}

+ 48 - 3
internal/service/install_service/state.go

@@ -15,10 +15,10 @@ func InstallPlugin(
 	user_id string,
 	runtime entities.PluginRuntimeInterface,
 	configuration map[string]any,
-) error {
+) (string, error) {
 	identity, err := runtime.Identity()
 	if err != nil {
-		return err
+		return "", err
 	}
 
 	plugin := &models.Plugin{
@@ -31,13 +31,44 @@ func InstallPlugin(
 
 	plugin, installation, err := curd.CreatePlugin(tenant_id, user_id, plugin, configuration)
 	if err != nil {
-		return err
+		return "", err
 	}
 
 	// check if there is a webhook for the plugin
 	if runtime.Configuration().Resource.Permission.AllowRegistryWebhook() {
 		_, err := InstallWebhook(plugin.PluginID, installation.ID, tenant_id, user_id)
 		if err != nil {
+			return "", err
+		}
+	}
+
+	return installation.ID, nil
+}
+
+func UninstallPlugin(tenant_id string, installation_id string, runtime entities.PluginRuntimeInterface) error {
+	identity, err := runtime.Identity()
+	if err != nil {
+		return err
+	}
+
+	// delete the plugin from db
+	resp, err := curd.DeletePlugin(tenant_id, identity, installation_id)
+	if err != nil {
+		return err
+	}
+
+	// delete the webhook from db
+	if runtime.Configuration().Resource.Permission.AllowRegistryWebhook() {
+		// get the webhook from db
+		webhook, err := GetWebhook(tenant_id, identity, resp.Installation.ID)
+		if err != nil && err != db.ErrDatabaseNotFound {
+			return err
+		} else if err == db.ErrDatabaseNotFound {
+			return nil
+		}
+
+		err = UninstallWebhook(webhook)
+		if err != nil {
 			return err
 		}
 	}
@@ -64,6 +95,20 @@ func InstallWebhook(plugin_id string, installation_id string, tenant_id string,
 	return installation.HookID, nil
 }
 
+func GetWebhook(tenant_id string, plugin_id string, installation_id string) (*models.Webhook, error) {
+	webhook, err := db.GetOne[models.Webhook](
+		db.Equal("tenant_id", tenant_id),
+		db.Equal("plugin_id", plugin_id),
+		db.Equal("plugin_installation_id", installation_id),
+	)
+
+	if err != nil {
+		return nil, err
+	}
+
+	return &webhook, nil
+}
+
 // uninstalls a plugin from db
 func UninstallWebhook(webhook *models.Webhook) error {
 	return db.Delete(webhook)

+ 4 - 1
internal/service/webhook.go

@@ -2,6 +2,7 @@ package service
 
 import (
 	"bytes"
+	"context"
 	"encoding/hex"
 	"sync/atomic"
 	"time"
@@ -16,7 +17,8 @@ import (
 )
 
 func Webhook(ctx *gin.Context, webhook *models.Webhook, path string) {
-	req := ctx.Request
+	req := ctx.Request.Clone(context.Background())
+	req.URL.Path = path
 
 	var buffer bytes.Buffer
 	err := req.Write(&buffer)
@@ -73,6 +75,7 @@ func Webhook(ctx *gin.Context, webhook *models.Webhook, path string) {
 				return
 			}
 			ctx.Writer.Write(chunk)
+			ctx.Writer.Flush()
 		}
 	})
 

+ 0 - 5
internal/types/entities/from_map.go

@@ -1,5 +0,0 @@
-package entities
-
-type FromMapper interface {
-	FromMap(map[string]any) error
-}

+ 5 - 15
internal/types/models/curd/atomic.go

@@ -18,18 +18,7 @@ func CreatePlugin(tenant_id string, user_id string, plugin *models.Plugin, confi
 	var plugin_to_be_returns *models.Plugin
 	var installation_to_be_returns *models.PluginInstallation
 
-	_, err := db.GetOne[models.PluginInstallation](
-		db.Equal("plugin_id", plugin_to_be_returns.PluginID),
-		db.Equal("tenant_id", tenant_id),
-	)
-
-	if err != nil && err != db.ErrDatabaseNotFound {
-		return nil, nil, err
-	} else if err != nil {
-		return nil, nil, errors.New("plugin has been installed already")
-	}
-
-	err = db.WithTransaction(func(tx *gorm.DB) error {
+	err := db.WithTransaction(func(tx *gorm.DB) error {
 		p, err := db.GetOne[models.Plugin](
 			db.WithTransactionContext(tx),
 			db.Equal("plugin_id", plugin.PluginID),
@@ -89,12 +78,13 @@ type DeletePluginResponse struct {
 // Delete plugin for a tenant, delete the plugin if it has never been created before
 // and uninstall it from the tenant, return the plugin and the installation
 // if the plugin has been created before, return the plugin which has been created before
-func DeletePlugin(tenant_id string, plugin_id string) (*DeletePluginResponse, error) {
+func DeletePlugin(tenant_id string, plugin_id string, installation_id string) (*DeletePluginResponse, error) {
 	var plugin_to_be_returns *models.Plugin
 	var installation_to_be_returns *models.PluginInstallation
 
 	_, err := db.GetOne[models.PluginInstallation](
-		db.Equal("plugin_id", plugin_to_be_returns.PluginID),
+		db.Equal("id", installation_id),
+		db.Equal("plugin_id", plugin_id),
 		db.Equal("tenant_id", tenant_id),
 	)
 
@@ -109,7 +99,7 @@ func DeletePlugin(tenant_id string, plugin_id string) (*DeletePluginResponse, er
 	err = db.WithTransaction(func(tx *gorm.DB) error {
 		p, err := db.GetOne[models.Plugin](
 			db.WithTransactionContext(tx),
-			db.Equal("plugin_id", plugin_to_be_returns.PluginID),
+			db.Equal("plugin_id", plugin_id),
 			db.WLock(),
 		)