소스 검색

feat: generic invocation

Yeuoly 1 년 전
부모
커밋
5b96e61a9d

+ 7 - 2
internal/core/plugin_daemon/basic.go

@@ -10,8 +10,13 @@ const (
 type PluginAccessAction string
 
 const (
-	PLUGIN_ACCESS_ACTION_INVOKE_TOOL PluginAccessAction = "invoke_tool"
-	PLUGIN_ACCESS_ACTION_INVOKE_LLM  PluginAccessAction = "invoke_llm"
+	PLUGIN_ACCESS_ACTION_INVOKE_TOOL           PluginAccessAction = "invoke_tool"
+	PLUGIN_ACCESS_ACTION_INVOKE_LLM            PluginAccessAction = "invoke_llm"
+	PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING PluginAccessAction = "invoke_text_embedding"
+	PLUGIN_ACCESS_ACTION_INVOKE_RERANK         PluginAccessAction = "invoke_rerank"
+	PLUGIN_ACCESS_ACTION_INVOKE_TTS            PluginAccessAction = "invoke_tts"
+	PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT    PluginAccessAction = "invoke_speech2text"
+	PLUGIN_ACCESS_ACTION_INVOKE_MODERATION     PluginAccessAction = "invoke_moderation"
 )
 
 const (

+ 124 - 27
internal/core/plugin_daemon/model_service.go

@@ -5,6 +5,7 @@ import (
 
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
@@ -12,39 +13,21 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
-func getInvokeModelMap(
+func genericInvokePlugin[Req any, Rsp any](
 	session *session_manager.Session,
+	request *Req,
+	response_buffer_size int,
+	typ PluginAccessType,
 	action PluginAccessAction,
-	request *requests.RequestInvokeLLM,
-) map[string]any {
-	req := getBasicPluginAccessMap(session.ID(), session.UserID(), PLUGIN_ACCESS_TYPE_MODEL, action)
-	data := req["data"].(map[string]any)
-
-	data["provider"] = request.Provider
-	data["model"] = request.Model
-	data["model_type"] = request.ModelType
-	data["model_parameters"] = request.ModelParameters
-	data["prompt_messages"] = request.PromptMessages
-	data["tools"] = request.Tools
-	data["stop"] = request.Stop
-	data["stream"] = request.Stream
-	data["credentials"] = request.Credentials
-
-	return req
-}
-
-func InvokeLLM(
-	session *session_manager.Session,
-	request *requests.RequestInvokeLLM,
 ) (
-	*stream.StreamResponse[plugin_entities.InvokeModelResponseChunk], error,
+	*stream.StreamResponse[Rsp], error,
 ) {
 	runtime := plugin_manager.Get(session.PluginIdentity())
 	if runtime == nil {
 		return nil, errors.New("plugin not found")
 	}
 
-	response := stream.NewStreamResponse[plugin_entities.InvokeModelResponseChunk](512)
+	response := stream.NewStreamResponse[Rsp](response_buffer_size)
 
 	listener := runtime.Listen(session.ID())
 	listener.AddListener(func(message []byte) {
@@ -56,7 +39,7 @@ func InvokeLLM(
 
 		switch chunk.Type {
 		case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
-			chunk, err := parser.UnmarshalJsonBytes[plugin_entities.InvokeModelResponseChunk](chunk.Data)
+			chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
 			if err != nil {
 				log.Error("unmarshal json failed: %s", err.Error())
 				return
@@ -66,8 +49,15 @@ func InvokeLLM(
 			invokeDify(runtime, session, chunk.Data)
 		case plugin_entities.SESSION_MESSAGE_TYPE_END:
 			response.Close()
+		case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
+			e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
+			if err != nil {
+				break
+			}
+			response.WriteError(errors.New(e.Error))
+			response.Close()
 		default:
-			log.Error("unknown stream message type: %s", chunk.Type)
+			response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
 			response.Close()
 		}
 	})
@@ -79,10 +69,117 @@ func InvokeLLM(
 	runtime.Write(session.ID(), []byte(parser.MarshalJson(
 		getInvokeModelMap(
 			session,
-			PLUGIN_ACCESS_ACTION_INVOKE_LLM,
+			typ,
+			action,
 			request,
 		),
 	)))
 
 	return response, nil
 }
+
+func getInvokeModelMap(
+	session *session_manager.Session,
+	typ PluginAccessType,
+	action PluginAccessAction,
+	request any,
+) map[string]any {
+	req := getBasicPluginAccessMap(session.ID(), session.UserID(), typ, action)
+	data := req["data"].(map[string]any)
+
+	for k, v := range parser.StructToMap(request) {
+		data[k] = v
+	}
+
+	return req
+}
+
+func InvokeLLM(
+	session *session_manager.Session,
+	request *requests.RequestInvokeLLM,
+) (
+	*stream.StreamResponse[model_entities.LLMResultChunk], error,
+) {
+	return genericInvokePlugin[requests.RequestInvokeLLM, model_entities.LLMResultChunk](
+		session,
+		request,
+		512,
+		PLUGIN_ACCESS_TYPE_MODEL,
+		PLUGIN_ACCESS_ACTION_INVOKE_LLM,
+	)
+}
+
+func InvokeTextEmbedding(
+	session *session_manager.Session,
+	request *requests.RequestInvokeTextEmbedding,
+) (
+	*stream.StreamResponse[model_entities.TextEmbeddingResult], error,
+) {
+	return genericInvokePlugin[requests.RequestInvokeTextEmbedding, model_entities.TextEmbeddingResult](
+		session,
+		request,
+		1,
+		PLUGIN_ACCESS_TYPE_MODEL,
+		PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
+	)
+}
+
+func InvokeRerank(
+	session *session_manager.Session,
+	request *requests.RequestInvokeRerank,
+) (
+	*stream.StreamResponse[model_entities.RerankResult], error,
+) {
+	return genericInvokePlugin[requests.RequestInvokeRerank, model_entities.RerankResult](
+		session,
+		request,
+		1,
+		PLUGIN_ACCESS_TYPE_MODEL,
+		PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
+	)
+}
+
+func InvokeTTS(
+	session *session_manager.Session,
+	request *requests.RequestInvokeTTS,
+) (
+	*stream.StreamResponse[string], error,
+) {
+	return genericInvokePlugin[requests.RequestInvokeTTS, string](
+		session,
+		request,
+		1,
+		PLUGIN_ACCESS_TYPE_MODEL,
+		PLUGIN_ACCESS_ACTION_INVOKE_TTS,
+	)
+}
+
+func InvokeSpeech2Text(
+	session *session_manager.Session,
+	request *requests.RequestInvokeSpeech2Text,
+) (
+	*stream.StreamResponse[string], error,
+) {
+	return genericInvokePlugin[requests.RequestInvokeSpeech2Text, string](
+		session,
+		request,
+		1,
+		PLUGIN_ACCESS_TYPE_MODEL,
+		PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
+	)
+}
+
+func InvokeModeration(
+	session *session_manager.Session,
+	request *requests.RequestInvokeModeration,
+) (
+	*stream.StreamResponse[bool], error,
+) {
+	return genericInvokePlugin[requests.RequestInvokeModeration, bool](
+		session,
+		request,
+		1,
+		PLUGIN_ACCESS_TYPE_MODEL,
+		PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
+	)
+}

+ 7 - 70
internal/core/plugin_daemon/tool_service.go

@@ -1,86 +1,23 @@
 package plugin_daemon
 
 import (
-	"errors"
-
-	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
-	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
-	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
-func getInvokeToolMap(
-	session *session_manager.Session,
-	action PluginAccessAction,
-	request *requests.RequestInvokeTool,
-) map[string]any {
-	req := getBasicPluginAccessMap(session.ID(), session.UserID(), PLUGIN_ACCESS_TYPE_TOOL, action)
-	data := req["data"].(map[string]any)
-
-	data["provider"] = request.Provider
-	data["tool"] = request.Tool
-	data["parameters"] = request.ToolParameters
-	data["credentials"] = request.Credentials
-
-	return req
-}
-
 func InvokeTool(
 	session *session_manager.Session,
 	request *requests.RequestInvokeTool,
 ) (
 	*stream.StreamResponse[plugin_entities.ToolResponseChunk], error,
 ) {
-	runtime := plugin_manager.Get(session.PluginIdentity())
-	if runtime == nil {
-		return nil, errors.New("plugin not found")
-	}
-
-	response := stream.NewStreamResponse[plugin_entities.ToolResponseChunk](512)
-
-	listener := runtime.Listen(session.ID())
-	listener.AddListener(func(message []byte) {
-		chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message)
-		if err != nil {
-			log.Error("unmarshal json failed: %s", err.Error())
-			return
-		}
-
-		switch chunk.Type {
-		case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
-			chunk, err := parser.UnmarshalJsonBytes[plugin_entities.ToolResponseChunk](chunk.Data)
-			if err != nil {
-				log.Error("unmarshal json failed: %s", err.Error())
-				return
-			}
-			response.Write(chunk)
-		case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
-			invokeDify(runtime, session, chunk.Data)
-		case plugin_entities.SESSION_MESSAGE_TYPE_END:
-			response.Close()
-		case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
-			e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
-			if err != nil {
-				break
-			}
-			response.WriteError(errors.New(e.Error))
-			response.Close()
-		default:
-			response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
-			response.Close()
-		}
-	})
-
-	response.OnClose(func() {
-		listener.Close()
-	})
-
-	runtime.Write(session.ID(), []byte(parser.MarshalJson(
-		getInvokeToolMap(session, PLUGIN_ACCESS_ACTION_INVOKE_TOOL, request)),
-	))
-
-	return response, nil
+	return genericInvokePlugin[requests.RequestInvokeTool, plugin_entities.ToolResponseChunk](
+		session,
+		request,
+		128,
+		PLUGIN_ACCESS_TYPE_TOOL,
+		PLUGIN_ACCESS_ACTION_INVOKE_TOOL,
+	)
 }

+ 5 - 5
internal/core/plugin_manager/stdio_holder/io.go

@@ -71,6 +71,10 @@ func (s *stdioHolder) StartStdout() {
 	scanner := bufio.NewScanner(s.reader)
 	for scanner.Scan() {
 		data := scanner.Bytes()
+		if len(data) == 0 {
+			continue
+		}
+
 		event, err := parser.UnmarshalJsonBytes[plugin_entities.PluginUniversalEvent](data)
 		if err != nil {
 			// log.Error("unmarshal json failed: %s", err.Error())
@@ -101,11 +105,7 @@ func (s *stdioHolder) StartStdout() {
 				}
 			}
 		case plugin_entities.PLUGIN_EVENT_ERROR:
-			for listener_session_id, listener := range s.error_listener {
-				if listener_session_id == session_id {
-					listener(event.Data)
-				}
-			}
+			log.Error("plugin %s: %s", s.plugin_identity, event.Data)
 		case plugin_entities.PLUGIN_EVENT_HEARTBEAT:
 			s.last_active_at = time.Now()
 		}

+ 2 - 1
internal/service/invoke.go

@@ -4,6 +4,7 @@ import (
 	"github.com/gin-gonic/gin"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
@@ -25,7 +26,7 @@ func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM]
 	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
 	defer session.Close()
 
-	baseSSEService(r, func() (*stream.StreamResponse[plugin_entities.InvokeModelResponseChunk], error) {
+	baseSSEService(r, func() (*stream.StreamResponse[model_entities.LLMResultChunk], error) {
 		return plugin_daemon.InvokeLLM(session, &r.Data)
 	}, ctx)
 }

+ 0 - 4
internal/types/entities/plugin_entities/event.go

@@ -2,8 +2,6 @@ package plugin_entities
 
 import (
 	"encoding/json"
-
-	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
 )
 
 type PluginUniversalEvent struct {
@@ -51,8 +49,6 @@ type PluginResponseChunk struct {
 	Data json.RawMessage `json:"data"`
 }
 
-type InvokeModelResponseChunk = model_entities.LLMResultChunk
-
 type ErrorResponse struct {
 	Error string `json:"error"`
 }

+ 24 - 9
internal/types/entities/requests/model.go

@@ -5,20 +5,29 @@ import (
 )
 
 type BaseRequestInvokeModel struct {
-	Provider    string                   `json:"provider"`
-	ModelType   model_entities.ModelType `json:"model_type" validate:"required,model_type"`
-	Model       string                   `json:"model"`
+	Provider    string                   `json:"provider" validate:"required"`
+	ModelType   model_entities.ModelType `json:"model_type" mapstructure:"model_type" validate:"required,model_type"`
+	Model       string                   `json:"model" validate:"required"`
 	Credentials map[string]any           `json:"credentials" validate:"omitempty,dive,is_basic_type"`
 }
 
+func (r *BaseRequestInvokeModel) ToCallerArguments() map[string]any {
+	return map[string]any{
+		"provider":    r.Provider,
+		"model":       r.Model,
+		"model_type":  r.ModelType,
+		"credentials": r.Credentials,
+	}
+}
+
 type RequestInvokeLLM struct {
 	BaseRequestInvokeModel
 
-	ModelParameters map[string]any                     `json:"model_parameters" validate:"omitempty,dive,is_basic_type"`
-	PromptMessages  []model_entities.PromptMessage     `json:"prompt_messages" validate:"omitempty,dive"`
+	ModelParameters map[string]any                     `json:"model_parameters" mapstructure:"model_parameters" validate:"omitempty,dive,is_basic_type"`
+	PromptMessages  []model_entities.PromptMessage     `json:"prompt_messages" mapstructure:"prompt_messages" validate:"omitempty,dive"`
 	Tools           []model_entities.PromptMessageTool `json:"tools" validate:"omitempty,dive"`
 	Stop            []string                           `json:"stop" validate:"omitempty"`
-	Stream          bool                               `json:"stream"`
+	Stream          bool                               `json:"stream" mapstructure:"stream"`
 }
 
 type RequestInvokeTextEmbedding struct {
@@ -32,14 +41,14 @@ type RequestInvokeRerank struct {
 
 	Query          string   `json:"query" validate:"required"`
 	Docs           []string `json:"docs" validate:"required,dive"`
-	ScoreThreshold float64  `json:"score_threshold"`
-	TopN           int      `json:"top_n"`
+	ScoreThreshold float64  `json:"score_threshold" mapstructure:"score_threshold"`
+	TopN           int      `json:"top_n" mapstructure:"top_n"`
 }
 
 type RequestInvokeTTS struct {
 	BaseRequestInvokeModel
 
-	ContentText string `json:"content_text" validate:"required"`
+	ContentText string `json:"content_text" mapstructure:"content_text" validate:"required"`
 	Voice       string `json:"voice" validate:"required"`
 }
 
@@ -48,3 +57,9 @@ type RequestInvokeSpeech2Text struct {
 
 	File string `json:"file" validate:"required"` // base64 encoded voice file
 }
+
+type RequestInvokeModeration struct {
+	BaseRequestInvokeModel
+
+	Text string `json:"text" validate:"required"`
+}

+ 48 - 0
internal/utils/parser/struct2map.go

@@ -0,0 +1,48 @@
+package parser
+
+import (
+	"reflect"
+	"unicode"
+)
+
+func StructToMap(data interface{}) map[string]interface{} {
+	result := make(map[string]interface{})
+	val := reflect.ValueOf(data)
+	if val.Kind() == reflect.Ptr {
+		val = val.Elem()
+	}
+	for i := 0; i < val.NumField(); i++ {
+		field := val.Field(i)
+		typeField := val.Type().Field(i)
+		fieldName := toSnakeCase(typeField.Name)
+
+		if typeField.Anonymous {
+			embeddedFields := StructToMap(field.Interface())
+			for k, v := range embeddedFields {
+				result[k] = v
+			}
+		} else {
+			result[fieldName] = field.Interface()
+		}
+	}
+	return result
+}
+
+func toSnakeCase(str string) string {
+	runes := []rune(str)
+	length := len(runes)
+	var out []rune
+
+	for i := 0; i < length; i++ {
+		if unicode.IsUpper(runes[i]) {
+			if i > 0 {
+				out = append(out, '_')
+			}
+			out = append(out, unicode.ToLower(runes[i]))
+		} else {
+			out = append(out, runes[i])
+		}
+	}
+
+	return string(out)
+}