Sfoglia il codice sorgente

feat: enable disable endpoints

Yeuoly 10 mesi fa
parent
commit
b3d619bc03

+ 6 - 0
internal/server/constants/constants.go

@@ -0,0 +1,6 @@
+package constants
+
+const (
+	X_PLUGIN_IDENTIFIER = "X-Plugin-Identifier"
+	X_API_KEY           = "X-Api-Key"
+)

+ 15 - 0
internal/server/controllers/base.go

@@ -2,7 +2,9 @@ package controllers
 
 import (
 	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/server/constants"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
 )
 
@@ -32,3 +34,16 @@ func BindRequest[T any](r *gin.Context, success func(T)) {
 
 	success(request)
 }
+
+func BindRequestWithPluginUniqueIdentifier[T any](r *gin.Context, success func(T, plugin_entities.PluginUniqueIdentifier)) {
+	BindRequest(r, func(req T) {
+		plugin_unique_identifier := r.GetHeader(constants.X_PLUGIN_IDENTIFIER)
+		if plugin_unique_identifier == "" {
+			resp := entities.NewErrorResponse(-400, "Plugin unique identifier is required")
+			r.JSON(400, resp)
+			return
+		}
+
+		success(req, plugin_entities.PluginUniqueIdentifier(plugin_unique_identifier))
+	})
+}

+ 37 - 12
internal/server/controllers/endpoint.go

@@ -7,19 +7,20 @@ import (
 )
 
 func SetupEndpoint(ctx *gin.Context) {
-	BindRequest(ctx, func(request struct {
-		PluginUniqueIdentifier string         `json:"plugin_unique_identifier" binding:"required"`
-		TenantID               string         `json:"tenant_id" binding:"required"`
-		UserID                 string         `json:"user_id" binding:"required"`
-		Settings               map[string]any `json:"settings" binding:"omitempty"`
-	}) {
-		plugin_unique_identifier := request.PluginUniqueIdentifier
+	BindRequestWithPluginUniqueIdentifier(ctx, func(
+		request struct {
+			TenantID string         `json:"tenant_id" binding:"required"`
+			UserID   string         `json:"user_id" binding:"required"`
+			Settings map[string]any `json:"settings" binding:"omitempty"`
+		},
+		plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+	) {
 		tenant_id := request.TenantID
 		user_id := request.UserID
 		settings := request.Settings
 
 		ctx.JSON(200, service.SetupEndpoint(
-			tenant_id, user_id, plugin_entities.PluginUniqueIdentifier(plugin_unique_identifier), settings,
+			tenant_id, user_id, plugin_unique_identifier, settings,
 		))
 	})
 }
@@ -40,12 +41,36 @@ func ListEndpoints(ctx *gin.Context) {
 
 func RemoveEndpoint(ctx *gin.Context) {
 	BindRequest(ctx, func(request struct {
-		PluginUniqueIdentifier string `json:"plugin_unique_identifier"`
-		TenantID               string `json:"tenant_id"`
+		EndpointID string `json:"endpoint_id" binding:"required"`
+		TenantID   string `json:"tenant_id" binding:"required"`
+	}) {
+		endpoint_id := request.EndpointID
+		tenant_id := request.TenantID
+
+		ctx.JSON(200, service.RemoveEndpoint(endpoint_id, tenant_id))
+	})
+}
+
+func EnableEndpoint(ctx *gin.Context) {
+	BindRequest(ctx, func(request struct {
+		EndpointID string `json:"endpoint_id" binding:"required"`
+		TenantID   string `json:"tenant_id" binding:"required"`
+	}) {
+		tenant_id := request.TenantID
+		endpoint_id := request.EndpointID
+
+		ctx.JSON(200, service.EnableEndpoint(endpoint_id, tenant_id))
+	})
+}
+
+func DisableEndpoint(ctx *gin.Context) {
+	BindRequest(ctx, func(request struct {
+		EndpointID string `json:"endpoint_id" binding:"required"`
+		TenantID   string `json:"tenant_id" binding:"required"`
 	}) {
-		plugin_unique_identifier := request.PluginUniqueIdentifier
 		tenant_id := request.TenantID
+		endpoint_id := request.EndpointID
 
-		ctx.JSON(200, service.RemoveEndpoint(plugin_unique_identifier, tenant_id))
+		ctx.JSON(200, service.DisableEndpoint(endpoint_id, tenant_id))
 	})
 }

+ 2 - 0
internal/server/http_server.go

@@ -93,6 +93,8 @@ func (app *App) endpointManagementGroup(group *gin.RouterGroup) {
 	group.POST("/setup", controllers.SetupEndpoint)
 	group.POST("/remove", controllers.RemoveEndpoint)
 	group.GET("/list", controllers.ListEndpoints)
+	group.POST("/enable", controllers.EnableEndpoint)
+	group.POST("/disable", controllers.DisableEndpoint)
 }
 
 func (app *App) pluginGroup(group *gin.RouterGroup, config *app.Config) {

+ 3 - 7
internal/server/middleware.go

@@ -5,19 +5,15 @@ import (
 	"io"
 
 	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/server/constants"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 )
 
-const (
-	X_PLUGIN_IDENTIFIER = "X-Plugin-Identifier"
-	X_API_KEY           = "X-Api-Key"
-)
-
 func CheckingKey(key string) gin.HandlerFunc {
 	return func(c *gin.Context) {
 		// get header X-Api-Key
-		if c.GetHeader(X_API_KEY) != key {
+		if c.GetHeader(constants.X_API_KEY) != key {
 			c.JSON(401, gin.H{"error": "Unauthorized"})
 			c.Abort()
 			return
@@ -53,7 +49,7 @@ func (app *App) RedirectPluginInvoke() gin.HandlerFunc {
 			reader: bytes.NewReader(raw),
 		}
 
-		identity := plugin_entities.PluginUniqueIdentifier(ctx.Request.Header.Get(X_PLUGIN_IDENTIFIER))
+		identity := plugin_entities.PluginUniqueIdentifier(ctx.Request.Header.Get(constants.X_PLUGIN_IDENTIFIER))
 		if identity == "" {
 			ctx.AbortWithStatusJSON(400, gin.H{"error": "Invalid request"})
 			return

+ 39 - 0
internal/service/endpoint.go

@@ -13,6 +13,9 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/db"
+	"github.com/langgenius/dify-plugin-daemon/internal/service/install_service"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
@@ -135,3 +138,39 @@ func Endpoint(
 		ctx.JSON(500, gin.H{"error": "killed by timeout"})
 	}
 }
+
+func EnableEndpoint(endpoint_id string, tenant_id string) *entities.Response {
+	endpoint, err := db.GetOne[models.Endpoint](
+		db.Equal("id", endpoint_id),
+		db.Equal("tenant_id", tenant_id),
+	)
+	if err != nil {
+		return entities.NewErrorResponse(-404, "Endpoint not found")
+	}
+
+	endpoint.Enabled = true
+
+	if err := install_service.EnabledEndpoint(&endpoint); err != nil {
+		return entities.NewErrorResponse(-500, "Failed to enable endpoint")
+	}
+
+	return entities.NewSuccessResponse("success")
+}
+
+func DisableEndpoint(endpoint_id string, tenant_id string) *entities.Response {
+	endpoint, err := db.GetOne[models.Endpoint](
+		db.Equal("id", endpoint_id),
+		db.Equal("tenant_id", tenant_id),
+	)
+	if err != nil {
+		return entities.NewErrorResponse(-404, "Endpoint not found")
+	}
+
+	endpoint.Enabled = false
+
+	if err := install_service.DisabledEndpoint(&endpoint); err != nil {
+		return entities.NewErrorResponse(-500, "Failed to disable endpoint")
+	}
+
+	return entities.NewSuccessResponse("success")
+}

+ 11 - 0
internal/service/install_service/state.go

@@ -81,6 +81,7 @@ func InstallEndpoint(
 		PluginID:  plugin_id.PluginID(),
 		TenantID:  tenant_id,
 		UserID:    user_id,
+		Enabled:   true,
 		ExpiredAt: time.Now().Add(time.Hour * 24 * 365 * 10),
 		Settings:  string(settings_json),
 	}
@@ -112,3 +113,13 @@ func GetEndpoint(
 func UninstallEndpoint(endpoint *models.Endpoint) error {
 	return db.Delete(endpoint)
 }
+
+func EnabledEndpoint(endpoint *models.Endpoint) error {
+	endpoint.Enabled = true
+	return db.Update(endpoint)
+}
+
+func DisabledEndpoint(endpoint *models.Endpoint) error {
+	endpoint.Enabled = false
+	return db.Update(endpoint)
+}

+ 1 - 0
internal/types/models/endpoint.go

@@ -14,6 +14,7 @@ type Endpoint struct {
 	UserID    string    `json:"user_id" orm:"index;size:64;column:user_id"`
 	PluginID  string    `json:"plugin_id" orm:"index;size:64;column:plugin_id"`
 	ExpiredAt time.Time `json:"expired_at" orm:"column:expired_at"`
+	Enabled   bool      `json:"enabled" orm:"column:enabled"`
 	Settings  string    `json:"settings" orm:"column:settings;size:2048"`
 }