|
@@ -5,6 +5,7 @@ import (
|
|
|
|
|
|
"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
|
|
|
"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
|
|
|
+ "github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
|
|
|
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
|
|
|
"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
|
|
|
"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
|
|
@@ -12,39 +13,21 @@ import (
|
|
|
"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
|
|
|
)
|
|
|
|
|
|
-func getInvokeModelMap(
|
|
|
+func genericInvokePlugin[Req any, Rsp any](
|
|
|
session *session_manager.Session,
|
|
|
+ request *Req,
|
|
|
+ response_buffer_size int,
|
|
|
+ typ PluginAccessType,
|
|
|
action PluginAccessAction,
|
|
|
- request *requests.RequestInvokeLLM,
|
|
|
-) map[string]any {
|
|
|
- req := getBasicPluginAccessMap(session.ID(), session.UserID(), PLUGIN_ACCESS_TYPE_MODEL, action)
|
|
|
- data := req["data"].(map[string]any)
|
|
|
-
|
|
|
- data["provider"] = request.Provider
|
|
|
- data["model"] = request.Model
|
|
|
- data["model_type"] = request.ModelType
|
|
|
- data["model_parameters"] = request.ModelParameters
|
|
|
- data["prompt_messages"] = request.PromptMessages
|
|
|
- data["tools"] = request.Tools
|
|
|
- data["stop"] = request.Stop
|
|
|
- data["stream"] = request.Stream
|
|
|
- data["credentials"] = request.Credentials
|
|
|
-
|
|
|
- return req
|
|
|
-}
|
|
|
-
|
|
|
-func InvokeLLM(
|
|
|
- session *session_manager.Session,
|
|
|
- request *requests.RequestInvokeLLM,
|
|
|
) (
|
|
|
- *stream.StreamResponse[plugin_entities.InvokeModelResponseChunk], error,
|
|
|
+ *stream.StreamResponse[Rsp], error,
|
|
|
) {
|
|
|
runtime := plugin_manager.Get(session.PluginIdentity())
|
|
|
if runtime == nil {
|
|
|
return nil, errors.New("plugin not found")
|
|
|
}
|
|
|
|
|
|
- response := stream.NewStreamResponse[plugin_entities.InvokeModelResponseChunk](512)
|
|
|
+ response := stream.NewStreamResponse[Rsp](response_buffer_size)
|
|
|
|
|
|
listener := runtime.Listen(session.ID())
|
|
|
listener.AddListener(func(message []byte) {
|
|
@@ -56,7 +39,7 @@ func InvokeLLM(
|
|
|
|
|
|
switch chunk.Type {
|
|
|
case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
|
|
|
- chunk, err := parser.UnmarshalJsonBytes[plugin_entities.InvokeModelResponseChunk](chunk.Data)
|
|
|
+ chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
|
|
|
if err != nil {
|
|
|
log.Error("unmarshal json failed: %s", err.Error())
|
|
|
return
|
|
@@ -66,8 +49,15 @@ func InvokeLLM(
|
|
|
invokeDify(runtime, session, chunk.Data)
|
|
|
case plugin_entities.SESSION_MESSAGE_TYPE_END:
|
|
|
response.Close()
|
|
|
+ case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
|
|
|
+ e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
|
|
|
+ if err != nil {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ response.WriteError(errors.New(e.Error))
|
|
|
+ response.Close()
|
|
|
default:
|
|
|
- log.Error("unknown stream message type: %s", chunk.Type)
|
|
|
+ response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
|
|
|
response.Close()
|
|
|
}
|
|
|
})
|
|
@@ -79,10 +69,117 @@ func InvokeLLM(
|
|
|
runtime.Write(session.ID(), []byte(parser.MarshalJson(
|
|
|
getInvokeModelMap(
|
|
|
session,
|
|
|
- PLUGIN_ACCESS_ACTION_INVOKE_LLM,
|
|
|
+ typ,
|
|
|
+ action,
|
|
|
request,
|
|
|
),
|
|
|
)))
|
|
|
|
|
|
return response, nil
|
|
|
}
|
|
|
+
|
|
|
+func getInvokeModelMap(
|
|
|
+ session *session_manager.Session,
|
|
|
+ typ PluginAccessType,
|
|
|
+ action PluginAccessAction,
|
|
|
+ request any,
|
|
|
+) map[string]any {
|
|
|
+ req := getBasicPluginAccessMap(session.ID(), session.UserID(), typ, action)
|
|
|
+ data := req["data"].(map[string]any)
|
|
|
+
|
|
|
+ for k, v := range parser.StructToMap(request) {
|
|
|
+ data[k] = v
|
|
|
+ }
|
|
|
+
|
|
|
+ return req
|
|
|
+}
|
|
|
+
|
|
|
+func InvokeLLM(
|
|
|
+ session *session_manager.Session,
|
|
|
+ request *requests.RequestInvokeLLM,
|
|
|
+) (
|
|
|
+ *stream.StreamResponse[model_entities.LLMResultChunk], error,
|
|
|
+) {
|
|
|
+ return genericInvokePlugin[requests.RequestInvokeLLM, model_entities.LLMResultChunk](
|
|
|
+ session,
|
|
|
+ request,
|
|
|
+ 512,
|
|
|
+ PLUGIN_ACCESS_TYPE_MODEL,
|
|
|
+ PLUGIN_ACCESS_ACTION_INVOKE_LLM,
|
|
|
+ )
|
|
|
+}
|
|
|
+
|
|
|
+func InvokeTextEmbedding(
|
|
|
+ session *session_manager.Session,
|
|
|
+ request *requests.RequestInvokeTextEmbedding,
|
|
|
+) (
|
|
|
+ *stream.StreamResponse[model_entities.TextEmbeddingResult], error,
|
|
|
+) {
|
|
|
+ return genericInvokePlugin[requests.RequestInvokeTextEmbedding, model_entities.TextEmbeddingResult](
|
|
|
+ session,
|
|
|
+ request,
|
|
|
+ 1,
|
|
|
+ PLUGIN_ACCESS_TYPE_MODEL,
|
|
|
+ PLUGIN_ACCESS_ACTION_INVOKE_TEXT_EMBEDDING,
|
|
|
+ )
|
|
|
+}
|
|
|
+
|
|
|
+func InvokeRerank(
|
|
|
+ session *session_manager.Session,
|
|
|
+ request *requests.RequestInvokeRerank,
|
|
|
+) (
|
|
|
+ *stream.StreamResponse[model_entities.RerankResult], error,
|
|
|
+) {
|
|
|
+ return genericInvokePlugin[requests.RequestInvokeRerank, model_entities.RerankResult](
|
|
|
+ session,
|
|
|
+ request,
|
|
|
+ 1,
|
|
|
+ PLUGIN_ACCESS_TYPE_MODEL,
|
|
|
+ PLUGIN_ACCESS_ACTION_INVOKE_RERANK,
|
|
|
+ )
|
|
|
+}
|
|
|
+
|
|
|
+func InvokeTTS(
|
|
|
+ session *session_manager.Session,
|
|
|
+ request *requests.RequestInvokeTTS,
|
|
|
+) (
|
|
|
+ *stream.StreamResponse[string], error,
|
|
|
+) {
|
|
|
+ return genericInvokePlugin[requests.RequestInvokeTTS, string](
|
|
|
+ session,
|
|
|
+ request,
|
|
|
+ 1,
|
|
|
+ PLUGIN_ACCESS_TYPE_MODEL,
|
|
|
+ PLUGIN_ACCESS_ACTION_INVOKE_TTS,
|
|
|
+ )
|
|
|
+}
|
|
|
+
|
|
|
+func InvokeSpeech2Text(
|
|
|
+ session *session_manager.Session,
|
|
|
+ request *requests.RequestInvokeSpeech2Text,
|
|
|
+) (
|
|
|
+ *stream.StreamResponse[string], error,
|
|
|
+) {
|
|
|
+ return genericInvokePlugin[requests.RequestInvokeSpeech2Text, string](
|
|
|
+ session,
|
|
|
+ request,
|
|
|
+ 1,
|
|
|
+ PLUGIN_ACCESS_TYPE_MODEL,
|
|
|
+ PLUGIN_ACCESS_ACTION_INVOKE_SPEECH2TEXT,
|
|
|
+ )
|
|
|
+}
|
|
|
+
|
|
|
+func InvokeModeration(
|
|
|
+ session *session_manager.Session,
|
|
|
+ request *requests.RequestInvokeModeration,
|
|
|
+) (
|
|
|
+ *stream.StreamResponse[bool], error,
|
|
|
+) {
|
|
|
+ return genericInvokePlugin[requests.RequestInvokeModeration, bool](
|
|
|
+ session,
|
|
|
+ request,
|
|
|
+ 1,
|
|
|
+ PLUGIN_ACCESS_TYPE_MODEL,
|
|
|
+ PLUGIN_ACCESS_ACTION_INVOKE_MODERATION,
|
|
|
+ )
|
|
|
+}
|