浏览代码

feat: migrate to mapstruct

Yeuoly 1 年之前
父节点
当前提交
f62f3af5d1

+ 0 - 69
internal/core/dify_invocation/types.go

@@ -2,7 +2,6 @@ package dify_invocation
 
 import (
 	"encoding/json"
-	"fmt"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
 )
@@ -13,23 +12,6 @@ type BaseInvokeDifyRequest struct {
 	Type     InvokeType `json:"type"`
 }
 
-func (r *BaseInvokeDifyRequest) FromMap(data map[string]any) error {
-	var ok bool
-	if r.TenantId, ok = data["tenant_id"].(string); !ok {
-		return fmt.Errorf("tenant_id is not a string")
-	}
-
-	if r.UserId, ok = data["user_id"].(string); !ok {
-		return fmt.Errorf("user_id is not a string")
-	}
-
-	if r.Type, ok = data["type"].(InvokeType); !ok {
-		return fmt.Errorf("type is not a string")
-	}
-
-	return nil
-}
-
 type InvokeType string
 
 const (
@@ -46,27 +28,6 @@ type InvokeModelRequest struct {
 	Parameters map[string]any           `json:"parameters"`
 }
 
-func (r *InvokeModelRequest) FromMap(base map[string]any, data map[string]any) error {
-	var ok bool
-	if r.Provider, ok = data["provider"].(string); !ok {
-		return fmt.Errorf("provider is not a string")
-	}
-
-	if r.Model, ok = data["model"].(string); !ok {
-		return fmt.Errorf("model is not a string")
-	}
-
-	if r.ModelType, ok = data["model_type"].(model_entities.ModelType); !ok {
-		return fmt.Errorf("model_type is not a string")
-	}
-
-	if r.Parameters, ok = data["parameters"].(map[string]any); !ok {
-		return fmt.Errorf("parameters is not a map")
-	}
-
-	return nil
-}
-
 func (r InvokeModelRequest) MarshalJSON() ([]byte, error) {
 	flattened := make(map[string]any)
 	flattened["tenant_id"] = r.TenantId
@@ -87,23 +48,6 @@ type InvokeToolRequest struct {
 	Parameters map[string]any `json:"parameters"`
 }
 
-func (r *InvokeToolRequest) FromMap(base map[string]any, data map[string]any) error {
-	var ok bool
-	if r.Provider, ok = data["provider"].(string); !ok {
-		return fmt.Errorf("provider is not a string")
-	}
-
-	if r.Tool, ok = data["tool"].(string); !ok {
-		return fmt.Errorf("tool is not a string")
-	}
-
-	if r.Parameters, ok = data["parameters"].(map[string]any); !ok {
-		return fmt.Errorf("parameters is not a map")
-	}
-
-	return nil
-}
-
 func (r InvokeToolRequest) MarshalJSON() ([]byte, error) {
 	flattened := make(map[string]any)
 	flattened["tenant_id"] = r.TenantId
@@ -123,19 +67,6 @@ type InvokeNodeRequest[T WorkflowNodeData] struct {
 	NodeData T        `json:"node_data"`
 }
 
-func (r *InvokeNodeRequest[T]) FromMap(data map[string]any) error {
-	var ok bool
-	if r.NodeType, ok = data["node_type"].(NodeType); !ok {
-		return fmt.Errorf("node_type is not a string")
-	}
-
-	if err := r.NodeData.FromMap(data["node_data"].(map[string]any)); err != nil {
-		return err
-	}
-
-	return nil
-}
-
 func (r InvokeNodeRequest[T]) MarshalJSON() ([]byte, error) {
 	flattened := make(map[string]any)
 	flattened["tenant_id"] = r.TenantId

+ 1 - 15
internal/core/dify_invocation/workflow_node_data.go

@@ -1,9 +1,7 @@
 package dify_invocation
 
 type WorkflowNodeData interface {
-	FromMap(map[string]any) error
-
-	*KnowledgeRetrievalNodeData | *QuestionClassifierNodeData | *ParameterExtractorNodeData
+	KnowledgeRetrievalNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData
 }
 
 type NodeType string
@@ -18,20 +16,8 @@ const (
 type KnowledgeRetrievalNodeData struct {
 }
 
-func (r *KnowledgeRetrievalNodeData) FromMap(data map[string]any) error {
-	return nil
-}
-
 type QuestionClassifierNodeData struct {
 }
 
-func (r *QuestionClassifierNodeData) FromMap(data map[string]any) error {
-	return nil
-}
-
 type ParameterExtractorNodeData struct {
 }
-
-func (r *ParameterExtractorNodeData) FromMap(data map[string]any) error {
-	return nil
-}

+ 43 - 0
internal/core/plugin_daemon/backwards_invocation/entities.go

@@ -0,0 +1,43 @@
+package backwards_invocation
+
+type RequestEvent string
+
+const (
+	REQUEST_EVENT_RESPONSE RequestEvent = "response"
+	REQUEST_EVENT_ERROR    RequestEvent = "error"
+	REQUEST_EVENT_END      RequestEvent = "end"
+)
+
+type BaseRequestEvent struct {
+	BackwardsRequestId string         `json:"backwards_request_id"`
+	Event              RequestEvent   `json:"event"`
+	Message            string         `json:"message"`
+	Data               map[string]any `json:"data"`
+}
+
+func NewResponseEvent(request_id string, message string, data map[string]any) *BaseRequestEvent {
+	return &BaseRequestEvent{
+		BackwardsRequestId: request_id,
+		Event:              REQUEST_EVENT_RESPONSE,
+		Message:            message,
+		Data:               data,
+	}
+}
+
+func NewErrorEvent(request_id string, message string) *BaseRequestEvent {
+	return &BaseRequestEvent{
+		BackwardsRequestId: request_id,
+		Event:              REQUEST_EVENT_ERROR,
+		Message:            message,
+		Data:               nil,
+	}
+}
+
+func NewEndEvent(request_id string) *BaseRequestEvent {
+	return &BaseRequestEvent{
+		BackwardsRequestId: request_id,
+		Event:              REQUEST_EVENT_END,
+		Message:            "",
+		Data:               nil,
+	}
+}

+ 52 - 0
internal/core/plugin_daemon/backwards_invocation/request.go

@@ -0,0 +1,52 @@
+package backwards_invocation
+
+import (
+	"github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+)
+
+type BackwardsInvocationType = dify_invocation.InvokeType
+
+type BackwardsInvocation struct {
+	typ              BackwardsInvocationType
+	id               string
+	detailed_request map[string]any
+	session          *session_manager.Session
+}
+
+func NewBackwardsInvocation(
+	typ BackwardsInvocationType,
+	id string, session *session_manager.Session, detailed_request map[string]any,
+) *BackwardsInvocation {
+	return &BackwardsInvocation{
+		typ:              typ,
+		id:               id,
+		detailed_request: detailed_request,
+		session:          session,
+	}
+}
+
+func (bi *BackwardsInvocation) GetID() string {
+	return bi.id
+}
+
+func (bi *BackwardsInvocation) WriteError(err error) {
+	bi.session.Write(parser.MarshalJsonBytes(NewErrorEvent(bi.id, err.Error())))
+}
+
+func (bi *BackwardsInvocation) Write(message string, data map[string]any) {
+	bi.session.Write(parser.MarshalJsonBytes(NewResponseEvent(bi.id, message, data)))
+}
+
+func (bi *BackwardsInvocation) End() {
+	bi.session.Write(parser.MarshalJsonBytes(NewEndEvent(bi.id)))
+}
+
+func (bi *BackwardsInvocation) Type() BackwardsInvocationType {
+	return bi.typ
+}
+
+func (bi *BackwardsInvocation) RequestData() map[string]any {
+	return bi.detailed_request
+}

+ 54 - 27
internal/core/plugin_daemon/invoke_dify.go

@@ -4,6 +4,7 @@ import (
 	"fmt"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/core/dify_invocation"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/log"
@@ -13,6 +14,7 @@ import (
 
 func invokeDify(
 	runtime entities.PluginRuntimeInterface,
+	invoke_from PluginAccessType,
 	session *session_manager.Session, data []byte,
 ) error {
 	// unmarshal invoke data
@@ -22,37 +24,67 @@ func invokeDify(
 		return fmt.Errorf("unmarshal invoke request failed: %s", err.Error())
 	}
 
+	// prepare invocation arguments
+	request_handle, err := prepareDifyInvocationArguments(session, request)
+	if err != nil {
+		return err
+	}
+	defer request_handle.End()
+
+	if invoke_from == PLUGIN_ACCESS_TYPE_MODEL {
+		request_handle.WriteError(fmt.Errorf("you can not invoke dify from %s", invoke_from))
+		return nil
+	}
+
+	// dispatch invocation task
+	dispatchDifyInvocationTask(request_handle)
+
+	return nil
+}
+
+func prepareDifyInvocationArguments(session *session_manager.Session, request map[string]any) (*backwards_invocation.BackwardsInvocation, error) {
 	typ, ok := request["type"].(string)
 	if !ok {
-		return fmt.Errorf("invoke request missing type: %s", data)
+		return nil, fmt.Errorf("invoke request missing type: %s", request)
 	}
 
 	// get request id
-	request_id, ok := request["request_id"].(string)
+	backwards_request_id, ok := request["backwards_request_id"].(string)
 	if !ok {
-		return fmt.Errorf("invoke request missing request_id: %s", data)
+		return nil, fmt.Errorf("invoke request missing request_id: %s", request)
 	}
 
 	// get request
 	detailed_request, ok := request["request"].(map[string]any)
 	if !ok {
-		return fmt.Errorf("invoke request missing request: %s", data)
+		return nil, fmt.Errorf("invoke request missing request: %s", request)
 	}
 
-	switch typ {
-	case "tool":
-		r := dify_invocation.InvokeToolRequest{}
-		if err := r.FromMap(request, detailed_request); err != nil {
-			return fmt.Errorf("unmarshal tool invoke request failed: %s", err.Error())
+	return backwards_invocation.NewBackwardsInvocation(
+		backwards_invocation.BackwardsInvocationType(typ),
+		backwards_request_id, session, detailed_request,
+	), nil
+}
+
+func dispatchDifyInvocationTask(handle *backwards_invocation.BackwardsInvocation) {
+	switch handle.Type() {
+	case dify_invocation.INVOKE_TYPE_TOOL:
+		r, err := parser.MapToStruct[dify_invocation.InvokeToolRequest](handle.RequestData())
+		if err != nil {
+			handle.WriteError(fmt.Errorf("unmarshal invoke tool request failed: %s", err.Error()))
+			return
 		}
-		submitToolTask(runtime, session, request_id, &r)
-	case "model":
-		r := dify_invocation.InvokeModelRequest{}
-		if err := r.FromMap(request, detailed_request); err != nil {
-			return fmt.Errorf("unmarshal model invoke request failed: %s", err.Error())
+
+		submitToolTask(runtime, session, backwards_request_id, &r)
+	case dify_invocation.INVOKE_TYPE_MODEL:
+		r, err := parser.MapToStruct[dify_invocation.InvokeModelRequest](handle.RequestData())
+		if err != nil {
+			handle.WriteError(fmt.Errorf("unmarshal invoke model request failed: %s", err.Error()))
+			return
 		}
-		submitModelTask(runtime, session, request_id, &r)
-	case "node":
+
+		submitModelTask(runtime, session, backwards_request_id, &r)
+	case dify_invocation.INVOKE_TYPE_NODE:
 		node_type, ok := detailed_request["node_type"].(dify_invocation.NodeType)
 		if !ok {
 			return fmt.Errorf("invoke request missing node_type: %s", data)
@@ -63,40 +95,35 @@ func invokeDify(
 		}
 		switch node_type {
 		case dify_invocation.QUESTION_CLASSIFIER:
-			d := dify_invocation.InvokeNodeRequest[*dify_invocation.QuestionClassifierNodeData]{
+			d := dify_invocation.InvokeNodeRequest[dify_invocation.QuestionClassifierNodeData]{
 				NodeType: dify_invocation.QUESTION_CLASSIFIER,
-				NodeData: &dify_invocation.QuestionClassifierNodeData{},
 			}
 			if err := d.FromMap(node_data); err != nil {
 				return fmt.Errorf("unmarshal question classifier node data failed: %s", err.Error())
 			}
-			submitNodeInvocationRequestTask(runtime, session, request_id, &d)
+			submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d)
 		case dify_invocation.KNOWLEDGE_RETRIEVAL:
-			d := dify_invocation.InvokeNodeRequest[*dify_invocation.KnowledgeRetrievalNodeData]{
+			d := dify_invocation.InvokeNodeRequest[dify_invocation.KnowledgeRetrievalNodeData]{
 				NodeType: dify_invocation.KNOWLEDGE_RETRIEVAL,
-				NodeData: &dify_invocation.KnowledgeRetrievalNodeData{},
 			}
 			if err := d.FromMap(node_data); err != nil {
 				return fmt.Errorf("unmarshal knowledge retrieval node data failed: %s", err.Error())
 			}
-			submitNodeInvocationRequestTask(runtime, session, request_id, &d)
+			submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d)
 		case dify_invocation.PARAMETER_EXTRACTOR:
-			d := dify_invocation.InvokeNodeRequest[*dify_invocation.ParameterExtractorNodeData]{
+			d := dify_invocation.InvokeNodeRequest[dify_invocation.ParameterExtractorNodeData]{
 				NodeType: dify_invocation.PARAMETER_EXTRACTOR,
-				NodeData: &dify_invocation.ParameterExtractorNodeData{},
 			}
 			if err := d.FromMap(node_data); err != nil {
 				return fmt.Errorf("unmarshal parameter extractor node data failed: %s", err.Error())
 			}
-			submitNodeInvocationRequestTask(runtime, session, request_id, &d)
+			submitNodeInvocationRequestTask(runtime, session, backwards_request_id, &d)
 		default:
 			return fmt.Errorf("unknown node type: %s", node_type)
 		}
 	default:
 		return fmt.Errorf("unknown invoke type: %s", typ)
 	}
-
-	return nil
 }
 
 func setTaskContext(session *session_manager.Session, r *dify_invocation.BaseInvokeDifyRequest) {

+ 1 - 1
internal/core/plugin_daemon/model_service.go

@@ -46,7 +46,7 @@ func genericInvokePlugin[Req any, Rsp any](
 			}
 			response.Write(chunk)
 		case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
-			invokeDify(runtime, session, chunk.Data)
+			invokeDify(runtime, typ, session, chunk.Data)
 		case plugin_entities.SESSION_MESSAGE_TYPE_END:
 			response.Close()
 		case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:

+ 16 - 1
internal/core/session_manager/session.go

@@ -1,9 +1,11 @@
 package session_manager
 
 import (
+	"errors"
 	"sync"
 
 	"github.com/google/uuid"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities"
 )
 
 var (
@@ -12,7 +14,8 @@ var (
 )
 
 type Session struct {
-	id string
+	id      string
+	runtime entities.PluginRuntimeSessionIOInterface
 
 	tenant_id       string
 	user_id         string
@@ -71,3 +74,15 @@ func (s *Session) UserID() string {
 func (s *Session) PluginIdentity() string {
 	return s.plugin_identity
 }
+
+func (s *Session) BindRuntime(runtime entities.PluginRuntimeSessionIOInterface) {
+	s.runtime = runtime
+}
+
+func (s *Session) Write(data []byte) error {
+	if s.runtime == nil {
+		return errors.New("runtime not bound")
+	}
+	s.runtime.Write(s.id, data)
+	return nil
+}

+ 2 - 4
internal/service/invoke_model.go

@@ -3,17 +3,15 @@ package service
 import (
 	"github.com/gin-gonic/gin"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
-	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
-	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
 func InvokeTool(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTool], ctx *gin.Context) {
 	// create session
-	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	session := createSession(r)
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[tool_entities.ToolResponseChunk], error) {
@@ -23,7 +21,7 @@ func InvokeTool(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeToo
 
 func ValidateToolCredentials(r *plugin_entities.InvokePluginRequest[requests.RequestValidateToolCredentials], ctx *gin.Context) {
 	// create session
-	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	session := createSession(r)
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[tool_entities.ValidateCredentialsResult], error) {

+ 16 - 8
internal/service/invoke_tool.go

@@ -3,6 +3,7 @@ package service
 import (
 	"github.com/gin-gonic/gin"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/model_entities"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
@@ -11,9 +12,16 @@ import (
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
 )
 
+func createSession[T any](r *plugin_entities.InvokePluginRequest[T]) *session_manager.Session {
+	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	runtime := plugin_manager.Get(session.PluginIdentity())
+	session.BindRuntime(runtime)
+	return session
+}
+
 func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM], ctx *gin.Context) {
 	// create session
-	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	session := createSession(r)
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.LLMResultChunk], error) {
@@ -23,7 +31,7 @@ func InvokeLLM(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeLLM]
 
 func InvokeTextEmbedding(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTextEmbedding], ctx *gin.Context) {
 	// create session
-	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	session := createSession(r)
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.TextEmbeddingResult], error) {
@@ -33,7 +41,7 @@ func InvokeTextEmbedding(r *plugin_entities.InvokePluginRequest[requests.Request
 
 func InvokeRerank(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeRerank], ctx *gin.Context) {
 	// create session
-	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	session := createSession(r)
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.RerankResult], error) {
@@ -43,7 +51,7 @@ func InvokeRerank(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeR
 
 func InvokeTTS(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS], ctx *gin.Context) {
 	// create session
-	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	session := createSession(r)
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.TTSResult], error) {
@@ -53,7 +61,7 @@ func InvokeTTS(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeTTS]
 
 func InvokeSpeech2Text(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeSpeech2Text], ctx *gin.Context) {
 	// create session
-	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	session := createSession(r)
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.Speech2TextResult], error) {
@@ -63,7 +71,7 @@ func InvokeSpeech2Text(r *plugin_entities.InvokePluginRequest[requests.RequestIn
 
 func InvokeModeration(r *plugin_entities.InvokePluginRequest[requests.RequestInvokeModeration], ctx *gin.Context) {
 	// create session
-	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	session := createSession(r)
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.ModerationResult], error) {
@@ -73,7 +81,7 @@ func InvokeModeration(r *plugin_entities.InvokePluginRequest[requests.RequestInv
 
 func ValidateProviderCredentials(r *plugin_entities.InvokePluginRequest[requests.RequestValidateProviderCredentials], ctx *gin.Context) {
 	// create session
-	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	session := createSession(r)
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) {
@@ -83,7 +91,7 @@ func ValidateProviderCredentials(r *plugin_entities.InvokePluginRequest[requests
 
 func ValidateModelCredentials(r *plugin_entities.InvokePluginRequest[requests.RequestValidateModelCredentials], ctx *gin.Context) {
 	// create session
-	session := session_manager.NewSession(r.TenantId, r.UserId, parser.MarshalPluginIdentity(r.PluginName, r.PluginVersion))
+	session := createSession(r)
 	defer session.Close()
 
 	baseSSEService(r, func() (*stream.StreamResponse[model_entities.ValidateCredentialsResult], error) {

+ 1 - 0
internal/service/session.go

@@ -0,0 +1 @@
+package service

+ 29 - 0
internal/utils/parser/map2struct.go

@@ -0,0 +1,29 @@
+package parser
+
+import (
+	"fmt"
+
+	"github.com/mitchellh/mapstructure"
+)
+
+func MapToStruct[T any](m map[string]any) (*T, error) {
+	var s T
+	decoder := &mapstructure.DecoderConfig{
+		Metadata: nil,
+		Result:   &s,
+		TagName:  "json",
+		Squash:   true,
+	}
+
+	d, err := mapstructure.NewDecoder(decoder)
+	if err != nil {
+		return nil, fmt.Errorf("error creating decoder: %s", err)
+	}
+
+	err = d.Decode(m)
+	if err != nil {
+		return nil, fmt.Errorf("error decoding map: %s", err)
+	}
+
+	return &s, nil
+}

+ 48 - 0
internal/utils/parser/map2struct_test.go

@@ -0,0 +1,48 @@
+package parser
+
+import "testing"
+
+func TestMapToStruct(t *testing.T) {
+	m := map[string]any{
+		"result": "result",
+		"inherit": map[string]any{
+			"inherit_result": "result",
+		},
+		"object": map[string]any{
+			"a": 1,
+		},
+	}
+
+	type p struct {
+		Inherit struct {
+			InheritResult string `json:"inherit_result"`
+		}
+	}
+
+	type s struct {
+		p
+
+		Result string `json:"result"`
+		Object struct {
+			A int `json:"a"`
+		} `json:"object"`
+	}
+
+	result, err := MapToStruct[s](m)
+	if err != nil {
+		t.Error(err)
+	}
+
+	if result.Result != "result" {
+		t.Error("result should be result")
+	}
+
+	if result.Inherit.InheritResult != "result" {
+		t.Error("inherit_result should be result")
+	}
+
+	if result.Object.A != 1 {
+		t.Error("a should be 1")
+	}
+
+}

+ 14 - 34
internal/utils/parser/struct2map.go

@@ -1,48 +1,28 @@
 package parser
 
 import (
-	"reflect"
-	"unicode"
+	"github.com/mitchellh/mapstructure"
 )
 
 func StructToMap(data interface{}) map[string]interface{} {
 	result := make(map[string]interface{})
-	val := reflect.ValueOf(data)
-	if val.Kind() == reflect.Ptr {
-		val = val.Elem()
-	}
-	for i := 0; i < val.NumField(); i++ {
-		field := val.Field(i)
-		typeField := val.Type().Field(i)
-		fieldName := toSnakeCase(typeField.Name)
 
-		if typeField.Anonymous {
-			embeddedFields := StructToMap(field.Interface())
-			for k, v := range embeddedFields {
-				result[k] = v
-			}
-		} else {
-			result[fieldName] = field.Interface()
-		}
+	decoder := &mapstructure.DecoderConfig{
+		Metadata: nil,
+		Result:   &result,
+		TagName:  "json",
+		Squash:   true,
 	}
-	return result
-}
 
-func toSnakeCase(str string) string {
-	runes := []rune(str)
-	length := len(runes)
-	var out []rune
+	d, err := mapstructure.NewDecoder(decoder)
+	if err != nil {
+		return nil
+	}
 
-	for i := 0; i < length; i++ {
-		if unicode.IsUpper(runes[i]) {
-			if i > 0 {
-				out = append(out, '_')
-			}
-			out = append(out, unicode.ToLower(runes[i]))
-		} else {
-			out = append(out, runes[i])
-		}
+	err = d.Decode(data)
+	if err != nil {
+		return nil
 	}
 
-	return string(out)
+	return result
 }

+ 32 - 0
internal/utils/parser/struct2map_test.go

@@ -0,0 +1,32 @@
+package parser
+
+import "testing"
+
+func TestStruct2Map(t *testing.T) {
+	type Base struct {
+		A int `json:"a"`
+	}
+
+	type p struct {
+		Base
+
+		B int `json:"b"`
+	}
+
+	d := p{
+		Base: Base{
+			A: 1,
+		},
+		B: 2,
+	}
+
+	result := StructToMap(d)
+
+	if result["a"] != 1 {
+		t.Error("a should be 1")
+	}
+
+	if result["b"] != 2 {
+		t.Error("b should be 2")
+	}
+}