Przeglądaj źródła

refactor: using finer granularity to cache PluginDeclaration

Yeuoly 8 miesięcy temu
rodzic
commit
eb1456b0d0

+ 5 - 1
internal/core/plugin_manager/manager.go

@@ -227,8 +227,12 @@ func (p *PluginManager) GetPackage(
 
 func (p *PluginManager) GetDeclaration(
 	plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+	tenant_id string,
+	runtime_type plugin_entities.PluginRuntimeType,
 ) (
 	*plugin_entities.PluginDeclaration, error,
 ) {
-	return helper.CombinedGetPluginDeclaration(plugin_unique_identifier)
+	return helper.CombinedGetPluginDeclaration(
+		plugin_unique_identifier, tenant_id, runtime_type,
+	)
 }

+ 10 - 3
internal/service/endpoint.go

@@ -20,7 +20,6 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
-	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache/helper"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/encryption"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
@@ -234,7 +233,11 @@ func ListEndpoints(tenant_id string, page int, page_size int) *entities.Response
 			return entities.NewErrorResponse(-500, fmt.Sprintf("failed to parse plugin unique identifier: %v", err))
 		}
 
-		pluginDeclaration, err := helper.CombinedGetPluginDeclaration(pluginUniqueIdentifier)
+		pluginDeclaration, err := manager.GetDeclaration(
+			pluginUniqueIdentifier,
+			tenant_id,
+			plugin_entities.PluginRuntimeType(pluginInstallation.RuntimeType),
+		)
 		if err != nil {
 			return entities.NewErrorResponse(-500, fmt.Sprintf("failed to get plugin declaration: %v", err))
 		}
@@ -308,7 +311,11 @@ func ListPluginEndpoints(tenant_id string, plugin_id string, page int, page_size
 			return entities.NewErrorResponse(-500, fmt.Sprintf("failed to parse plugin unique identifier: %v", err))
 		}
 
-		pluginDeclaration, err := helper.CombinedGetPluginDeclaration(pluginUniqueIdentifier)
+		pluginDeclaration, err := manager.GetDeclaration(
+			pluginUniqueIdentifier,
+			tenant_id,
+			plugin_entities.PluginRuntimeType(pluginInstallation.RuntimeType),
+		)
 		if err != nil {
 			return entities.NewErrorResponse(-500, fmt.Sprintf("failed to get plugin declaration: %v", err))
 		}

+ 19 - 3
internal/service/install_plugin.go

@@ -37,9 +37,17 @@ func InstallPluginRuntimeToTenant(
 	onDone InstallPluginOnDoneHandler, // since installing plugin is a async task, we need to call it asynchronously
 ) (*InstallPluginResponse, error) {
 	response := &InstallPluginResponse{}
-
 	pluginsWaitForInstallation := []plugin_entities.PluginUniqueIdentifier{}
 
+	runtimeType := plugin_entities.PluginRuntimeType("")
+	if config.Platform == app.PLATFORM_AWS_LAMBDA {
+		runtimeType = plugin_entities.PLUGIN_RUNTIME_TYPE_AWS
+	} else if config.Platform == app.PLATFORM_LOCAL {
+		runtimeType = plugin_entities.PLUGIN_RUNTIME_TYPE_LOCAL
+	} else {
+		return nil, fmt.Errorf("unsupported platform: %s", config.Platform)
+	}
+
 	task := &models.InstallTask{
 		Status:           models.InstallTaskStatusRunning,
 		TenantID:         tenant_id,
@@ -50,7 +58,11 @@ func InstallPluginRuntimeToTenant(
 
 	for i, pluginUniqueIdentifier := range plugin_unique_identifiers {
 		// fetch plugin declaration first, before installing, we need to ensure pkg is uploaded
-		pluginDeclaration, err := helper.CombinedGetPluginDeclaration(pluginUniqueIdentifier)
+		pluginDeclaration, err := helper.CombinedGetPluginDeclaration(
+			pluginUniqueIdentifier,
+			tenant_id,
+			runtimeType,
+		)
 		if err != nil {
 			return nil, err
 		}
@@ -103,7 +115,11 @@ func InstallPluginRuntimeToTenant(
 		// copy the variable to avoid race condition
 		pluginUniqueIdentifier := pluginUniqueIdentifier
 
-		declaration, err := manager.GetDeclaration(pluginUniqueIdentifier)
+		declaration, err := helper.CombinedGetPluginDeclaration(
+			pluginUniqueIdentifier,
+			tenant_id,
+			runtimeType,
+		)
 		if err != nil {
 			return nil, err
 		}

+ 5 - 1
internal/service/manage_plugin.go

@@ -49,7 +49,11 @@ func ListPlugins(tenant_id string, page int, page_size int) *entities.Response {
 			return entities.NewErrorResponse(-500, err.Error())
 		}
 
-		pluginDeclaration, err := helper.CombinedGetPluginDeclaration(pluginUniqueIdentifier)
+		pluginDeclaration, err := helper.CombinedGetPluginDeclaration(
+			pluginUniqueIdentifier,
+			tenant_id,
+			plugin_entities.PluginRuntimeType(plugin_installation.RuntimeType),
+		)
 		if err != nil {
 			return entities.NewErrorResponse(-500, err.Error())
 		}

+ 10 - 2
internal/service/setup_endpoint.go

@@ -31,7 +31,11 @@ func SetupEndpoint(
 	}
 
 	// try get plugin
-	pluginDeclaration, err := helper.CombinedGetPluginDeclaration(pluginUniqueIdentifier)
+	pluginDeclaration, err := helper.CombinedGetPluginDeclaration(
+		pluginUniqueIdentifier,
+		tenant_id,
+		plugin_entities.PluginRuntimeType(installation.RuntimeType),
+	)
 	if err != nil {
 		return entities.NewErrorResponse(-404, fmt.Sprintf("failed to find plugin: %v", err))
 	}
@@ -160,7 +164,11 @@ func UpdateEndpoint(endpoint_id string, tenant_id string, user_id string, name s
 	}
 
 	// get plugin
-	pluginDeclaration, err := helper.CombinedGetPluginDeclaration(pluginUniqueIdentifier)
+	pluginDeclaration, err := helper.CombinedGetPluginDeclaration(
+		pluginUniqueIdentifier,
+		tenant_id,
+		plugin_entities.PluginRuntimeType(installation.RuntimeType),
+	)
 	if err != nil {
 		return entities.NewErrorResponse(-404, fmt.Sprintf("failed to find plugin: %v", err))
 	}

+ 40 - 11
internal/utils/cache/helper/redis.go

@@ -1,6 +1,8 @@
 package helper
 
 import (
+	"strings"
+
 	"github.com/langgenius/dify-plugin-daemon/internal/db"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/models"
@@ -9,22 +11,49 @@ import (
 
 func CombinedGetPluginDeclaration(
 	plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+	tenant_id string,
+	runtime_type plugin_entities.PluginRuntimeType,
 ) (*plugin_entities.PluginDeclaration, error) {
 	return cache.AutoGetWithGetter(
-		plugin_unique_identifier.String(),
+		strings.Join(
+			[]string{
+				string(runtime_type),
+				tenant_id,
+				plugin_unique_identifier.String(),
+			},
+			":",
+		),
 		func() (*plugin_entities.PluginDeclaration, error) {
-			declaration, err := db.GetOne[models.PluginDeclaration](
-				db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
-			)
-			if err != nil && err != db.ErrDatabaseNotFound {
-				return nil, err
-			}
+			if runtime_type != plugin_entities.PLUGIN_RUNTIME_TYPE_REMOTE {
+				declaration, err := db.GetOne[models.PluginDeclaration](
+					db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
+				)
+				if err != nil && err != db.ErrDatabaseNotFound {
+					return nil, err
+				}
 
-			if err != nil {
-				return nil, err
-			}
+				if err != nil {
+					return nil, err
+				}
 
-			return &declaration.Declaration, nil
+				return &declaration.Declaration, nil
+			} else {
+				// try to fetch the declaration from plugin if it's remote
+				plugin, err := db.GetOne[models.Plugin](
+					db.Equal("unique_identifier", plugin_unique_identifier.String()),
+					db.Equal("install_type", string(plugin_entities.PLUGIN_RUNTIME_TYPE_REMOTE)),
+					db.Equal("tenant_id", tenant_id),
+				)
+				if err != nil && err != db.ErrDatabaseNotFound {
+					return nil, err
+				}
+
+				if err != nil {
+					return nil, err
+				}
+
+				return &plugin.Declaration, nil
+			}
 		},
 	)
 }