浏览代码

feat: service support persistence

Yeuoly 11 月之前
父节点
当前提交
e139f38c1f

+ 26 - 0
internal/core/dify_invocation/types.go

@@ -1,8 +1,10 @@
 package dify_invocation
 
 import (
+	"github.com/go-playground/validator/v10"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/app_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
 )
 
 type BaseInvokeDifyRequest struct {
@@ -23,6 +25,7 @@ const (
 	INVOKE_TYPE_TOOL           InvokeType = "tool"
 	INVOKE_TYPE_NODE           InvokeType = "node"
 	INVOKE_TYPE_APP            InvokeType = "app"
+	INVOKE_TYPE_STORAGE        InvokeType = "storage"
 )
 
 type InvokeLLMRequest struct {
@@ -71,6 +74,29 @@ type InvokeAppSchema struct {
 	Files          []*app_entities.FileVar `json:"files" validate:"omitempty,dive"`
 }
 
+type StorageOpt string
+
+const (
+	STORAGE_OPT_GET StorageOpt = "get"
+	STORAGE_OPT_SET StorageOpt = "set"
+	STORAGE_OPT_DEL StorageOpt = "del"
+)
+
+func isStorageOpt(fl validator.FieldLevel) bool {
+	opt := StorageOpt(fl.Field().String())
+	return opt == STORAGE_OPT_GET || opt == STORAGE_OPT_SET || opt == STORAGE_OPT_DEL
+}
+
+func init() {
+	validators.GlobalEntitiesValidator.RegisterValidation("storage_opt", isStorageOpt)
+}
+
+type InvokeStorageRequest struct {
+	Opt   StorageOpt `json:"opt" validate:"required,storage_opt"`
+	Key   string     `json:"key" validate:"required"`
+	Value string     `json:"value"` // encoded in hex, optional
+}
+
 type InvokeAppRequest struct {
 	BaseInvokeDifyRequest
 	requests.BaseRequestInvokeModel

+ 10 - 4
internal/core/persistence/init.go

@@ -5,7 +5,11 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 )
 
-func InitPersistence(config *app.Config) *Persistence {
+var (
+	persistence *Persistence
+)
+
+func InitPersistence(config *app.Config) {
 	if config.PersistenceStorageType == "s3" {
 		s3, err := NewS3Wrapper(
 			config.PersistenceStorageS3Region,
@@ -17,16 +21,18 @@ func InitPersistence(config *app.Config) *Persistence {
 			log.Panic("Failed to initialize S3 wrapper: %v", err)
 		}
 
-		return &Persistence{
+		persistence = &Persistence{
 			storage: s3,
 		}
 	} else if config.PersistenceStorageType == "local" {
-		return &Persistence{
+		persistence = &Persistence{
 			storage: NewLocalWrapper(config.PersistenceStorageLocalPath),
 		}
 	} else {
 		log.Panic("Invalid persistence storage type: %s", config.PersistenceStorageType)
 	}
+}
 
-	return nil
+func GetPersistence() *Persistence {
+	return persistence
 }

+ 11 - 11
internal/core/persistence/persistence.go

@@ -16,21 +16,21 @@ const (
 	CACHE_KEY_PREFIX = "persistence:cache"
 )
 
-func (c *Persistence) getCacheKey(tenant_id string, plugin_checksum string, key string) string {
-	return fmt.Sprintf("%s:%s:%s:%s", CACHE_KEY_PREFIX, tenant_id, plugin_checksum, key)
+func (c *Persistence) getCacheKey(tenant_id string, plugin_identity string, key string) string {
+	return fmt.Sprintf("%s:%s:%s:%s", CACHE_KEY_PREFIX, tenant_id, plugin_identity, key)
 }
 
-func (c *Persistence) Save(tenant_id string, plugin_checksum string, key string, data []byte) error {
+func (c *Persistence) Save(tenant_id string, plugin_identity string, key string, data []byte) error {
 	if len(key) > 64 {
 		return fmt.Errorf("key length must be less than 64 characters")
 	}
 
-	return c.storage.Save(tenant_id, plugin_checksum, key, data)
+	return c.storage.Save(tenant_id, plugin_identity, key, data)
 }
 
-func (c *Persistence) Load(tenant_id string, plugin_checksum string, key string) ([]byte, error) {
+func (c *Persistence) Load(tenant_id string, plugin_identity string, key string) ([]byte, error) {
 	// check if the key exists in cache
-	h, err := cache.GetString(c.getCacheKey(tenant_id, plugin_checksum, key))
+	h, err := cache.GetString(c.getCacheKey(tenant_id, plugin_identity, key))
 	if err != nil && err != cache.ErrNotFound {
 		return nil, err
 	}
@@ -39,22 +39,22 @@ func (c *Persistence) Load(tenant_id string, plugin_checksum string, key string)
 	}
 
 	// load from storage
-	data, err := c.storage.Load(tenant_id, plugin_checksum, key)
+	data, err := c.storage.Load(tenant_id, plugin_identity, key)
 	if err != nil {
 		return nil, err
 	}
 
 	// add to cache
-	cache.Store(c.getCacheKey(tenant_id, plugin_checksum, key), hex.EncodeToString(data), time.Minute*5)
+	cache.Store(c.getCacheKey(tenant_id, plugin_identity, key), hex.EncodeToString(data), time.Minute*5)
 
 	return data, nil
 }
 
-func (c *Persistence) Delete(tenant_id string, plugin_checksum string, key string) error {
+func (c *Persistence) Delete(tenant_id string, plugin_identity string, key string) error {
 	// delete from cache and storage
-	err := cache.Del(c.getCacheKey(tenant_id, plugin_checksum, key))
+	err := cache.Del(c.getCacheKey(tenant_id, plugin_identity, key))
 	if err != nil {
 		return err
 	}
-	return c.storage.Delete(tenant_id, plugin_checksum, key)
+	return c.storage.Delete(tenant_id, plugin_identity, key)
 }

+ 8 - 8
internal/core/persistence/persistence_test.go

@@ -17,18 +17,18 @@ func TestPersistenceStoreAndLoad(t *testing.T) {
 	}
 	defer cache.Close()
 
-	p := InitPersistence(&app.Config{
+	InitPersistence(&app.Config{
 		PersistenceStorageType:      "local",
 		PersistenceStorageLocalPath: "./persistence_storage",
 	})
 
 	key := strings.RandomString(10)
 
-	if err := p.Save("tenant_id", "plugin_checksum", key, []byte("data")); err != nil {
+	if err := persistence.Save("tenant_id", "plugin_checksum", key, []byte("data")); err != nil {
 		t.Fatalf("Failed to save data: %v", err)
 	}
 
-	data, err := p.Load("tenant_id", "plugin_checksum", key)
+	data, err := persistence.Load("tenant_id", "plugin_checksum", key)
 	if err != nil {
 		t.Fatalf("Failed to load data: %v", err)
 	}
@@ -65,14 +65,14 @@ func TestPersistenceSaveAndLoadWithLongKey(t *testing.T) {
 	}
 	defer cache.Close()
 
-	p := InitPersistence(&app.Config{
+	InitPersistence(&app.Config{
 		PersistenceStorageType:      "local",
 		PersistenceStorageLocalPath: "./persistence_storage",
 	})
 
 	key := strings.RandomString(65)
 
-	if err := p.Save("tenant_id", "plugin_checksum", key, []byte("data")); err == nil {
+	if err := persistence.Save("tenant_id", "plugin_checksum", key, []byte("data")); err == nil {
 		t.Fatalf("Expected error, got nil")
 	}
 }
@@ -84,18 +84,18 @@ func TestPersistenceDelete(t *testing.T) {
 	}
 	defer cache.Close()
 
-	p := InitPersistence(&app.Config{
+	InitPersistence(&app.Config{
 		PersistenceStorageType:      "local",
 		PersistenceStorageLocalPath: "./persistence_storage",
 	})
 
 	key := strings.RandomString(10)
 
-	if err := p.Save("tenant_id", "plugin_checksum", key, []byte("data")); err != nil {
+	if err := persistence.Save("tenant_id", "plugin_checksum", key, []byte("data")); err != nil {
 		t.Fatalf("Failed to save data: %v", err)
 	}
 
-	if err := p.Delete("tenant_id", "plugin_checksum", key); err != nil {
+	if err := persistence.Delete("tenant_id", "plugin_checksum", key); err != nil {
 		t.Fatalf("Failed to delete data: %v", err)
 	}
 

+ 72 - 8
internal/core/plugin_daemon/backwards_invocation/task.go

@@ -1,6 +1,7 @@
 package backwards_invocation
 
 import (
+	"encoding/hex"
 	"fmt"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
@@ -169,28 +170,31 @@ func prepareDifyInvocationArguments(
 var (
 	dispatchMapping = map[dify_invocation.InvokeType]func(handle *BackwardsInvocation){
 		dify_invocation.INVOKE_TYPE_TOOL: func(handle *BackwardsInvocation) {
-			genericDispatchTask[dify_invocation.InvokeToolRequest](handle, executeDifyInvocationToolTask)
+			genericDispatchTask(handle, executeDifyInvocationToolTask)
 		},
 		dify_invocation.INVOKE_TYPE_LLM: func(handle *BackwardsInvocation) {
-			genericDispatchTask[dify_invocation.InvokeLLMRequest](handle, executeDifyInvocationLLMTask)
+			genericDispatchTask(handle, executeDifyInvocationLLMTask)
 		},
 		dify_invocation.INVOKE_TYPE_TEXT_EMBEDDING: func(handle *BackwardsInvocation) {
-			genericDispatchTask[dify_invocation.InvokeTextEmbeddingRequest](handle, executeDifyInvocationTextEmbeddingTask)
+			genericDispatchTask(handle, executeDifyInvocationTextEmbeddingTask)
 		},
 		dify_invocation.INVOKE_TYPE_RERANK: func(handle *BackwardsInvocation) {
-			genericDispatchTask[dify_invocation.InvokeRerankRequest](handle, executeDifyInvocationRerankTask)
+			genericDispatchTask(handle, executeDifyInvocationRerankTask)
 		},
 		dify_invocation.INVOKE_TYPE_TTS: func(handle *BackwardsInvocation) {
-			genericDispatchTask[dify_invocation.InvokeTTSRequest](handle, executeDifyInvocationTTSTask)
+			genericDispatchTask(handle, executeDifyInvocationTTSTask)
 		},
 		dify_invocation.INVOKE_TYPE_SPEECH2TEXT: func(handle *BackwardsInvocation) {
-			genericDispatchTask[dify_invocation.InvokeSpeech2TextRequest](handle, executeDifyInvocationSpeech2TextTask)
+			genericDispatchTask(handle, executeDifyInvocationSpeech2TextTask)
 		},
 		dify_invocation.INVOKE_TYPE_MODERATION: func(handle *BackwardsInvocation) {
-			genericDispatchTask[dify_invocation.InvokeModerationRequest](handle, executeDifyInvocationModerationTask)
+			genericDispatchTask(handle, executeDifyInvocationModerationTask)
 		},
 		dify_invocation.INVOKE_TYPE_APP: func(handle *BackwardsInvocation) {
-			genericDispatchTask[dify_invocation.InvokeAppRequest](handle, executeDifyInvocationAppTask)
+			genericDispatchTask(handle, executeDifyInvocationAppTask)
+		},
+		dify_invocation.INVOKE_TYPE_STORAGE: func(handle *BackwardsInvocation) {
+			genericDispatchTask(handle, executeDifyInvocationStorageTask)
 		},
 	}
 )
@@ -356,3 +360,63 @@ func executeDifyInvocationAppTask(
 		handle.WriteResponse("stream", t)
 	})
 }
+
+func executeDifyInvocationStorageTask(
+	handle *BackwardsInvocation,
+	request *dify_invocation.InvokeStorageRequest,
+) {
+	if handle.session == nil {
+		handle.WriteError(fmt.Errorf("session not found"))
+		return
+	}
+
+	persistence := handle.session.Storage()
+	if persistence == nil {
+		handle.WriteError(fmt.Errorf("persistence not found"))
+		return
+	}
+
+	tenant_id, err := handle.TenantID()
+	if err != nil {
+		handle.WriteError(fmt.Errorf("get tenant id failed: %s", err.Error()))
+		return
+	}
+
+	plugin_id := handle.session.PluginIdentity
+
+	if request.Opt == dify_invocation.STORAGE_OPT_GET {
+		data, err := persistence.Load(tenant_id, plugin_id, request.Key)
+		if err != nil {
+			handle.WriteError(fmt.Errorf("load data failed: %s", err.Error()))
+			return
+		}
+
+		handle.WriteResponse("struct", map[string]any{
+			"data": hex.EncodeToString(data),
+		})
+	} else if request.Opt == dify_invocation.STORAGE_OPT_SET {
+		data, err := hex.DecodeString(request.Value)
+		if err != nil {
+			handle.WriteError(fmt.Errorf("decode data failed: %s", err.Error()))
+			return
+		}
+
+		if err := persistence.Save(tenant_id, plugin_id, request.Key, data); err != nil {
+			handle.WriteError(fmt.Errorf("save data failed: %s", err.Error()))
+			return
+		}
+
+		handle.WriteResponse("struct", map[string]any{
+			"data": "ok",
+		})
+	} else if request.Opt == dify_invocation.STORAGE_OPT_DEL {
+		if err := persistence.Delete(tenant_id, plugin_id, request.Key); err != nil {
+			handle.WriteError(fmt.Errorf("delete data failed: %s", err.Error()))
+			return
+		}
+
+		handle.WriteResponse("struct", map[string]any{
+			"data": "ok",
+		})
+	}
+}

+ 10 - 2
internal/core/session_manager/session.go

@@ -7,6 +7,7 @@ import (
 	"time"
 
 	"github.com/google/uuid"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/cache"
@@ -21,8 +22,9 @@ var (
 
 // session need to implement the backwards_invocation.BackwardsInvocationWriter interface
 type Session struct {
-	ID      string                                 `json:"id"`
-	runtime plugin_entities.PluginRuntimeInterface `json:"-"`
+	ID          string                                 `json:"id"`
+	runtime     plugin_entities.PluginRuntimeInterface `json:"-"`
+	persistence *persistence.Persistence               `json:"-"`
 
 	TenantID       string                             `json:"tenant_id"`
 	UserID         string                             `json:"user_id"`
@@ -45,6 +47,7 @@ func NewSession(
 	invoke_from access_types.PluginAccessType,
 	action access_types.PluginAccessAction,
 	declaration *plugin_entities.PluginDeclaration,
+	persistence *persistence.Persistence,
 ) *Session {
 	s := &Session{
 		ID:             uuid.New().String(),
@@ -55,6 +58,7 @@ func NewSession(
 		InvokeFrom:     invoke_from,
 		Action:         action,
 		Declaration:    declaration,
+		persistence:    persistence,
 	}
 
 	session_lock.Lock()
@@ -97,6 +101,10 @@ func (s *Session) Runtime() plugin_entities.PluginRuntimeInterface {
 	return s.runtime
 }
 
+func (s *Session) Storage() *persistence.Persistence {
+	return s.persistence
+}
+
 type PLUGIN_IN_STREAM_EVENT string
 
 const (

+ 0 - 4
internal/server/app.go

@@ -2,7 +2,6 @@ package server
 
 import (
 	"github.com/langgenius/dify-plugin-daemon/internal/cluster"
-	"github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation/transaction"
 )
 
@@ -18,7 +17,4 @@ type App struct {
 	// aws transaction handler
 	// accept aws transaction request and forward to the plugin daemon
 	aws_transaction_handler *transaction.AWSTransactionHandler
-
-	// persistence
-	persistence *persistence.Persistence
 }

+ 1 - 1
internal/server/server.go

@@ -26,7 +26,7 @@ func (a *App) Run(config *app.Config) {
 	plugin_manager.InitGlobalPluginManager(a.cluster, config)
 
 	// init persistence
-	a.persistence = persistence.InitPersistence(config)
+	persistence.InitPersistence(config)
 
 	// launch cluster
 	a.cluster.Launch()

+ 13 - 1
internal/service/endpoint.go

@@ -8,6 +8,7 @@ import (
 	"time"
 
 	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
@@ -17,7 +18,11 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
 
-func Endpoint(ctx *gin.Context, endpoint *models.Endpoint, path string) {
+func Endpoint(
+	ctx *gin.Context,
+	endpoint *models.Endpoint,
+	path string,
+) {
 	req := ctx.Request.Clone(context.Background())
 	req.URL.Path = path
 
@@ -36,6 +41,12 @@ func Endpoint(ctx *gin.Context, endpoint *models.Endpoint, path string) {
 		return
 	}
 
+	persistence := persistence.GetPersistence()
+	if persistence == nil {
+		ctx.JSON(500, gin.H{"error": "persistence not found"})
+		return
+	}
+
 	session := session_manager.NewSession(
 		endpoint.TenantID,
 		"",
@@ -44,6 +55,7 @@ func Endpoint(ctx *gin.Context, endpoint *models.Endpoint, path string) {
 		access_types.PLUGIN_ACCESS_TYPE_Endpoint,
 		access_types.PLUGIN_ACCESS_ACTION_INVOKE_ENDPOINT,
 		runtime.Configuration(),
+		persistence,
 	)
 	defer session.Close()
 

+ 10 - 2
internal/service/invoke_model.go

@@ -16,12 +16,16 @@ func InvokeTool(
 	max_timeout_seconds int,
 ) {
 	// create session
-	session := createSession(
+	session, err := createSession(
 		r,
 		access_types.PLUGIN_ACCESS_TYPE_TOOL,
 		access_types.PLUGIN_ACCESS_ACTION_INVOKE_TOOL,
 		ctx.GetString("cluster_id"),
 	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
 	defer session.Close()
 
 	baseSSEService(
@@ -39,12 +43,16 @@ func ValidateToolCredentials(
 	max_timeout_seconds int,
 ) {
 	// create session
-	session := createSession(
+	session, err := createSession(
 		r,
 		access_types.PLUGIN_ACCESS_TYPE_TOOL,
 		access_types.PLUGIN_ACCESS_ACTION_VALIDATE_TOOL_CREDENTIALS,
 		ctx.GetString("cluster_id"),
 	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
 	defer session.Close()
 
 	baseSSEService(

+ 51 - 10
internal/service/invoke_tool.go

@@ -1,7 +1,10 @@
 package service
 
 import (
+	"errors"
+
 	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/persistence"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
@@ -18,7 +21,12 @@ func createSession[T any](
 	access_type access_types.PluginAccessType,
 	access_action access_types.PluginAccessAction,
 	cluster_id string,
-) *session_manager.Session {
+) (*session_manager.Session, error) {
+	persistence := persistence.GetPersistence()
+	if persistence == nil {
+		return nil, errors.New("persistence not found")
+	}
+
 	plugin_identity := parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion)
 	runtime := plugin_manager.GetGlobalPluginManager().Get(plugin_identity)
 
@@ -30,10 +38,11 @@ func createSession[T any](
 		access_type,
 		access_action,
 		runtime.Configuration(),
+		persistence,
 	)
 
 	session.BindRuntime(runtime)
-	return session
+	return session, nil
 }
 
 func InvokeLLM(
@@ -42,12 +51,16 @@ func InvokeLLM(
 	max_timeout_seconds int,
 ) {
 	// create session
-	session := createSession(
+	session, err := createSession(
 		r,
 		access_types.PLUGIN_ACCESS_TYPE_MODEL,
 		access_types.PLUGIN_ACCESS_ACTION_INVOKE_LLM,
 		ctx.GetString("cluster_id"),
 	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
 	defer session.Close()
 
 	baseSSEService(
@@ -65,11 +78,15 @@ func InvokeTextEmbedding(
 	max_timeout_seconds int,
 ) {
 	// create session
-	session := createSession(
+	session, err := createSession(
 		r,
 		access_types.PLUGIN_ACCESS_TYPE_MODEL,
 		access_types.PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
 		ctx.GetString("cluster_id"))
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
 	defer session.Close()
 
 	baseSSEService(
@@ -87,12 +104,16 @@ func InvokeRerank(
 	max_timeout_seconds int,
 ) {
 	// create session
-	session := createSession(
+	session, err := createSession(
 		r,
 		access_types.PLUGIN_ACCESS_TYPE_MODEL,
 		access_types.PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
 		ctx.GetString("cluster_id"),
 	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
 	defer session.Close()
 
 	baseSSEService(
@@ -110,12 +131,16 @@ func InvokeTTS(
 	max_timeout_seconds int,
 ) {
 	// create session
-	session := createSession(
+	session, err := createSession(
 		r,
 		access_types.PLUGIN_ACCESS_TYPE_MODEL,
 		access_types.PLUGIN_ACCESS_ACTION_INVOKE_TTS,
 		ctx.GetString("cluster_id"),
 	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
 	defer session.Close()
 
 	baseSSEService(
@@ -133,12 +158,16 @@ func InvokeSpeech2Text(
 	max_timeout_seconds int,
 ) {
 	// create session
-	session := createSession(
+	session, err := createSession(
 		r,
 		access_types.PLUGIN_ACCESS_TYPE_MODEL,
 		access_types.PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
 		ctx.GetString("cluster_id"),
 	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
 	defer session.Close()
 
 	baseSSEService(
@@ -156,12 +185,16 @@ func InvokeModeration(
 	max_timeout_seconds int,
 ) {
 	// create session
-	session := createSession(
+	session, err := createSession(
 		r,
 		access_types.PLUGIN_ACCESS_TYPE_MODEL,
 		access_types.PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
 		ctx.GetString("cluster_id"),
 	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
 	defer session.Close()
 
 	baseSSEService(
@@ -179,12 +212,16 @@ func ValidateProviderCredentials(
 	max_timeout_seconds int,
 ) {
 	// create session
-	session := createSession(
+	session, err := createSession(
 		r,
 		access_types.PLUGIN_ACCESS_TYPE_MODEL,
 		access_types.PLUGIN_ACCESS_ACTION_VALIDATE_PROVIDER_CREDENTIALS,
 		ctx.GetString("cluster_id"),
 	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
 	defer session.Close()
 
 	baseSSEService(
@@ -202,12 +239,16 @@ func ValidateModelCredentials(
 	max_timeout_seconds int,
 ) {
 	// create session
-	session := createSession(
+	session, err := createSession(
 		r,
 		access_types.PLUGIN_ACCESS_TYPE_MODEL,
 		access_types.PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS,
 		ctx.GetString("cluster_id"),
 	)
+	if err != nil {
+		ctx.JSON(500, gin.H{"error": err.Error()})
+		return
+	}
 	defer session.Close()
 
 	baseSSEService(