Browse Source

feat: add incoming session information into requests

Yeuoly 9 months ago
parent
commit
1a78a5a903

+ 11 - 3
internal/core/plugin_daemon/basic.go

@@ -6,10 +6,18 @@ func getBasicPluginAccessMap(
 	user_id string,
 	access_type access_types.PluginAccessType,
 	action access_types.PluginAccessAction,
+	conversation_id *string,
+	message_id *string,
+	app_id *string,
+	endpoint_id *string,
 ) map[string]any {
 	return map[string]any{
-		"user_id": user_id,
-		"type":    access_type,
-		"action":  action,
+		"user_id":         user_id,
+		"type":            access_type,
+		"action":          action,
+		"conversation_id": conversation_id,
+		"message_id":      message_id,
+		"app_id":          app_id,
+		"endpoint_id":     endpoint_id,
 	}
 }

+ 9 - 1
internal/core/plugin_daemon/generic.go

@@ -102,7 +102,15 @@ func getInvokePluginMap(
 	session *session_manager.Session,
 	request any,
 ) map[string]any {
-	req := getBasicPluginAccessMap(session.UserID, session.InvokeFrom, session.Action)
+	req := getBasicPluginAccessMap(
+		session.UserID,
+		session.InvokeFrom,
+		session.Action,
+		session.ConversationID,
+		session.MessageID,
+		session.AppID,
+		session.EndpointID,
+	)
 	for k, v := range parser.StructToMap(request) {
 		req[k] = v
 	}

+ 14 - 0
internal/core/session_manager/session.go

@@ -33,6 +33,12 @@ type Session struct {
 	InvokeFrom             access_types.PluginAccessType          `json:"invoke_from"`
 	Action                 access_types.PluginAccessAction        `json:"action"`
 	Declaration            *plugin_entities.PluginDeclaration     `json:"declaration"`
+
+	// information about incoming request
+	ConversationID *string `json:"conversation_id"`
+	MessageID      *string `json:"message_id"`
+	AppID          *string `json:"app_id"`
+	EndpointID     *string `json:"endpoint_id"`
 }
 
 func sessionKey(id string) string {
@@ -49,6 +55,10 @@ type NewSessionPayload struct {
 	Declaration            *plugin_entities.PluginDeclaration     `json:"declaration"`
 	BackwardsInvocation    dify_invocation.BackwardsInvocation    `json:"backwards_invocation"`
 	IgnoreCache            bool                                   `json:"ignore_cache"`
+	ConversationID         *string                                `json:"conversation_id"`
+	MessageID              *string                                `json:"message_id"`
+	AppID                  *string                                `json:"app_id"`
+	EndpointID             *string                                `json:"endpoint_id"`
 }
 
 func NewSession(payload NewSessionPayload) *Session {
@@ -62,6 +72,10 @@ func NewSession(payload NewSessionPayload) *Session {
 		Action:                 payload.Action,
 		Declaration:            payload.Declaration,
 		backwardsInvocation:    payload.BackwardsInvocation,
+		ConversationID:         payload.ConversationID,
+		MessageID:              payload.MessageID,
+		AppID:                  payload.AppID,
+		EndpointID:             payload.EndpointID,
 	}
 
 	session_lock.Lock()

+ 1 - 0
internal/service/endpoint.go

@@ -99,6 +99,7 @@ func Endpoint(
 			Declaration:            runtime.Configuration(),
 			BackwardsInvocation:    manager.BackwardsInvocation(),
 			IgnoreCache:            false,
+			EndpointID:             &endpoint.ID,
 		},
 	)
 	defer session.Close(session_manager.CloseSessionPayload{

+ 0 - 39
internal/service/invoke_tool.go

@@ -1,12 +1,9 @@
 package service
 
 import (
-	"errors"
-
 	"github.com/gin-gonic/gin"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
-	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
@@ -14,42 +11,6 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
-func createSession[T any](
-	r *plugin_entities.InvokePluginRequest[T],
-	access_type access_types.PluginAccessType,
-	access_action access_types.PluginAccessAction,
-	cluster_id string,
-) (*session_manager.Session, error) {
-	manager := plugin_manager.Manager()
-	if manager == nil {
-		return nil, errors.New("failed to get plugin manager")
-	}
-
-	// try fetch plugin identifier from plugin id
-
-	runtime := manager.Get(r.UniqueIdentifier)
-	if runtime == nil {
-		return nil, errors.New("failed to get plugin runtime")
-	}
-
-	session := session_manager.NewSession(
-		session_manager.NewSessionPayload{
-			TenantID:               r.TenantId,
-			UserID:                 r.UserId,
-			PluginUniqueIdentifier: r.UniqueIdentifier,
-			ClusterID:              cluster_id,
-			InvokeFrom:             access_type,
-			Action:                 access_action,
-			Declaration:            runtime.Configuration(),
-			BackwardsInvocation:    manager.BackwardsInvocation(),
-			IgnoreCache:            false,
-		},
-	)
-
-	session.BindRuntime(runtime)
-	return session, nil
-}
-
 func InvokeTool(
 	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTool],
 	ctx *gin.Context,

+ 50 - 0
internal/service/session.go

@@ -0,0 +1,50 @@
+package service
+
+import (
+	"errors"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+)
+
+func createSession[T any](
+	r *plugin_entities.InvokePluginRequest[T],
+	access_type access_types.PluginAccessType,
+	access_action access_types.PluginAccessAction,
+	cluster_id string,
+) (*session_manager.Session, error) {
+	manager := plugin_manager.Manager()
+	if manager == nil {
+		return nil, errors.New("failed to get plugin manager")
+	}
+
+	// try fetch plugin identifier from plugin id
+
+	runtime := manager.Get(r.UniqueIdentifier)
+	if runtime == nil {
+		return nil, errors.New("failed to get plugin runtime")
+	}
+
+	session := session_manager.NewSession(
+		session_manager.NewSessionPayload{
+			TenantID:               r.TenantId,
+			UserID:                 r.UserId,
+			PluginUniqueIdentifier: r.UniqueIdentifier,
+			ClusterID:              cluster_id,
+			InvokeFrom:             access_type,
+			Action:                 access_action,
+			Declaration:            runtime.Configuration(),
+			BackwardsInvocation:    manager.BackwardsInvocation(),
+			IgnoreCache:            false,
+			ConversationID:         r.ConversationID,
+			MessageID:              r.MessageID,
+			AppID:                  r.AppID,
+			EndpointID:             r.EndpointID,
+		},
+	)
+
+	session.BindRuntime(runtime)
+	return session, nil
+}

+ 6 - 1
internal/types/entities/plugin_entities/request.go

@@ -14,5 +14,10 @@ type InvokePluginRequest[T any] struct {
 	BasePluginIdentifier
 
 	UniqueIdentifier PluginUniqueIdentifier `json:"unique_identifier"`
-	Data             T                      `json:"data" validate:"required"`
+	ConversationID   *string                `json:"conversation_id"`
+	MessageID        *string                `json:"message_id"`
+	AppID            *string                `json:"app_id"`
+	EndpointID       *string                `json:"endpoint_id"`
+
+	Data T `json:"data" validate:"required"`
 }