Quellcode durchsuchen

feat: agent pluign

Yeuoly vor 7 Monaten
Ursprung
Commit
7bb4705a48

+ 6 - 2
internal/core/plugin_daemon/access_types/access.go

@@ -6,12 +6,14 @@ const (
 	PLUGIN_ACCESS_TYPE_TOOL     PluginAccessType = "tool"
 	PLUGIN_ACCESS_TYPE_MODEL    PluginAccessType = "model"
 	PLUGIN_ACCESS_TYPE_ENDPOINT PluginAccessType = "endpoint"
+	PLUGIN_ACCESS_TYPE_AGENT    PluginAccessType = "agent"
 )
 
 func (p PluginAccessType) IsValid() bool {
 	return p == PLUGIN_ACCESS_TYPE_TOOL ||
 		p == PLUGIN_ACCESS_TYPE_MODEL ||
-		p == PLUGIN_ACCESS_TYPE_ENDPOINT
+		p == PLUGIN_ACCESS_TYPE_ENDPOINT ||
+		p == PLUGIN_ACCESS_TYPE_AGENT
 }
 
 type PluginAccessAction string
@@ -33,6 +35,7 @@ const (
 	PLUGIN_ACCESS_ACTION_GET_TEXT_EMBEDDING_NUM_TOKENS PluginAccessAction = "get_text_embedding_num_tokens"
 	PLUGIN_ACCESS_ACTION_GET_AI_MODEL_SCHEMAS          PluginAccessAction = "get_ai_model_schemas"
 	PLUGIN_ACCESS_ACTION_GET_LLM_NUM_TOKENS            PluginAccessAction = "get_llm_num_tokens"
+	PLUGIN_ACCESS_ACTION_INVOKE_AGENT                  PluginAccessAction = "invoke_agent"
 )
 
 func (p PluginAccessAction) IsValid() bool {
@@ -51,5 +54,6 @@ func (p PluginAccessAction) IsValid() bool {
 		p == PLUGIN_ACCESS_ACTION_GET_TTS_MODEL_VOICES ||
 		p == PLUGIN_ACCESS_ACTION_GET_TEXT_EMBEDDING_NUM_TOKENS ||
 		p == PLUGIN_ACCESS_ACTION_GET_AI_MODEL_SCHEMAS ||
-		p == PLUGIN_ACCESS_ACTION_GET_LLM_NUM_TOKENS
+		p == PLUGIN_ACCESS_ACTION_GET_LLM_NUM_TOKENS ||
+		p == PLUGIN_ACCESS_ACTION_INVOKE_AGENT
 }

+ 205 - 0
internal/core/plugin_daemon/agent_service.go

@@ -0,0 +1,205 @@
+package plugin_daemon
+
+import (
+	"bytes"
+	"encoding/base64"
+	"errors"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/agent_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
+	"github.com/xeipuuv/gojsonschema"
+)
+
+func InvokeAgent(
+	session *session_manager.Session,
+	r *requests.RequestInvokeAgent,
+) (*stream.Stream[agent_entities.AgentResponseChunk], error) {
+	runtime := session.Runtime()
+	if runtime == nil {
+		return nil, errors.New("plugin not found")
+	}
+
+	response, err := GenericInvokePlugin[
+		requests.RequestInvokeAgent, agent_entities.AgentResponseChunk,
+	](
+		session,
+		r,
+		128,
+	)
+
+	if err != nil {
+		return nil, err
+	}
+
+	agentDeclaration := runtime.Configuration().Agent
+	if agentDeclaration == nil {
+		return nil, errors.New("agent declaration not found")
+	}
+
+	var agentOutputSchema plugin_entities.AgentOutputSchema
+	for _, v := range agentDeclaration.Strategies {
+		if v.Identity.Name == r.Strategy {
+			agentOutputSchema = v.OutputSchema
+		}
+	}
+
+	newResponse := stream.NewStream[agent_entities.AgentResponseChunk](128)
+	routine.Submit(map[string]string{
+		"module":         "plugin_daemon",
+		"function":       "InvokeAgent",
+		"agent_name":     r.Strategy,
+		"agent_provider": r.Provider,
+	}, func() {
+		files := make(map[string]*bytes.Buffer)
+		defer newResponse.Close()
+
+		for response.Next() {
+			item, err := response.Read()
+			if err != nil {
+				newResponse.WriteError(err)
+				return
+			}
+
+			if item.Type == tool_entities.ToolResponseChunkTypeBlobChunk {
+				id, ok := item.Message["id"].(string)
+				if !ok {
+					continue
+				}
+
+				totalLength, ok := item.Message["total_length"].(float64)
+				if !ok {
+					continue
+				}
+
+				// convert total_length to int
+				totalLengthInt := int(totalLength)
+
+				blob, ok := item.Message["blob"].(string)
+				if !ok {
+					continue
+				}
+
+				end, ok := item.Message["end"].(bool)
+				if !ok {
+					continue
+				}
+
+				if _, ok := files[id]; !ok {
+					files[id] = bytes.NewBuffer(make([]byte, 0, totalLengthInt))
+				}
+
+				if end {
+					newResponse.Write(agent_entities.AgentResponseChunk{
+						ToolResponseChunk: tool_entities.ToolResponseChunk{
+							Type: tool_entities.ToolResponseChunkTypeBlob,
+							Message: map[string]any{
+								"blob": files[id].Bytes(), // bytes will be encoded to base64 finally
+							},
+							Meta: item.Meta,
+						},
+					})
+				} else {
+					if files[id].Len() > 15*1024*1024 {
+						// delete the file if it is too large
+						delete(files, id)
+						newResponse.WriteError(errors.New("file is too large"))
+						return
+					} else {
+						// decode the blob using base64
+						decoded, err := base64.StdEncoding.DecodeString(blob)
+						if err != nil {
+							newResponse.WriteError(err)
+							return
+						}
+						if len(decoded) > 8192 {
+							// single chunk is too large, raises error
+							newResponse.WriteError(errors.New("single file chunk is too large"))
+							return
+						}
+						files[id].Write(decoded)
+					}
+				}
+			} else {
+				newResponse.Write(item)
+			}
+		}
+	})
+
+	// bind json schema validator
+	bindAgentValidator(response, agentOutputSchema)
+
+	return newResponse, nil
+}
+
+// TODO: reduce implementation of bindAgentValidator, it's a copy of bindToolValidator now
+func bindAgentValidator(
+	response *stream.Stream[agent_entities.AgentResponseChunk],
+	agentOutputSchema plugin_entities.AgentOutputSchema,
+) {
+	// check if the tool_output_schema is valid
+	variables := make(map[string]any)
+
+	response.Filter(func(trc agent_entities.AgentResponseChunk) error {
+		if trc.Type == tool_entities.ToolResponseChunkTypeVariable {
+			variableName, ok := trc.Message["variable_name"].(string)
+			if !ok {
+				return errors.New("variable name is not a string")
+			}
+			stream, ok := trc.Message["stream"].(bool)
+			if !ok {
+				return errors.New("stream is not a boolean")
+			}
+
+			if stream {
+				// ensure variable_value is a string
+				variableValue, ok := trc.Message["variable_value"].(string)
+				if !ok {
+					return errors.New("variable value is not a string")
+				}
+
+				// create it if not exists
+				if _, ok := variables[variableName]; !ok {
+					variables[variableName] = ""
+				}
+
+				originalValue, ok := variables[variableName].(string)
+				if !ok {
+					return errors.New("variable value is not a string")
+				}
+
+				// add the variable value to the variable
+				variables[variableName] = originalValue + variableValue
+			} else {
+				variables[variableName] = trc.Message["variable_value"]
+			}
+		}
+
+		return nil
+	})
+
+	response.BeforeClose(func() {
+		// validate the variables
+		schema, err := gojsonschema.NewSchema(gojsonschema.NewGoLoader(agentOutputSchema))
+		if err != nil {
+			response.WriteError(err)
+			return
+		}
+
+		// validate the variables
+		result, err := schema.Validate(gojsonschema.NewGoLoader(variables))
+		if err != nil {
+			response.WriteError(err)
+			return
+		}
+
+		if !result.Valid() {
+			response.WriteError(errors.New("tool output schema is not valid"))
+			return
+		}
+	})
+}

+ 2 - 2
internal/core/plugin_daemon/tool_service.go

@@ -130,12 +130,12 @@ func InvokeTool(
 	})
 
 	// bind json schema validator
-	bindValidator(response, toolOutputSchema)
+	bindToolValidator(response, toolOutputSchema)
 
 	return newResponse, nil
 }
 
-func bindValidator(
+func bindToolValidator(
 	response *stream.Stream[tool_entities.ToolResponseChunk],
 	toolOutputSchema plugin_entities.ToolOutputSchema,
 ) {

+ 2 - 2
internal/core/plugin_daemon/tool_service_test.go

@@ -10,7 +10,7 @@ import (
 func TestToolInvokeJSONSchemaValidator(t *testing.T) {
 	response := stream.NewStream[tool_entities.ToolResponseChunk](128)
 
-	bindValidator(response, map[string]any{
+	bindToolValidator(response, map[string]any{
 		"output_schema": map[string]any{
 			"type": "object",
 			"properties": map[string]any{
@@ -44,7 +44,7 @@ func TestToolInvokeJSONSchemaValidator(t *testing.T) {
 func TestToolInvokeJSONSchemaValidatorWithInvalidSchema(t *testing.T) {
 	response := stream.NewStream[tool_entities.ToolResponseChunk](128)
 
-	bindValidator(response, map[string]any{
+	bindToolValidator(response, map[string]any{
 		"output_schema": map[string]any{
 			"type": "object",
 			"properties": map[string]any{

+ 21 - 5
internal/core/plugin_manager/remote_manager/hooks.go

@@ -20,6 +20,7 @@ import (
 
 var (
 	// mode is only used for testing
+	// TODO: simplify this ugly code
 	_mode pluginRuntimeMode
 )
 
@@ -270,7 +271,8 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 		} else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_END {
 			if !runtime.modelsRegistrationTransferred &&
 				!runtime.endpointsRegistrationTransferred &&
-				!runtime.toolsRegistrationTransferred {
+				!runtime.toolsRegistrationTransferred &&
+				!runtime.agentsRegistrationTransferred {
 				closeConn([]byte("no registration transferred, cannot initialize\n"))
 				return
 			}
@@ -400,10 +402,24 @@ func (s *DifyServer) onMessage(runtime *RemotePluginRuntime, message []byte) {
 				declaration.Endpoint = &endpoints[0]
 				runtime.Config = declaration
 			}
-		} else {
-			// unknown event type
-			closeConn([]byte("unknown initialization event type\n"))
-			return
+		} else if registerPayload.Type == plugin_entities.REGISTER_EVENT_TYPE_AGENT_DECLARATION {
+			if runtime.agentsRegistrationTransferred {
+				return
+			}
+
+			agents, err := parser.UnmarshalJsonBytes2Slice[plugin_entities.AgentProviderDeclaration](registerPayload.Data)
+			if err != nil {
+				closeConn([]byte(fmt.Sprintf("agents register failed, invalid agents declaration: %v\n", err)))
+				return
+			}
+
+			runtime.agentsRegistrationTransferred = true
+
+			if len(agents) > 0 {
+				declaration := runtime.Config
+				declaration.Agent = &agents[0]
+				runtime.Config = declaration
+			}
 		}
 	} else {
 		// continue handle messages if handshake completed

+ 1 - 0
internal/core/plugin_manager/remote_manager/type.go

@@ -52,6 +52,7 @@ type RemotePluginRuntime struct {
 	toolsRegistrationTransferred     bool
 	modelsRegistrationTransferred    bool
 	endpointsRegistrationTransferred bool
+	agentsRegistrationTransferred    bool
 	assetsTransferred                bool
 
 	// tenant id

+ 22 - 0
internal/server/controllers/agent.go

@@ -0,0 +1,22 @@
+package controllers
+
+import (
+	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/service"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/app"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
+)
+
+func InvokeAgent(config *app.Config) gin.HandlerFunc {
+	type request = plugin_entities.InvokePluginRequest[requests.RequestInvokeAgent]
+
+	return func(c *gin.Context) {
+		BindPluginDispatchRequest(
+			c,
+			func(itr request) {
+				service.InvokeAgent(&itr, c, config.PluginMaxExecutionTimeout)
+			},
+		)
+	}
+}

+ 1 - 0
internal/server/http_server.go

@@ -60,6 +60,7 @@ func (app *App) pluginDispatchGroup(group *gin.RouterGroup, config *app.Config)
 	group.POST("/tool/invoke", controllers.InvokeTool(config))
 	group.POST("/tool/validate_credentials", controllers.ValidateToolCredentials(config))
 	group.POST("/tool/get_runtime_parameters", controllers.GetToolRuntimeParameters(config))
+	group.POST("/agent/invoke", controllers.InvokeAgent(config))
 	group.POST("/llm/invoke", controllers.InvokeLLM(config))
 	group.POST("/llm/num_tokens", controllers.GetLLMNumTokens(config))
 	group.POST("/text_embedding/invoke", controllers.InvokeTextEmbedding(config))

+ 42 - 0
internal/service/invoke_agent.go

@@ -0,0 +1,42 @@
+package service
+
+import (
+	"github.com/gin-gonic/gin"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/access_types"
+	"github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/agent_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
+	"github.com/langgenius/dify-plugin-daemon/internal/types/exception"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
+)
+
+func InvokeAgent(
+	r *plugin_entities.InvokePluginRequest[requests.RequestInvokeAgent],
+	ctx *gin.Context,
+	max_timeout_seconds int,
+) {
+	// create session
+	session, err := createSession(
+		r,
+		access_types.PLUGIN_ACCESS_TYPE_AGENT,
+		access_types.PLUGIN_ACCESS_ACTION_INVOKE_AGENT,
+		ctx.GetString("cluster_id"),
+	)
+	if err != nil {
+		ctx.JSON(500, exception.InternalServerError(err).ToResponse())
+		return
+	}
+	defer session.Close(session_manager.CloseSessionPayload{
+		IgnoreCache: false,
+	})
+
+	baseSSEService(
+		func() (*stream.Stream[agent_entities.AgentResponseChunk], error) {
+			return plugin_daemon.InvokeAgent(session, &r.Data)
+		},
+		ctx,
+		max_timeout_seconds,
+	)
+}

+ 7 - 0
internal/types/entities/agent_entities/agent.go

@@ -0,0 +1,7 @@
+package agent_entities
+
+import "github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
+
+type AgentResponseChunk struct {
+	tool_entities.ToolResponseChunk `json:",inline"`
+}

+ 127 - 0
internal/types/entities/plugin_entities/agent_declaration.go

@@ -0,0 +1,127 @@
+package plugin_entities
+
+import (
+	"encoding/json"
+
+	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/manifest_entities"
+	"gopkg.in/yaml.v3"
+)
+
+type AgentProviderIdentity struct {
+	ToolProviderIdentity `json:",inline" yaml:",inline"`
+}
+
+type AgentIdentity struct {
+	ToolIdentity `json:",inline" yaml:",inline"`
+}
+
+type AgentParameter struct {
+	ToolParameter `json:",inline" yaml:",inline"`
+}
+
+type AgentOutputSchema struct {
+	ToolOutputSchema `json:",inline" yaml:",inline"`
+}
+
+type AgentStrategyDeclaration struct {
+	Identity     AgentIdentity     `json:"identity" yaml:"identity" validate:"required"`
+	Description  I18nObject        `json:"description" yaml:"description" validate:"required"`
+	Parameters   []AgentParameter  `json:"parameters" yaml:"parameters" validate:"omitempty,dive"`
+	OutputSchema AgentOutputSchema `json:"output_schema" yaml:"output_schema" validate:"omitempty,json_schema"`
+}
+
+type AgentProviderDeclaration struct {
+	Identity      AgentProviderIdentity      `json:"identity" yaml:"identity" validate:"required"`
+	Strategies    []AgentStrategyDeclaration `json:"strategies" yaml:"strategies" validate:"required,dive"`
+	StrategyFiles []string                   `json:"-" yaml:"-"`
+}
+
+func (a *AgentProviderDeclaration) MarshalJSON() ([]byte, error) {
+	type alias AgentProviderDeclaration
+	p := alias(*a)
+	if p.Strategies == nil {
+		p.Strategies = []AgentStrategyDeclaration{}
+	}
+	return json.Marshal(p)
+}
+
+func (a *AgentProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
+	type alias struct {
+		Identity   AgentProviderIdentity `yaml:"identity"`
+		Strategies yaml.Node             `yaml:"strategies"`
+	}
+
+	var temp alias
+
+	err := value.Decode(&temp)
+	if err != nil {
+		return err
+	}
+
+	// apply identity
+	a.Identity = temp.Identity
+
+	if a.StrategyFiles == nil {
+		a.StrategyFiles = []string{}
+	}
+
+	// unmarshal strategies
+	if temp.Strategies.Kind == yaml.SequenceNode {
+		for _, item := range temp.Strategies.Content {
+			if item.Kind == yaml.ScalarNode {
+				a.StrategyFiles = append(a.StrategyFiles, item.Value)
+			} else if item.Kind == yaml.MappingNode {
+				strategy := AgentStrategyDeclaration{}
+				if err := item.Decode(&strategy); err != nil {
+					return err
+				}
+				a.Strategies = append(a.Strategies, strategy)
+			}
+		}
+	}
+
+	if a.Strategies == nil {
+		a.Strategies = []AgentStrategyDeclaration{}
+	}
+
+	if a.Identity.Tags == nil {
+		a.Identity.Tags = []manifest_entities.PluginTag{}
+	}
+
+	return nil
+}
+
+func (a *AgentProviderDeclaration) UnmarshalJSON(data []byte) error {
+	type alias AgentProviderDeclaration
+
+	var temp struct {
+		alias
+		Strategies []json.RawMessage `json:"strategies"`
+	}
+
+	if err := json.Unmarshal(data, &temp); err != nil {
+		return err
+	}
+
+	*a = AgentProviderDeclaration(temp.alias)
+
+	// unmarshal strategies
+	for _, item := range temp.Strategies {
+		strategy := AgentStrategyDeclaration{}
+		if err := json.Unmarshal(item, &strategy); err != nil {
+			a.StrategyFiles = append(a.StrategyFiles, string(item))
+		} else {
+			a.Strategies = append(a.Strategies, strategy)
+		}
+	}
+
+	if a.Strategies == nil {
+		a.Strategies = []AgentStrategyDeclaration{}
+	}
+
+	if a.Identity.Tags == nil {
+		a.Identity.Tags = []manifest_entities.PluginTag{}
+	}
+
+	return nil
+}

+ 1 - 0
internal/types/entities/plugin_entities/plugin_declaration.go

@@ -176,6 +176,7 @@ type PluginDeclaration struct {
 	Endpoint                               *EndpointProviderDeclaration `json:"endpoint,omitempty" yaml:"endpoint,omitempty" validate:"omitempty"`
 	Model                                  *ModelProviderDeclaration    `json:"model,omitempty" yaml:"model,omitempty" validate:"omitempty"`
 	Tool                                   *ToolProviderDeclaration     `json:"tool,omitempty" yaml:"tool,omitempty" validate:"omitempty"`
+	Agent                                  *AgentProviderDeclaration    `json:"agent,omitempty" yaml:"agent,omitempty" validate:"omitempty"`
 }
 
 func (p *PluginDeclaration) Category() PluginCategory {

+ 1 - 0
internal/types/entities/plugin_entities/remote_entities.go

@@ -16,6 +16,7 @@ const (
 	REGISTER_EVENT_TYPE_TOOL_DECLARATION     RemotePluginRegisterEventType = "tool_declaration"
 	REGISTER_EVENT_TYPE_MODEL_DECLARATION    RemotePluginRegisterEventType = "model_declaration"
 	REGISTER_EVENT_TYPE_ENDPOINT_DECLARATION RemotePluginRegisterEventType = "endpoint_declaration"
+	REGISTER_EVENT_TYPE_AGENT_DECLARATION    RemotePluginRegisterEventType = "agent_declaration"
 	REGISTER_EVENT_TYPE_END                  RemotePluginRegisterEventType = "end"
 )
 

+ 8 - 0
internal/types/entities/plugin_entities/tool_declaration.go

@@ -196,6 +196,10 @@ func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
 		t.CredentialsSchema = credentialsSchema
 	}
 
+	if t.ToolFiles == nil {
+		t.ToolFiles = []string{}
+	}
+
 	// unmarshal tools
 	if temp.Tools.Kind == yaml.SequenceNode {
 		for _, item := range temp.Tools.Content {
@@ -266,6 +270,10 @@ func (t *ToolProviderDeclaration) UnmarshalJSON(data []byte) error {
 		t.CredentialsSchema = credentials_schema_array
 	}
 
+	if t.ToolFiles == nil {
+		t.ToolFiles = []string{}
+	}
+
 	// unmarshal tools
 	for _, item := range temp.Tools {
 		tool := ToolDeclaration{}

+ 11 - 0
internal/types/entities/requests/agent.go

@@ -0,0 +1,11 @@
+package requests
+
+type InvokeAgentSchema struct {
+	Provider    string         `json:"provider" validate:"required"`
+	Strategy    string         `json:"strategy" validate:"required"`
+	AgentParams map[string]any `json:"agent_params" validate:"omitempty"`
+}
+
+type RequestInvokeAgent struct {
+	InvokeAgentSchema
+}