Browse Source

refactor: implment session writer

Yeuoly 11 months ago
parent
commit
54d2074118

+ 12 - 2
internal/core/plugin_daemon/backwards_invocation/request.go

@@ -10,16 +10,26 @@ import (
 type BackwardsInvocationType = dify_invocation.InvokeType
 type BackwardsInvocationType = dify_invocation.InvokeType
 
 
 type BackwardsInvocationWriter interface {
 type BackwardsInvocationWriter interface {
-	Write(event session_manager.PLUGIN_IN_STREAM_EVENT, data any)
+	Write(event session_manager.PLUGIN_IN_STREAM_EVENT, data any) error
 	Done()
 	Done()
 }
 }
 
 
+// BackwardsInvocation is a struct that represents a backwards invocation
+// For different plugin runtime type, stream handler is different
+//  1. Local and Remote: they are both full duplex, multiplexing could be implemented by different session
+//     different session share the same physical channel.
+//  2. AWS: it is half duplex, one request could have multiple channels, we need to combine them into one stream
+//
+// That's why it has a writer, for different transaction, the writer is unique
 type BackwardsInvocation struct {
 type BackwardsInvocation struct {
 	typ              BackwardsInvocationType
 	typ              BackwardsInvocationType
 	id               string
 	id               string
 	detailed_request map[string]any
 	detailed_request map[string]any
 	session          *session_manager.Session
 	session          *session_manager.Session
-	writer           BackwardsInvocationWriter
+
+	// writer is the writer that writes the data to the session
+	// NOTE: write operation will not raise errors
+	writer BackwardsInvocationWriter
 }
 }
 
 
 func NewBackwardsInvocation(
 func NewBackwardsInvocation(

+ 13 - 6
internal/core/plugin_daemon/backwards_invocation/transaction/aws_event_writer.go

@@ -1,6 +1,8 @@
 package transaction
 package transaction
 
 
 import (
 import (
+	"io"
+
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 )
 )
@@ -8,22 +10,27 @@ import (
 // AWSTransactionWriter is a writer that implements the backwards_invocation.BackwardsInvocationWriter interface
 // AWSTransactionWriter is a writer that implements the backwards_invocation.BackwardsInvocationWriter interface
 // it is used to write data to the plugin runtime
 // it is used to write data to the plugin runtime
 type AWSTransactionWriter struct {
 type AWSTransactionWriter struct {
-	event_id string
+	session     *session_manager.Session
+	writeCloser io.WriteCloser
 
 
 	backwards_invocation.BackwardsInvocationWriter
 	backwards_invocation.BackwardsInvocationWriter
 }
 }
 
 
 // NewAWSTransactionWriter creates a new transaction writer
 // NewAWSTransactionWriter creates a new transaction writer
-func NewAWSTransactionWriter(event_id string) *AWSTransactionWriter {
+func NewAWSTransactionWriter(session *session_manager.Session, writeCloser io.WriteCloser) *AWSTransactionWriter {
 	return &AWSTransactionWriter{
 	return &AWSTransactionWriter{
-		event_id: event_id,
+		session:     session,
+		writeCloser: writeCloser,
 	}
 	}
 }
 }
 
 
-func (w *AWSTransactionWriter) Write(event session_manager.PLUGIN_IN_STREAM_EVENT, data any) {
-
+// Write writes the event and data to the session
+// WARNING: write
+func (w *AWSTransactionWriter) Write(event session_manager.PLUGIN_IN_STREAM_EVENT, data any) error {
+	_, err := w.writeCloser.Write(w.session.Message(event, data))
+	return err
 }
 }
 
 
 func (w *AWSTransactionWriter) Done() {
 func (w *AWSTransactionWriter) Done() {
-
+	w.writeCloser.Close()
 }
 }

+ 2 - 2
internal/core/plugin_daemon/backwards_invocation/transaction/full_duplex_event_writer.go

@@ -19,8 +19,8 @@ func NewFullDuplexEventWriter(session *session_manager.Session) *FullDuplexTrans
 	}
 	}
 }
 }
 
 
-func (w *FullDuplexTransactionWriter) Write(event session_manager.PLUGIN_IN_STREAM_EVENT, data any) {
-	w.session.Write(event, data)
+func (w *FullDuplexTransactionWriter) Write(event session_manager.PLUGIN_IN_STREAM_EVENT, data any) error {
+	return w.session.Write(event, data)
 }
 }
 
 
 func (w *FullDuplexTransactionWriter) Done() {
 func (w *FullDuplexTransactionWriter) Done() {

+ 2 - 8
internal/core/plugin_daemon/generic.go

@@ -29,13 +29,7 @@ func genericInvokePlugin[Req any, Rsp any](
 	response := stream.NewStreamResponse[Rsp](response_buffer_size)
 	response := stream.NewStreamResponse[Rsp](response_buffer_size)
 
 
 	listener := runtime.Listen(session.ID())
 	listener := runtime.Listen(session.ID())
-	listener.Listen(func(message []byte) {
-		chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message)
-		if err != nil {
-			log.Error("unmarshal json failed: %s", err.Error())
-			return
-		}
-
+	listener.Listen(func(chunk plugin_entities.SessionMessage) {
 		switch chunk.Type {
 		switch chunk.Type {
 		case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
 		case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
 			chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
 			chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
@@ -49,7 +43,7 @@ func genericInvokePlugin[Req any, Rsp any](
 			// check if the request contains a aws_event_id
 			// check if the request contains a aws_event_id
 			var writer backwards_invocation.BackwardsInvocationWriter
 			var writer backwards_invocation.BackwardsInvocationWriter
 			if chunk.RuntimeType == plugin_entities.PLUGIN_RUNTIME_TYPE_AWS {
 			if chunk.RuntimeType == plugin_entities.PLUGIN_RUNTIME_TYPE_AWS {
-				writer = transaction.NewAWSTransactionWriter(chunk.ServerlessEventId)
+				writer = transaction.NewAWSTransactionWriter(session, chunk.SessionWriter)
 			} else {
 			} else {
 				writer = transaction.NewFullDuplexEventWriter(session)
 				writer = transaction.NewFullDuplexEventWriter(session)
 			}
 			}

+ 3 - 2
internal/core/plugin_manager/aws_manager/io.go

@@ -9,6 +9,7 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
 )
 
 
@@ -22,8 +23,8 @@ func (r *AWSPluginRuntime) consume() {
 	}
 	}
 }
 }
 
 
-func (r *AWSPluginRuntime) Listen(session_id string) *entities.BytesIOListener {
-	l := entities.NewIOListener[[]byte]()
+func (r *AWSPluginRuntime) Listen(session_id string) *entities.Broadcast[plugin_entities.SessionMessage] {
+	l := entities.NewBroadcast[plugin_entities.SessionMessage]()
 	l.OnClose(func() {
 	l.OnClose(func() {
 		// close the pipe writer
 		// close the pipe writer
 		writer, exists := r.session_pool.Load(session_id)
 		writer, exists := r.session_pool.Load(session_id)

+ 1 - 1
internal/core/plugin_manager/aws_manager/packager_test.go

@@ -47,7 +47,7 @@ func (r *TPluginRuntime) Wait() (<-chan bool, error) {
 	return nil, nil
 	return nil, nil
 }
 }
 
 
-func (r *TPluginRuntime) Listen(string) *entities.Broadcast[[]byte] {
+func (r *TPluginRuntime) Listen(string) *entities.Broadcast[plugin_entities.SessionMessage] {
 	return nil
 	return nil
 }
 }
 
 

+ 15 - 3
internal/core/plugin_manager/local_manager/io.go

@@ -2,15 +2,27 @@ package local_manager
 
 
 import (
 import (
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 )
 )
 
 
-func (r *LocalPluginRuntime) Listen(session_id string) *entities.BytesIOListener {
-	listener := entities.NewIOListener[[]byte]()
+func (r *LocalPluginRuntime) Listen(session_id string) *entities.Broadcast[plugin_entities.SessionMessage] {
+	listener := entities.NewBroadcast[plugin_entities.SessionMessage]()
 	listener.OnClose(func() {
 	listener.OnClose(func() {
 		RemoveStdioListener(r.io_identity, session_id)
 		RemoveStdioListener(r.io_identity, session_id)
 	})
 	})
 	OnStdioEvent(r.io_identity, session_id, func(b []byte) {
 	OnStdioEvent(r.io_identity, session_id, func(b []byte) {
-		listener.Send(b)
+		// unmarshal the session message
+		data, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](b)
+		if err != nil {
+			log.Error("unmarshal json failed: %s, failed to parse session message", err.Error())
+			return
+		}
+		// set the runtime type
+		data.RuntimeType = r.Type()
+
+		listener.Send(data)
 	})
 	})
 	return listener
 	return listener
 }
 }

+ 15 - 3
internal/core/plugin_manager/remote_manager/io.go

@@ -2,17 +2,29 @@ package remote_manager
 
 
 import (
 import (
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 	"github.com/panjf2000/gnet/v2"
 	"github.com/panjf2000/gnet/v2"
 )
 )
 
 
-func (r *RemotePluginRuntime) Listen(session_id string) *entities.BytesIOListener {
-	listener := entities.NewIOListener[[]byte]()
+func (r *RemotePluginRuntime) Listen(session_id string) *entities.Broadcast[plugin_entities.SessionMessage] {
+	listener := entities.NewBroadcast[plugin_entities.SessionMessage]()
 	listener.OnClose(func() {
 	listener.OnClose(func() {
 		r.removeCallback(session_id)
 		r.removeCallback(session_id)
 	})
 	})
 
 
 	r.addCallback(session_id, func(data []byte) {
 	r.addCallback(session_id, func(data []byte) {
-		listener.Send(data)
+		// unmarshal the session message
+		chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](data)
+		if err != nil {
+			log.Error("unmarshal json failed: %s, failed to parse session message", err.Error())
+			return
+		}
+		// set the runtime type
+		chunk.RuntimeType = r.Type()
+
+		listener.Send(chunk)
 	})
 	})
 
 
 	return listener
 	return listener

+ 9 - 5
internal/core/session_manager/session.go

@@ -88,14 +88,18 @@ const (
 	PLUGIN_IN_STREAM_EVENT_RESPONSE PLUGIN_IN_STREAM_EVENT = "backwards_response"
 	PLUGIN_IN_STREAM_EVENT_RESPONSE PLUGIN_IN_STREAM_EVENT = "backwards_response"
 )
 )
 
 
+func (s *Session) Message(event PLUGIN_IN_STREAM_EVENT, data any) []byte {
+	return parser.MarshalJsonBytes(map[string]any{
+		"session_id": s.id,
+		"event":      event,
+		"data":       data,
+	})
+}
+
 func (s *Session) Write(event PLUGIN_IN_STREAM_EVENT, data any) error {
 func (s *Session) Write(event PLUGIN_IN_STREAM_EVENT, data any) error {
 	if s.runtime == nil {
 	if s.runtime == nil {
 		return errors.New("runtime not bound")
 		return errors.New("runtime not bound")
 	}
 	}
-	s.runtime.Write(s.id, parser.MarshalJsonBytes(map[string]any{
-		"session_id": s.id,
-		"event":      event,
-		"data":       data,
-	}))
+	s.runtime.Write(s.id, s.Message(event, data))
 	return nil
 	return nil
 }
 }

+ 1 - 1
internal/types/entities/listener.go

@@ -10,7 +10,7 @@ type Broadcast[T any] struct {
 
 
 type BytesIOListener = Broadcast[[]byte]
 type BytesIOListener = Broadcast[[]byte]
 
 
-func NewIOListener[T any]() *Broadcast[T] {
+func NewBroadcast[T any]() *Broadcast[T] {
 	return &Broadcast[T]{
 	return &Broadcast[T]{
 		l: &sync.RWMutex{},
 		l: &sync.RWMutex{},
 	}
 	}

+ 2 - 1
internal/types/entities/plugin_entities/event.go

@@ -2,6 +2,7 @@ package plugin_entities
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
+	"io"
 )
 )
 
 
 type PluginUniversalEvent struct {
 type PluginUniversalEvent struct {
@@ -31,7 +32,7 @@ type SessionMessage struct {
 	RuntimeType PluginRuntimeType    `json:"runtime_type"`
 	RuntimeType PluginRuntimeType    `json:"runtime_type"`
 
 
 	// only used for aws event bridge, not used for stdio and tcp
 	// only used for aws event bridge, not used for stdio and tcp
-	ServerlessEventId string `json:"serverless_event_id"`
+	SessionWriter io.WriteCloser `json:"-"`
 }
 }
 
 
 type SESSION_MESSAGE_TYPE string
 type SESSION_MESSAGE_TYPE string

+ 1 - 1
internal/types/entities/plugin_entities/runtime.go

@@ -84,7 +84,7 @@ type (
 	}
 	}
 
 
 	PluginRuntimeSessionIOInterface interface {
 	PluginRuntimeSessionIOInterface interface {
-		Listen(session_id string) *entities.BytesIOListener
+		Listen(session_id string) *entities.Broadcast[SessionMessage]
 		Write(session_id string, data []byte)
 		Write(session_id string, data []byte)
 	}
 	}