Explorar el Código

feat: update endpoint

Yeuoly hace 10 meses
padre
commit
d9ed8259a3

+ 3 - 1
internal/server/controllers/endpoint.go

@@ -72,13 +72,15 @@ func UpdateEndpoint(ctx *gin.Context) {
 	BindRequest(ctx, func(request struct {
 		EndpointID string         `json:"endpoint_id" validate:"required"`
 		TenantID   string         `uri:"tenant_id" validate:"required"`
+		UserID     string         `json:"user_id" validate:"required"`
 		Settings   map[string]any `json:"settings" validate:"omitempty"`
 	}) {
 		tenant_id := request.TenantID
+		user_id := request.UserID
 		endpoint_id := request.EndpointID
 		settings := request.Settings
 
-		ctx.JSON(200, service.UpdateEndpoint(endpoint_id, tenant_id, settings))
+		ctx.JSON(200, service.UpdateEndpoint(endpoint_id, tenant_id, user_id, settings))
 	})
 }
 

+ 15 - 5
internal/service/install_service/state.go

@@ -63,17 +63,16 @@ func UninstallPlugin(
 }
 
 // setup a plugin to db,
-// returns the endpoint id
 func InstallEndpoint(
 	plugin_id plugin_entities.PluginUniqueIdentifier,
 	installation_id string,
 	tenant_id string,
 	user_id string,
 	settings map[string]any,
-) (string, error) {
+) (*models.Endpoint, error) {
 	settings_json, err := json.Marshal(settings)
 	if err != nil {
-		return "", err
+		return nil, err
 	}
 
 	installation := &models.Endpoint{
@@ -102,10 +101,10 @@ func InstallEndpoint(
 			}),
 		)
 	}); err != nil {
-		return "", err
+		return nil, err
 	}
 
-	return installation.HookID, nil
+	return installation, nil
 }
 
 func GetEndpoint(
@@ -192,3 +191,14 @@ func DisabledEndpoint(endpoint *models.Endpoint) error {
 		)
 	})
 }
+
+func UpdateEndpoint(endpoint *models.Endpoint, settings map[string]any) error {
+	settings_json, err := json.Marshal(settings)
+	if err != nil {
+		return err
+	}
+
+	endpoint.Settings = string(settings_json)
+
+	return db.Update(endpoint, nil)
+}

+ 108 - 13
internal/service/setup_endpoint.go

@@ -10,6 +10,7 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/encryption"
 )
 
 func SetupEndpoint(
@@ -40,6 +41,17 @@ func SetupEndpoint(
 		return entities.NewErrorResponse(-403, "permission denied")
 	}
 
+	endpoint, err := install_service.InstallEndpoint(
+		plugin_unique_identifier,
+		installation.ID,
+		tenant_id,
+		user_id,
+		map[string]any{},
+	)
+	if err != nil {
+		return entities.NewErrorResponse(-500, fmt.Sprintf("failed to setup endpoint: %v", err))
+	}
+
 	if declaration.Endpoint == nil {
 		return entities.NewErrorResponse(-404, "plugin does not have an endpoint")
 	}
@@ -60,7 +72,7 @@ func SetupEndpoint(
 			InvokeEncryptSchema: dify_invocation.InvokeEncryptSchema{
 				Opt:       dify_invocation.ENCRYPT_OPT_ENCRYPT,
 				Namespace: dify_invocation.ENCRYPT_NAMESPACE_ENDPOINT,
-				Identity:  installation.ID,
+				Identity:  endpoint.ID,
 				Data:      settings,
 				Config:    declaration.Endpoint.Settings,
 			},
@@ -71,15 +83,8 @@ func SetupEndpoint(
 		return entities.NewErrorResponse(-500, fmt.Sprintf("failed to encrypt settings: %v", err))
 	}
 
-	_, err = install_service.InstallEndpoint(
-		plugin_unique_identifier,
-		installation.ID,
-		tenant_id,
-		user_id,
-		encrypted_settings,
-	)
-	if err != nil {
-		return entities.NewErrorResponse(-500, fmt.Sprintf("failed to setup endpoint: %v", err))
+	if err := install_service.UpdateEndpoint(endpoint, encrypted_settings); err != nil {
+		return entities.NewErrorResponse(-500, fmt.Sprintf("failed to update endpoint: %v", err))
 	}
 
 	return entities.NewSuccessResponse(nil)
@@ -102,7 +107,97 @@ func RemoveEndpoint(endpoint_id string, tenant_id string) *entities.Response {
 	return entities.NewSuccessResponse(nil)
 }
 
-func UpdateEndpoint(endpoint_id string, tenant_id string, settings map[string]any) *entities.Response {
-	// TODO
-	return nil
+func UpdateEndpoint(endpoint_id string, tenant_id string, user_id string, settings map[string]any) *entities.Response {
+	// get endpoint
+	endpoint, err := db.GetOne[models.Endpoint](
+		db.Equal("id", endpoint_id),
+		db.Equal("tenant_id", tenant_id),
+	)
+	if err != nil {
+		return entities.NewErrorResponse(-404, fmt.Sprintf("failed to find endpoint: %v", err))
+	}
+
+	// get plugin installation
+	installation, err := db.GetOne[models.PluginInstallation](
+		db.Equal("plugin_id", endpoint.PluginID),
+		db.Equal("tenant_id", tenant_id),
+	)
+	if err != nil {
+		return entities.NewErrorResponse(-404, fmt.Sprintf("failed to find plugin installation: %v", err))
+	}
+
+	// get plugin
+	plugin, err := db.GetOne[models.Plugin](
+		db.Equal("plugin_unique_identifier", installation.PluginUniqueIdentifier),
+	)
+	if err != nil {
+		return entities.NewErrorResponse(-404, fmt.Sprintf("failed to find plugin: %v", err))
+	}
+
+	if plugin.Declaration.Endpoint == nil {
+		return entities.NewErrorResponse(-404, "plugin does not have an endpoint")
+	}
+
+	// decrypt original settings
+	manager := plugin_manager.Manager()
+	if manager == nil {
+		return entities.NewErrorResponse(-500, "failed to get plugin manager")
+	}
+
+	original_settings, err := manager.BackwardsInvocation().InvokeEncrypt(
+		&dify_invocation.InvokeEncryptRequest{
+			BaseInvokeDifyRequest: dify_invocation.BaseInvokeDifyRequest{
+				TenantId: tenant_id,
+				UserId:   user_id,
+				Type:     dify_invocation.INVOKE_TYPE_ENCRYPT,
+			},
+			InvokeEncryptSchema: dify_invocation.InvokeEncryptSchema{
+				Opt:       dify_invocation.ENCRYPT_OPT_DECRYPT,
+				Namespace: dify_invocation.ENCRYPT_NAMESPACE_ENDPOINT,
+				Identity:  installation.ID,
+				Data:      endpoint.GetSettings(),
+				Config:    plugin.Declaration.Endpoint.Settings,
+			},
+		},
+	)
+	if err != nil {
+		return entities.NewErrorResponse(-500, fmt.Sprintf("failed to decrypt settings: %v", err))
+	}
+
+	masked_settings := encryption.MaskConfigCredentials(original_settings, plugin.Declaration.Endpoint.Settings)
+
+	// check if settings is changed, replace the value is the same as masked_settings
+	for setting_name, value := range settings {
+		if masked_settings[setting_name] != value {
+			settings[setting_name] = original_settings[setting_name]
+		}
+	}
+
+	// encrypt settings
+	encrypted_settings, err := manager.BackwardsInvocation().InvokeEncrypt(
+		&dify_invocation.InvokeEncryptRequest{
+			BaseInvokeDifyRequest: dify_invocation.BaseInvokeDifyRequest{
+				TenantId: tenant_id,
+				UserId:   user_id,
+				Type:     dify_invocation.INVOKE_TYPE_ENCRYPT,
+			},
+			InvokeEncryptSchema: dify_invocation.InvokeEncryptSchema{
+				Opt:       dify_invocation.ENCRYPT_OPT_ENCRYPT,
+				Namespace: dify_invocation.ENCRYPT_NAMESPACE_ENDPOINT,
+				Identity:  endpoint.ID,
+				Data:      settings,
+				Config:    plugin.Declaration.Endpoint.Settings,
+			},
+		},
+	)
+	if err != nil {
+		return entities.NewErrorResponse(-500, fmt.Sprintf("failed to encrypt settings: %v", err))
+	}
+
+	// update endpoint
+	if err := install_service.UpdateEndpoint(&endpoint, encrypted_settings); err != nil {
+		return entities.NewErrorResponse(-500, fmt.Sprintf("failed to update endpoint: %v", err))
+	}
+
+	return entities.NewSuccessResponse(nil)
 }

+ 40 - 0
internal/utils/encryption/mask.go

@@ -0,0 +1,40 @@
+package encryption
+
+import (
+	"strings"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+)
+
+func MaskConfigCredentials(
+	credentials map[string]any,
+	provider_config map[string]plugin_entities.ProviderConfig,
+) map[string]any {
+	/*
+		Mask credentials based on provider config
+	*/
+	copied_credentials := make(map[string]any)
+	for key, value := range credentials {
+		if config, ok := provider_config[key]; ok {
+			if config.Type == plugin_entities.CONFIG_TYPE_SECRET_INPUT {
+				if original_value, ok := value.(string); ok {
+					if len(original_value) > 6 {
+						copied_credentials[key] = original_value[:2] +
+							strings.Repeat("*", len(original_value)-4) +
+							original_value[len(original_value)-2:]
+					} else {
+						copied_credentials[key] = strings.Repeat("*", len(original_value))
+					}
+				} else {
+					copied_credentials[key] = value
+				}
+			} else {
+				copied_credentials[key] = value
+			}
+		} else {
+			copied_credentials[key] = value
+		}
+	}
+
+	return copied_credentials
+}