浏览代码

refactor: session

Yeuoly 1 年之前
父节点
当前提交
137fb37118

+ 5 - 4
internal/core/plugin_daemon/daemon.go

@@ -4,6 +4,7 @@ import (
 	"errors"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
@@ -12,17 +13,17 @@ import (
 
 type ToolResponseChunk = plugin_entities.InvokeToolResponseChunk
 
-func InvokeTool(session *entities.InvocationSession, provider_name string, tool_name string, tool_parameters map[string]any) (
+func InvokeTool(session *session_manager.Session, provider_name string, tool_name string, tool_parameters map[string]any) (
 	*entities.InvocationResponse[ToolResponseChunk], error,
 ) {
-	runtime := plugin_manager.Get(session.PluginIdentity)
+	runtime := plugin_manager.Get(session.PluginIdentity())
 	if runtime == nil {
 		return nil, errors.New("plugin not found")
 	}
 
 	response := entities.NewInvocationResponse[ToolResponseChunk](512)
 
-	listener := runtime.Listen(session.ID)
+	listener := runtime.Listen(session.ID())
 	listener.AddListener(func(message []byte) {
 		chunk, err := parser.UnmarshalJsonBytes[plugin_entities.StreamMessage](message)
 		if err != nil {
@@ -50,7 +51,7 @@ func InvokeTool(session *entities.InvocationSession, provider_name string, tool_
 		listener.Close()
 	})
 
-	runtime.Write(session.ID, []byte(parser.MarshalJson(
+	runtime.Write(session.ID(), []byte(parser.MarshalJson(
 		map[string]any{
 			"provider":   provider_name,
 			"tool":       tool_name,

+ 2 - 0
internal/core/plugin_manager/stdio_holder/io.go

@@ -78,6 +78,8 @@ func (s *stdioHolder) StartStdout() {
 				}
 			case plugin_entities.PLUGIN_EVENT_ERROR:
 				log.Error("plugin %s: %s", s.pluginIdentity, event.Data)
+			case plugin_entities.PLUGIN_EVENT_INVOKE:
+				// invoke dify
 			}
 		}
 	}

+ 73 - 0
internal/core/session_manager/session.go

@@ -0,0 +1,73 @@
+package session_manager
+
+import (
+	"sync"
+
+	"github.com/google/uuid"
+)
+
+var (
+	sessions     map[string]*Session = map[string]*Session{}
+	session_lock sync.RWMutex
+)
+
+type Session struct {
+	id string
+
+	tenant_id       string
+	user_id         string
+	plugin_identity string
+}
+
+func NewSession(tenant_id string, user_id string, plugin_identity string) *Session {
+	s := &Session{
+		id:              uuid.New().String(),
+		tenant_id:       tenant_id,
+		user_id:         user_id,
+		plugin_identity: plugin_identity,
+	}
+
+	session_lock.Lock()
+	defer session_lock.Unlock()
+
+	sessions[s.id] = s
+
+	return s
+}
+
+func GetSession(id string) *Session {
+	session_lock.RLock()
+	defer session_lock.RUnlock()
+
+	return sessions[id]
+}
+
+func DeleteSession(id string) {
+	session_lock.Lock()
+	defer session_lock.Unlock()
+
+	delete(sessions, id)
+}
+
+func (s *Session) Close() {
+	session_lock.Lock()
+	defer session_lock.Unlock()
+
+	delete(sessions, s.id)
+}
+
+func (s *Session) ID() string {
+	return s.id
+}
+
+func (s *Session) TenantID() string {
+	return s.tenant_id
+}
+
+func (s *Session) UserID() string {
+	return s.user_id
+}
+
+func (s *Session) PluginIdentity() string {
+	return s.plugin_identity
+}

+ 3 - 6
internal/service/invoke.go

@@ -2,8 +2,8 @@ package service
 
 import (
 	"github.com/gin-gonic/gin"
-	"github.com/google/uuid"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
@@ -12,11 +12,8 @@ import (
 
 func InvokeTool(r *plugin_entities.InvokePluginRequest[plugin_entities.InvokeToolRequest], ctx *gin.Context) {
 	// create session
-	session_id := uuid.New().String()
-	session := &entities.InvocationSession{
-		ID:             session_id,
-		PluginIdentity: parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion),
-	}
+	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	defer session.Close()
 
 	writer := ctx.Writer
 	writer.WriteHeader(200)

+ 1 - 0
internal/types/entities/plugin_entities/request.go

@@ -11,6 +11,7 @@ type InvokePluginRequest[T InvokePluginRequestData] struct {
 	PluginName    string `json:"plugin_name" binding:"required"`
 	PluginVersion string `json:"plugin_version" binding:"required"`
 	TenantId      string `json:"tenant_id" binding:"required"`
+	UserId        string `json:"user_id" binding:"required"`
 	Data          T      `json:"data" binding:"required"`
 }