Yeuoly преди 1 година
родител
ревизия
f85bd95840

+ 10 - 10
internal/core/plugin_daemon/backwards_invocation/entities.go

@@ -3,20 +3,20 @@ package backwards_invocation
 type RequestEvent string
 
 const (
-	REQUEST_EVENT_RESPONSE RequestEvent = "backward_invocation_response"
-	REQUEST_EVENT_ERROR    RequestEvent = "backward_invocation_error"
-	REQUEST_EVENT_END      RequestEvent = "backward_invocation_end"
+	REQUEST_EVENT_RESPONSE RequestEvent = "response"
+	REQUEST_EVENT_ERROR    RequestEvent = "error"
+	REQUEST_EVENT_END      RequestEvent = "end"
 )
 
-type BaseRequestEvent struct {
+type BackwardsInvocationResponseEvent struct {
 	BackwardsRequestId string         `json:"backwards_request_id"`
 	Event              RequestEvent   `json:"event"`
 	Message            string         `json:"message"`
 	Data               map[string]any `json:"data"`
 }
 
-func NewResponseEvent(request_id string, message string, data map[string]any) *BaseRequestEvent {
-	return &BaseRequestEvent{
+func NewResponseEvent(request_id string, message string, data map[string]any) *BackwardsInvocationResponseEvent {
+	return &BackwardsInvocationResponseEvent{
 		BackwardsRequestId: request_id,
 		Event:              REQUEST_EVENT_RESPONSE,
 		Message:            message,
@@ -24,8 +24,8 @@ func NewResponseEvent(request_id string, message string, data map[string]any) *B
 	}
 }
 
-func NewErrorEvent(request_id string, message string) *BaseRequestEvent {
-	return &BaseRequestEvent{
+func NewErrorEvent(request_id string, message string) *BackwardsInvocationResponseEvent {
+	return &BackwardsInvocationResponseEvent{
 		BackwardsRequestId: request_id,
 		Event:              REQUEST_EVENT_ERROR,
 		Message:            message,
@@ -33,8 +33,8 @@ func NewErrorEvent(request_id string, message string) *BaseRequestEvent {
 	}
 }
 
-func NewEndEvent(request_id string) *BaseRequestEvent {
-	return &BaseRequestEvent{
+func NewEndEvent(request_id string) *BackwardsInvocationResponseEvent {
+	return &BackwardsInvocationResponseEvent{
 		BackwardsRequestId: request_id,
 		Event:              REQUEST_EVENT_END,
 		Message:            "",

+ 12 - 3
internal/core/plugin_daemon/backwards_invocation/request.go

@@ -34,15 +34,24 @@ func (bi *BackwardsInvocation) GetID() string {
 }
 
 func (bi *BackwardsInvocation) WriteError(err error) {
-	bi.session.Write(parser.MarshalJsonBytes(NewErrorEvent(bi.id, err.Error())))
+	bi.session.Write(
+		session_manager.PLUGIN_IN_STREAM_EVENT_RESPONSE,
+		NewErrorEvent(bi.id, err.Error()),
+	)
 }
 
 func (bi *BackwardsInvocation) Write(message string, data any) {
-	bi.session.Write(parser.MarshalJsonBytes(NewResponseEvent(bi.id, message, parser.StructToMap(data))))
+	bi.session.Write(
+		session_manager.PLUGIN_IN_STREAM_EVENT_RESPONSE,
+		NewResponseEvent(bi.id, message, parser.StructToMap(data)),
+	)
 }
 
 func (bi *BackwardsInvocation) End() {
-	bi.session.Write(parser.MarshalJsonBytes(NewEndEvent(bi.id)))
+	bi.session.Write(
+		session_manager.PLUGIN_IN_STREAM_EVENT_RESPONSE,
+		NewEndEvent(bi.id),
+	)
 }
 
 func (bi *BackwardsInvocation) Type() BackwardsInvocationType {

+ 4 - 12
internal/core/plugin_daemon/basic.go

@@ -22,18 +22,10 @@ const (
 	PLUGIN_ACCESS_ACTION_VALIDATE_MODEL_CREDENTIALS    PluginAccessAction = "validate_model_credentials"
 )
 
-const (
-	PLUGIN_IN_STREAM_EVENT = "request"
-)
-
-func getBasicPluginAccessMap(session_id string, user_id string, access_type PluginAccessType, action PluginAccessAction) map[string]any {
+func getBasicPluginAccessMap(user_id string, access_type PluginAccessType, action PluginAccessAction) map[string]any {
 	return map[string]any{
-		"session_id": session_id,
-		"event":      PLUGIN_IN_STREAM_EVENT,
-		"data": map[string]any{
-			"user_id": user_id,
-			"type":    access_type,
-			"action":  action,
-		},
+		"user_id": user_id,
+		"type":    access_type,
+		"action":  action,
 	}
 }

+ 1 - 1
internal/core/plugin_daemon/invoke_dify.go

@@ -18,7 +18,7 @@ func invokeDify(
 	session *session_manager.Session, data []byte,
 ) error {
 	// unmarshal invoke data
-	request, err := parser.UnmarshalJsonBytes[map[string]any](data)
+	request, err := parser.UnmarshalJsonBytes2Map(data)
 	if err != nil {
 		return fmt.Errorf("unmarshal invoke request failed: %s", err.Error())
 	}

+ 5 - 7
internal/core/plugin_daemon/model_service.go

@@ -66,14 +66,15 @@ func genericInvokePlugin[Req any, Rsp any](
 		listener.Close()
 	})
 
-	runtime.Write(session.ID(), []byte(parser.MarshalJson(
+	session.Write(
+		session_manager.PLUGIN_IN_STREAM_EVENT_REQUEST,
 		getInvokeModelMap(
 			session,
 			typ,
 			action,
 			request,
 		),
-	)))
+	)
 
 	return response, nil
 }
@@ -84,13 +85,10 @@ func getInvokeModelMap(
 	action PluginAccessAction,
 	request any,
 ) map[string]any {
-	req := getBasicPluginAccessMap(session.ID(), session.UserID(), typ, action)
-	data := req["data"].(map[string]any)
-
+	req := getBasicPluginAccessMap(session.UserID(), typ, action)
 	for k, v := range parser.StructToMap(request) {
-		data[k] = v
+		req[k] = v
 	}
-
 	return req
 }
 

+ 2 - 0
internal/core/plugin_manager/local_manager/environment.go

@@ -7,6 +7,7 @@ import (
 	"path"
 	"strings"
 	"sync"
+	"syscall"
 	"time"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
@@ -21,6 +22,7 @@ func (r *LocalPluginRuntime) InitEnvironment() error {
 	// execute init command
 	handle := exec.Command("bash", r.Config.Execution.Install)
 	handle.Dir = r.State.RelativePath
+	handle.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
 
 	// get stdout and stderr
 	stdout, err := handle.StdoutPipe()

+ 10 - 0
internal/core/plugin_manager/local_manager/run.go

@@ -7,6 +7,7 @@ import (
 	"sync"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager/stdio_holder"
+	"github.com/langgenius/dify-plugin-daemon/internal/process"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
@@ -35,6 +36,7 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 	// start plugin
 	e := exec.Command("bash", r.Config.Execution.Launch)
 	e.Dir = r.State.RelativePath
+	process.WrapProcess(e)
 
 	// get writer
 	stdin, err := e.StdinPipe()
@@ -42,23 +44,31 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 		r.State.Status = entities.PLUGIN_RUNTIME_STATUS_RESTARTING
 		return fmt.Errorf("get stdin pipe failed: %s", err.Error())
 	}
+	defer stdin.Close()
 
 	stdout, err := e.StdoutPipe()
 	if err != nil {
 		r.State.Status = entities.PLUGIN_RUNTIME_STATUS_RESTARTING
 		return fmt.Errorf("get stdout pipe failed: %s", err.Error())
 	}
+	defer stdout.Close()
 
 	stderr, err := e.StderrPipe()
 	if err != nil {
 		r.State.Status = entities.PLUGIN_RUNTIME_STATUS_RESTARTING
 		return fmt.Errorf("get stderr pipe failed: %s", err.Error())
 	}
+	defer stderr.Close()
 
 	if err := e.Start(); err != nil {
 		r.State.Status = entities.PLUGIN_RUNTIME_STATUS_RESTARTING
 		return err
 	}
+
+	// add to subprocess manager
+	process.NewProcess(e)
+	defer process.RemoveProcess(e)
+
 	defer func() {
 		// wait for plugin to exit
 		err = e.Wait()

+ 1 - 0
internal/core/plugin_manager/stdio_holder/io.go

@@ -66,6 +66,7 @@ func (s *stdioHolder) Stop() {
 
 func (s *stdioHolder) StartStdout() {
 	s.started = true
+	s.last_active_at = time.Now()
 	defer s.Stop()
 
 	scanner := bufio.NewScanner(s.reader)

+ 14 - 2
internal/core/session_manager/session.go

@@ -6,6 +6,7 @@ import (
 
 	"github.com/google/uuid"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 )
 
 var (
@@ -79,10 +80,21 @@ func (s *Session) BindRuntime(runtime entities.PluginRuntimeSessionIOInterface)
 	s.runtime = runtime
 }
 
-func (s *Session) Write(data []byte) error {
+type PLUGIN_IN_STREAM_EVENT string
+
+const (
+	PLUGIN_IN_STREAM_EVENT_REQUEST  PLUGIN_IN_STREAM_EVENT = "request"
+	PLUGIN_IN_STREAM_EVENT_RESPONSE PLUGIN_IN_STREAM_EVENT = "backwards_response"
+)
+
+func (s *Session) Write(event PLUGIN_IN_STREAM_EVENT, data any) error {
 	if s.runtime == nil {
 		return errors.New("runtime not bound")
 	}
-	s.runtime.Write(s.id, data)
+	s.runtime.Write(s.id, parser.MarshalJsonBytes(map[string]any{
+		"session_id": s.id,
+		"event":      event,
+		"data":       data,
+	}))
 	return nil
 }

+ 56 - 0
internal/process/manager.go

@@ -0,0 +1,56 @@
+package process
+
+import (
+	"os"
+	"os/exec"
+	"os/signal"
+	"sync"
+	"syscall"
+)
+
+var (
+	subprocesses map[int]*exec.Cmd
+	l            *sync.Mutex
+)
+
+func Init() {
+	subprocesses = make(map[int]*exec.Cmd)
+	l = &sync.Mutex{}
+
+	sig := make(chan os.Signal, 1)
+	signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
+
+	go func() {
+		<-sig
+		TerminateAll()
+		os.Exit(0)
+	}()
+}
+
+func WrapProcess(cmd *exec.Cmd) {
+	cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
+}
+
+func NewProcess(cmd *exec.Cmd) {
+	l.Lock()
+	defer l.Unlock()
+	subprocesses[cmd.Process.Pid] = cmd
+}
+
+func RemoveProcess(cmd *exec.Cmd) {
+	l.Lock()
+	defer l.Unlock()
+
+	delete(subprocesses, cmd.Process.Pid)
+}
+
+func TerminateAll() {
+	l.Lock()
+	defer l.Unlock()
+
+	for _, cmd := range subprocesses {
+		if cmd.Process != nil {
+			syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
+		}
+	}
+}

+ 4 - 0
internal/server/server.go

@@ -2,6 +2,7 @@ package server
 
 import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/process"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
@@ -10,6 +11,9 @@ func Run(config *app.Config) {
 	// init routine pool
 	routine.InitPool(config.RoutinePoolSize)
 
+	// init process lifetime
+	process.Init()
+
 	// init plugin daemon
 	plugin_manager.Init(config)
 

+ 8 - 2
internal/utils/parser/json.go

@@ -34,6 +34,12 @@ func MarshalJsonBytes[T any](data T) []byte {
 	return b
 }
 
-func UnmarshalJson2Map(json []byte) (map[string]any, error) {
-	return UnmarshalJsonBytes[map[string]any](json)
+func UnmarshalJsonBytes2Map(data []byte) (map[string]any, error) {
+	result := map[string]any{}
+	err := json.Unmarshal(data, &result)
+	return result, err
+}
+
+func UnmarshalJson2Map(json string) (map[string]any, error) {
+	return UnmarshalJsonBytes2Map([]byte(json))
 }