Browse Source

feat: backwards invocation

Yeuoly 1 year ago
parent
commit
483af3c5ba

+ 1 - 0
internal/core/dify_invocation/types.go

@@ -74,6 +74,7 @@ type InvokeModerationRequest struct {
 type InvokeToolRequest struct {
 type InvokeToolRequest struct {
 	BaseInvokeDifyRequest
 	BaseInvokeDifyRequest
 	Data struct {
 	Data struct {
+		ToolType requests.ToolType `json:"tool_type" validate:"required,tool_type"`
 		requests.RequestInvokeTool
 		requests.RequestInvokeTool
 	} `json:"data" validate:"required"`
 	} `json:"data" validate:"required"`
 }
 }

+ 3 - 3
internal/core/plugin_daemon/backwards_invocation/entities.go

@@ -3,9 +3,9 @@ package backwards_invocation
 type RequestEvent string
 type RequestEvent string
 
 
 const (
 const (
-	REQUEST_EVENT_RESPONSE RequestEvent = "response"
-	REQUEST_EVENT_ERROR    RequestEvent = "error"
-	REQUEST_EVENT_END      RequestEvent = "end"
+	REQUEST_EVENT_RESPONSE RequestEvent = "backward_invocation_response"
+	REQUEST_EVENT_ERROR    RequestEvent = "backward_invocation_error"
+	REQUEST_EVENT_END      RequestEvent = "backward_invocation_end"
 )
 )
 
 
 type BaseRequestEvent struct {
 type BaseRequestEvent struct {

+ 18 - 2
internal/core/plugin_daemon/backwards_invocation/request.go

@@ -1,6 +1,8 @@
 package backwards_invocation
 package backwards_invocation
 
 
 import (
 import (
+	"fmt"
+
 	"github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
@@ -35,8 +37,8 @@ func (bi *BackwardsInvocation) WriteError(err error) {
 	bi.session.Write(parser.MarshalJsonBytes(NewErrorEvent(bi.id, err.Error())))
 	bi.session.Write(parser.MarshalJsonBytes(NewErrorEvent(bi.id, err.Error())))
 }
 }
 
 
-func (bi *BackwardsInvocation) Write(message string, data map[string]any) {
-	bi.session.Write(parser.MarshalJsonBytes(NewResponseEvent(bi.id, message, data)))
+func (bi *BackwardsInvocation) Write(message string, data any) {
+	bi.session.Write(parser.MarshalJsonBytes(NewResponseEvent(bi.id, message, parser.StructToMap(data))))
 }
 }
 
 
 func (bi *BackwardsInvocation) End() {
 func (bi *BackwardsInvocation) End() {
@@ -50,3 +52,17 @@ func (bi *BackwardsInvocation) Type() BackwardsInvocationType {
 func (bi *BackwardsInvocation) RequestData() map[string]any {
 func (bi *BackwardsInvocation) RequestData() map[string]any {
 	return bi.detailed_request
 	return bi.detailed_request
 }
 }
+
+func (bi *BackwardsInvocation) TenantID() (string, error) {
+	if bi.session == nil {
+		return "", fmt.Errorf("session is nil")
+	}
+	return bi.session.TenantID(), nil
+}
+
+func (bi *BackwardsInvocation) UserID() (string, error) {
+	if bi.session == nil {
+		return "", fmt.Errorf("session is nil")
+	}
+	return bi.session.UserID(), nil
+}

+ 33 - 7
internal/core/plugin_daemon/invoke_dify.go

@@ -7,7 +7,9 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
 )
 )
 
 
 func invokeDify(
 func invokeDify(
@@ -17,11 +19,14 @@ func invokeDify(
 ) error {
 ) error {
 	// unmarshal invoke data
 	// unmarshal invoke data
 	request, err := parser.UnmarshalJsonBytes[map[string]any](data)
 	request, err := parser.UnmarshalJsonBytes[map[string]any](data)
-
 	if err != nil {
 	if err != nil {
 		return fmt.Errorf("unmarshal invoke request failed: %s", err.Error())
 		return fmt.Errorf("unmarshal invoke request failed: %s", err.Error())
 	}
 	}
 
 
+	if request == nil {
+		return fmt.Errorf("invoke request is empty")
+	}
+
 	// prepare invocation arguments
 	// prepare invocation arguments
 	request_handle, err := prepareDifyInvocationArguments(session, request)
 	request_handle, err := prepareDifyInvocationArguments(session, request)
 	if err != nil {
 	if err != nil {
@@ -35,7 +40,9 @@ func invokeDify(
 	}
 	}
 
 
 	// dispatch invocation task
 	// dispatch invocation task
-	dispatchDifyInvocationTask(request_handle)
+	routine.Submit(func() {
+		dispatchDifyInvocationTask(request_handle)
+	})
 
 
 	return nil
 	return nil
 }
 }
@@ -65,20 +72,39 @@ func prepareDifyInvocationArguments(session *session_manager.Session, request ma
 }
 }
 
 
 func dispatchDifyInvocationTask(handle *backwards_invocation.BackwardsInvocation) {
 func dispatchDifyInvocationTask(handle *backwards_invocation.BackwardsInvocation) {
+	request_data := handle.RequestData()
+	tenant_id, err := handle.TenantID()
+	if err != nil {
+		handle.WriteError(fmt.Errorf("get tenant id failed: %s", err.Error()))
+		return
+	}
+	request_data["tenant_id"] = tenant_id
+	user_id, err := handle.UserID()
+	if err != nil {
+		handle.WriteError(fmt.Errorf("get user id failed: %s", err.Error()))
+		return
+	}
+	request_data["user_id"] = user_id
+
 	switch handle.Type() {
 	switch handle.Type() {
 	case dify_invocation.INVOKE_TYPE_TOOL:
 	case dify_invocation.INVOKE_TYPE_TOOL:
-		_, err := parser.MapToStruct[dify_invocation.InvokeToolRequest](handle.RequestData())
+		r, err := parser.MapToStruct[dify_invocation.InvokeToolRequest](handle.RequestData())
 		if err != nil {
 		if err != nil {
 			handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error()))
 			handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error()))
 			return
 			return
 		}
 		}
-
+		executeDifyInvocationToolTask(handle, r)
 	default:
 	default:
 		handle.WriteError(fmt.Errorf("unsupported invoke type: %s", handle.Type()))
 		handle.WriteError(fmt.Errorf("unsupported invoke type: %s", handle.Type()))
 	}
 	}
 }
 }
 
 
-func setTaskContext(session *session_manager.Session, r *dify_invocation.BaseInvokeDifyRequest) {
-	r.TenantId = session.TenantID()
-	r.UserId = session.UserID()
+func executeDifyInvocationToolTask(handle *backwards_invocation.BackwardsInvocation, request *dify_invocation.InvokeToolRequest) {
+	handle.Write("stream", tool_entities.ToolResponseChunk{
+		Type: "text",
+		Message: map[string]any{
+			"text": "hello world",
+		},
+	})
+	handle.End()
 }
 }

+ 1 - 3
internal/core/plugin_manager/local_manager/run.go

@@ -59,7 +59,6 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 		r.State.Status = entities.PLUGIN_RUNTIME_STATUS_RESTARTING
 		r.State.Status = entities.PLUGIN_RUNTIME_STATUS_RESTARTING
 		return err
 		return err
 	}
 	}
-
 	defer func() {
 	defer func() {
 		// wait for plugin to exit
 		// wait for plugin to exit
 		err = e.Wait()
 		err = e.Wait()
@@ -70,6 +69,7 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 
 
 		r.gc()
 		r.gc()
 	}()
 	}()
+	defer e.Process.Kill()
 
 
 	log.Info("plugin %s started", r.Config.Identity())
 	log.Info("plugin %s started", r.Config.Identity())
 
 
@@ -99,8 +99,6 @@ func (r *LocalPluginRuntime) StartPlugin() error {
 		return err
 		return err
 	}
 	}
 
 
-	e.Process.Kill()
-
 	wg.Wait()
 	wg.Wait()
 
 
 	// plugin has exited
 	// plugin has exited

+ 6 - 2
internal/core/plugin_manager/stdio_holder/io.go

@@ -116,7 +116,11 @@ func (s *stdioHolder) WriteError(msg string) {
 	const MAX_ERR_MSG_LEN = 1024
 	const MAX_ERR_MSG_LEN = 1024
 	reduce := len(msg) + len(s.err_message) - MAX_ERR_MSG_LEN
 	reduce := len(msg) + len(s.err_message) - MAX_ERR_MSG_LEN
 	if reduce > 0 {
 	if reduce > 0 {
-		s.err_message = s.err_message[reduce:]
+		if reduce > len(s.err_message) {
+			s.err_message = ""
+		} else {
+			s.err_message = s.err_message[reduce:]
+		}
 	}
 	}
 
 
 	s.err_message += msg
 	s.err_message += msg
@@ -163,7 +167,7 @@ func (s *stdioHolder) Wait() error {
 		case <-ticker.C:
 		case <-ticker.C:
 			// check heartbeat
 			// check heartbeat
 			if time.Since(s.last_active_at) > 20*time.Second {
 			if time.Since(s.last_active_at) > 20*time.Second {
-				return errors.New("plugin is not active")
+				return errors.New("plugin is not active, does not respond to heartbeat in 20 seconds")
 			}
 			}
 		case <-s.health_chan:
 		case <-s.health_chan:
 			// closed
 			// closed

+ 23 - 0
internal/types/entities/requests/tool.go

@@ -1,5 +1,28 @@
 package requests
 package requests
 
 
+import (
+	"github.com/go-playground/validator/v10"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
+)
+
+type ToolType string
+
+const (
+	TOOL_TYPE_BUILTIN  ToolType = "builtin"
+	TOOL_TYPE_WORKFLOW ToolType = "workflow"
+	TOOL_TYPE_API      ToolType = "api"
+)
+
+func init() {
+	validators.GlobalEntitiesValidator.RegisterValidation("tool_type", func(fl validator.FieldLevel) bool {
+		switch fl.Field().String() {
+		case string(TOOL_TYPE_BUILTIN), string(TOOL_TYPE_WORKFLOW), string(TOOL_TYPE_API):
+			return true
+		}
+		return false
+	})
+}
+
 type InvokeToolSchema struct {
 type InvokeToolSchema struct {
 	Provider       string         `json:"provider" validate:"required"`
 	Provider       string         `json:"provider" validate:"required"`
 	Tool           string         `json:"tool" validate:"required"`
 	Tool           string         `json:"tool" validate:"required"`

+ 2 - 2
internal/utils/parser/struct2map.go

@@ -4,8 +4,8 @@ import (
 	"github.com/mitchellh/mapstructure"
 	"github.com/mitchellh/mapstructure"
 )
 )
 
 
-func StructToMap(data interface{}) map[string]interface{} {
-	result := make(map[string]interface{})
+func StructToMap(data any) map[string]any {
+	result := make(map[string]any)
 
 
 	decoder := &mapstructure.DecoderConfig{
 	decoder := &mapstructure.DecoderConfig{
 		Metadata: nil,
 		Metadata: nil,