Sfoglia il codice sorgente

feat: add endpoints counter for installation

Yeuoly 10 mesi fa
parent
commit
ab3def1052

+ 17 - 4
internal/db/pgsql.go

@@ -2,6 +2,7 @@ package db
 
 import (
 	"fmt"
+	"strings"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"gorm.io/gorm"
@@ -150,15 +151,27 @@ func WithoutBit[T genericComparableConstraint](field string, value T) GenericQue
 	}
 }
 
-func Inc[T genericComparableConstraint](field string, value T) GenericQuery {
+func Inc[T genericComparableConstraint](updates map[string]T) GenericQuery {
 	return func(tx *gorm.DB) *gorm.DB {
-		return tx.UpdateColumn(field, gorm.Expr(fmt.Sprintf("%s + ?", field), value))
+		expressions := make([]string, 0, len(updates))
+		values := make([]interface{}, 0, len(updates))
+		for field, value := range updates {
+			expressions = append(expressions, fmt.Sprintf("%s = %s + ?", field, field))
+			values = append(values, value)
+		}
+		return tx.UpdateColumns(gorm.Expr(strings.Join(expressions, ", "), values...))
 	}
 }
 
-func Dec[T genericComparableConstraint](field string, value T) GenericQuery {
+func Dec[T genericComparableConstraint](updates map[string]T) GenericQuery {
 	return func(tx *gorm.DB) *gorm.DB {
-		return tx.UpdateColumn(field, gorm.Expr(fmt.Sprintf("%s - ?", field), value))
+		expressions := make([]string, 0, len(updates))
+		values := make([]interface{}, 0, len(updates))
+		for field, value := range updates {
+			expressions = append(expressions, fmt.Sprintf("%s = %s - ?", field, field))
+			values = append(values, value)
+		}
+		return tx.UpdateColumns(gorm.Expr(strings.Join(expressions, ", "), values...))
 	}
 }
 

+ 36 - 6
internal/server/controllers/endpoint.go

@@ -9,8 +9,8 @@ import (
 func SetupEndpoint(ctx *gin.Context) {
 	BindRequest(ctx, func(
 		request struct {
-			PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifier" validate:"required" validate:"plugin_unique_identifier"`
-			TenantID               string                                 `json:"tenant_id" validate:"required"`
+			PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifier" validate:"required,plugin_unique_identifier"`
+			TenantID               string                                 `uri:"tenant_id" validate:"required"`
 			UserID                 string                                 `json:"user_id" validate:"required"`
 			Settings               map[string]any                         `json:"settings" validate:"omitempty"`
 		},
@@ -28,7 +28,7 @@ func SetupEndpoint(ctx *gin.Context) {
 
 func ListEndpoints(ctx *gin.Context) {
 	BindRequest(ctx, func(request struct {
-		TenantID string `form:"tenant_id" validate:"required"`
+		TenantID string `uri:"tenant_id" validate:"required"`
 		Page     int    `form:"page" validate:"required"`
 		PageSize int    `form:"page_size" validate:"required,max=100"`
 	}) {
@@ -40,10 +40,26 @@ func ListEndpoints(ctx *gin.Context) {
 	})
 }
 
+func ListPluginEndpoints(ctx *gin.Context) {
+	BindRequest(ctx, func(request struct {
+		TenantID               string                                 `uri:"tenant_id" validate:"required"`
+		PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `form:"plugin_unique_identifier" validate:"required,plugin_unique_identifier"`
+		Page                   int                                    `form:"page" validate:"required"`
+		PageSize               int                                    `form:"page_size" validate:"required,max=100"`
+	}) {
+		tenant_id := request.TenantID
+		plugin_unique_identifier := request.PluginUniqueIdentifier
+		page := request.Page
+		page_size := request.PageSize
+
+		ctx.JSON(200, service.ListPluginEndpoints(tenant_id, plugin_unique_identifier, page, page_size))
+	})
+}
+
 func RemoveEndpoint(ctx *gin.Context) {
 	BindRequest(ctx, func(request struct {
 		EndpointID string `json:"endpoint_id" validate:"required"`
-		TenantID   string `json:"tenant_id" validate:"required"`
+		TenantID   string `uri:"tenant_id" validate:"required"`
 	}) {
 		endpoint_id := request.EndpointID
 		tenant_id := request.TenantID
@@ -52,10 +68,24 @@ func RemoveEndpoint(ctx *gin.Context) {
 	})
 }
 
+func UpdateEndpoint(ctx *gin.Context) {
+	BindRequest(ctx, func(request struct {
+		EndpointID string         `json:"endpoint_id" validate:"required"`
+		TenantID   string         `uri:"tenant_id" validate:"required"`
+		Settings   map[string]any `json:"settings" validate:"omitempty"`
+	}) {
+		tenant_id := request.TenantID
+		endpoint_id := request.EndpointID
+		settings := request.Settings
+
+		ctx.JSON(200, service.UpdateEndpoint(endpoint_id, tenant_id, settings))
+	})
+}
+
 func EnableEndpoint(ctx *gin.Context) {
 	BindRequest(ctx, func(request struct {
 		EndpointID string `json:"endpoint_id" validate:"required"`
-		TenantID   string `json:"tenant_id" validate:"required"`
+		TenantID   string `uri:"tenant_id" validate:"required"`
 	}) {
 		tenant_id := request.TenantID
 		endpoint_id := request.EndpointID
@@ -67,7 +97,7 @@ func EnableEndpoint(ctx *gin.Context) {
 func DisableEndpoint(ctx *gin.Context) {
 	BindRequest(ctx, func(request struct {
 		EndpointID string `json:"endpoint_id" validate:"required"`
-		TenantID   string `json:"tenant_id" validate:"required"`
+		TenantID   string `uri:"tenant_id" validate:"required"`
 	}) {
 		tenant_id := request.TenantID
 		endpoint_id := request.EndpointID

+ 3 - 1
internal/server/http_server.go

@@ -21,7 +21,6 @@ func (app *App) server(config *app.Config) func() {
 
 	app.endpointGroup(engine.Group("/e"), config)
 	app.awsLambdaTransactionGroup(engine.Group("/backwards-invocation"), config)
-	app.endpointManagementGroup(engine.Group("/endpoint"))
 	app.pluginGroup(engine.Group("/plugin/:tenant_id"), config)
 
 	srv := &http.Server{
@@ -48,6 +47,7 @@ func (app *App) pluginGroup(group *gin.RouterGroup, config *app.Config) {
 	app.remoteDebuggingGroup(group.Group("/debugging"), config)
 	app.pluginDispatchGroup(group.Group("/dispatch"), config)
 	app.pluginManagementGroup(group.Group("/management"), config)
+	app.endpointManagementGroup(group.Group("/endpoint"))
 	app.pluginAssetGroup(group.Group("/asset"))
 }
 
@@ -99,7 +99,9 @@ func (appRef *App) awsLambdaTransactionGroup(group *gin.RouterGroup, config *app
 func (app *App) endpointManagementGroup(group *gin.RouterGroup) {
 	group.POST("/setup", controllers.SetupEndpoint)
 	group.POST("/remove", controllers.RemoveEndpoint)
+	group.POST("/update", controllers.UpdateEndpoint)
 	group.GET("/list", controllers.ListEndpoints)
+	group.GET("/list/plugin", controllers.ListPluginEndpoints)
 	group.POST("/enable", controllers.EnableEndpoint)
 	group.POST("/disable", controllers.DisableEndpoint)
 }

+ 5 - 0
internal/service/endpoint.go

@@ -253,3 +253,8 @@ func ListEndpoints(tenant_id string, page int, page_size int) *entities.Response
 
 	return entities.NewSuccessResponse(endpoints)
 }
+
+func ListPluginEndpoints(tenant_id string, plugin_unique_identifier plugin_entities.PluginUniqueIdentifier, page int, page_size int) *entities.Response {
+	// TODO:
+	return nil
+}

+ 77 - 7
internal/service/install_service/state.go

@@ -9,6 +9,7 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/models/curd"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/strings"
+	"gorm.io/gorm"
 )
 
 func InstallPlugin(
@@ -81,11 +82,26 @@ func InstallEndpoint(
 		TenantID:  tenant_id,
 		UserID:    user_id,
 		Enabled:   true,
-		ExpiredAt: time.Now().Add(time.Hour * 24 * 365 * 10),
+		ExpiredAt: time.Date(2050, 1, 1, 0, 0, 0, 0, time.UTC),
 		Settings:  string(settings_json),
 	}
 
-	if err := db.Create(&installation); err != nil {
+	if err := db.WithTransaction(func(tx *gorm.DB) error {
+		if err := db.Create(&installation, tx); err != nil {
+			return err
+		}
+
+		return db.Run(
+			db.WithTransactionContext(tx),
+			db.Model(models.PluginInstallation{}),
+			db.Equal("plugin_id", installation.PluginID),
+			db.Equal("tenant_id", installation.TenantID),
+			db.Inc(map[string]int{
+				"endpoints_setups": 1,
+				"endpoints_active": 1,
+			}),
+		)
+	}); err != nil {
 		return "", err
 	}
 
@@ -110,15 +126,69 @@ func GetEndpoint(
 
 // uninstalls a plugin from db
 func UninstallEndpoint(endpoint *models.Endpoint) error {
-	return db.Delete(endpoint)
+	return db.WithTransaction(func(tx *gorm.DB) error {
+		if err := db.Delete(endpoint, tx); err != nil {
+			return err
+		}
+
+		// update the plugin installation
+		return db.Run(
+			db.WithTransactionContext(tx),
+			db.Model(models.PluginInstallation{}),
+			db.Equal("plugin_id", endpoint.PluginID),
+			db.Equal("tenant_id", endpoint.TenantID),
+			db.Dec(map[string]int{
+				"endpoints_active": 1,
+				"endpoints_setups": 1,
+			}),
+		)
+	})
 }
 
 func EnabledEndpoint(endpoint *models.Endpoint) error {
-	endpoint.Enabled = true
-	return db.Update(endpoint)
+	if endpoint.Enabled {
+		return nil
+	}
+
+	return db.WithTransaction(func(tx *gorm.DB) error {
+		endpoint.Enabled = true
+		if err := db.Update(endpoint, tx); err != nil {
+			return err
+		}
+
+		// update the plugin installation
+		return db.Run(
+			db.WithTransactionContext(tx),
+			db.Model(models.PluginInstallation{}),
+			db.Equal("plugin_id", endpoint.PluginID),
+			db.Equal("tenant_id", endpoint.TenantID),
+			db.Inc(map[string]int{
+				"endpoints_active": 1,
+			}),
+		)
+	})
 }
 
 func DisabledEndpoint(endpoint *models.Endpoint) error {
-	endpoint.Enabled = false
-	return db.Update(endpoint)
+	if !endpoint.Enabled {
+		return nil
+	}
+
+	return db.WithTransaction(func(tx *gorm.DB) error {
+		endpoint.Enabled = false
+		if err := db.Update(endpoint, tx); err != nil {
+			return err
+		}
+
+		// update the plugin installation
+		return db.Run(
+			db.WithTransactionContext(tx),
+			db.Model(models.PluginInstallation{}),
+			db.Equal("plugin_id", endpoint.PluginID),
+			db.Equal("tenant_id", endpoint.TenantID),
+			db.Dec(map[string]int{
+				"endpoints_active": 1,
+			}),
+		)
+	})
 }

+ 5 - 0
internal/service/setup_endpoint.go

@@ -101,3 +101,8 @@ func RemoveEndpoint(endpoint_id string, tenant_id string) *entities.Response {
 
 	return entities.NewSuccessResponse(nil)
 }
+
+func UpdateEndpoint(endpoint_id string, tenant_id string, settings map[string]any) *entities.Response {
+	// TODO
+	return nil
+}

+ 3 - 3
internal/types/models/base.go

@@ -5,7 +5,7 @@ import (
 )
 
 type Model struct {
-	ID        string `gorm:"column:id;primaryKey;type:uuid;default:uuid_generate_v4()"`
-	CreatedAt time.Time
-	UpdatedAt time.Time
+	ID        string    `gorm:"column:id;primaryKey;type:uuid;default:uuid_generate_v4()" json:"id"`
+	CreatedAt time.Time `json:"created_at"`
+	UpdatedAt time.Time `json:"updated_at"`
 }

+ 2 - 0
internal/types/models/installation.go

@@ -8,4 +8,6 @@ type PluginInstallation struct {
 	PluginID               string `json:"plugin_id" gorm:"index;size:127"`
 	PluginUniqueIdentifier string `json:"plugin_unique_identifier" gorm:"index;size:127"`
 	RuntimeType            string `json:"runtime_type" gorm:"size:127"`
+	EndpointsSetups        int    `json:"endpoints_setups"`
+	EndpointsActive        int    `json:"endpoints_active"`
 }