Browse Source

fix: load model and tool declaration while decoding pkg

Yeuoly 11 months ago
parent
commit
ce94626717

+ 146 - 2
internal/core/plugin_packager/decoder/decoder.go

@@ -5,6 +5,7 @@ import (
 	"fmt"
 	"io"
 	"io/fs"
+	"path/filepath"
 	"strings"
 
 	"github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
@@ -90,13 +91,26 @@ func (p *PluginDecoderHelper) Manifest(decoder PluginDecoder) (plugin_entities.P
 			return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to read tool file: %s", tool))
 		}
 
-		// TODO
-
 		plugin_dec, err := parser.UnmarshalYamlBytes[plugin_entities.ToolProviderDeclaration](plugin_yaml)
 		if err != nil {
 			return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to unmarshal plugin file: %s", tool))
 		}
 
+		// read tools
+		for _, tool_file := range plugin_dec.ToolFiles {
+			tool_file_content, err := decoder.ReadFile(tool_file)
+			if err != nil {
+				return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to read tool file: %s", tool_file))
+			}
+
+			tool_file_dec, err := parser.UnmarshalYamlBytes[plugin_entities.ToolDeclaration](tool_file_content)
+			if err != nil {
+				return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to unmarshal tool file: %s", tool_file))
+			}
+
+			plugin_dec.Tools = append(plugin_dec.Tools, tool_file_dec)
+		}
+
 		dec.Tool = &plugin_dec
 	}
 
@@ -127,6 +141,136 @@ func (p *PluginDecoderHelper) Manifest(decoder PluginDecoder) (plugin_entities.P
 			return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to unmarshal plugin file: %s", model))
 		}
 
+		// read model position file
+		if plugin_dec.PositionFiles != nil {
+			plugin_dec.Position = &plugin_entities.ModelPosition{}
+
+			llm_file_name, ok := plugin_dec.PositionFiles["llm"]
+			if ok {
+				llm_file, err := decoder.ReadFile(llm_file_name)
+				if err != nil {
+					return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to read llm position file: %s", llm_file_name))
+				}
+
+				position, err := parser.UnmarshalYamlBytes[[]string](llm_file)
+				if err != nil {
+					return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to unmarshal llm position file: %s", llm_file_name))
+				}
+
+				plugin_dec.Position.LLM = &position
+			}
+
+			text_embedding_file_name, ok := plugin_dec.PositionFiles["text_embedding"]
+			if ok {
+				text_embedding_file, err := decoder.ReadFile(text_embedding_file_name)
+				if err != nil {
+					return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to read text embedding position file: %s", text_embedding_file_name))
+				}
+
+				position, err := parser.UnmarshalYamlBytes[[]string](text_embedding_file)
+				if err != nil {
+					return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to unmarshal text embedding position file: %s", text_embedding_file_name))
+				}
+
+				plugin_dec.Position.TextEmbedding = &position
+			}
+
+			rerank_file_name, ok := plugin_dec.PositionFiles["rerank"]
+			if ok {
+				rerank_file, err := decoder.ReadFile(rerank_file_name)
+				if err != nil {
+					return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to read rerank position file: %s", rerank_file_name))
+				}
+
+				position, err := parser.UnmarshalYamlBytes[[]string](rerank_file)
+				if err != nil {
+					return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to unmarshal rerank position file: %s", rerank_file_name))
+				}
+
+				plugin_dec.Position.Rerank = &position
+			}
+
+			tts_file_name, ok := plugin_dec.PositionFiles["tts"]
+			if ok {
+				tts_file, err := decoder.ReadFile(tts_file_name)
+				if err != nil {
+					return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to read tts position file: %s", tts_file_name))
+				}
+
+				position, err := parser.UnmarshalYamlBytes[[]string](tts_file)
+				if err != nil {
+					return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to unmarshal tts position file: %s", tts_file_name))
+				}
+
+				plugin_dec.Position.TTS = &position
+			}
+
+			speech2text_file_name, ok := plugin_dec.PositionFiles["speech2text"]
+			if ok {
+				speech2text_file, err := decoder.ReadFile(speech2text_file_name)
+				if err != nil {
+					return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to read speech2text position file: %s", speech2text_file_name))
+				}
+
+				position, err := parser.UnmarshalYamlBytes[[]string](speech2text_file)
+				if err != nil {
+					return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to unmarshal speech2text position file: %s", speech2text_file_name))
+				}
+
+				plugin_dec.Position.Speech2text = &position
+			}
+
+			moderation_file_name, ok := plugin_dec.PositionFiles["moderation"]
+			if ok {
+				moderation_file, err := decoder.ReadFile(moderation_file_name)
+				if err != nil {
+					return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to read moderation position file: %s", moderation_file_name))
+				}
+
+				position, err := parser.UnmarshalYamlBytes[[]string](moderation_file)
+				if err != nil {
+					return plugin_entities.PluginDeclaration{}, errors.Join(err, fmt.Errorf("failed to unmarshal moderation position file: %s", moderation_file_name))
+				}
+
+				plugin_dec.Position.Moderation = &position
+			}
+		}
+
+		// read models
+		if err := decoder.Walk(func(filename, dir string) error {
+			model_patterns := plugin_dec.ModelFiles
+			// using glob to match if dir/filename is in models
+			model_file_name := filepath.Join(dir, filename)
+			if strings.HasSuffix(model_file_name, "_position.yaml") {
+				return nil
+			}
+
+			for _, model_pattern := range model_patterns {
+				matched, err := filepath.Match(model_pattern, model_file_name)
+				if err != nil {
+					return err
+				}
+				if matched {
+					// read model file
+					model_file, err := decoder.ReadFile(model_file_name)
+					if err != nil {
+						return err
+					}
+
+					model_dec, err := parser.UnmarshalYamlBytes[plugin_entities.ModelDeclaration](model_file)
+					if err != nil {
+						return err
+					}
+
+					plugin_dec.Models = append(plugin_dec.Models, model_dec)
+				}
+			}
+
+			return nil
+		}); err != nil {
+			return plugin_entities.PluginDeclaration{}, err
+		}
+
 		dec.Model = &plugin_dec
 	}
 

+ 152 - 29
internal/types/entities/plugin_entities/model_declaration.go

@@ -1,12 +1,15 @@
 package plugin_entities
 
 import (
+	"encoding/json"
+
 	"github.com/go-playground/locales/en"
 	ut "github.com/go-playground/universal-translator"
 	"github.com/go-playground/validator/v10"
 	en_translations "github.com/go-playground/validator/v10/translations/en"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
 	"github.com/shopspring/decimal"
+	"gopkg.in/yaml.v3"
 )
 
 type ModelType string
@@ -75,17 +78,17 @@ func isModelParameterType(fl validator.FieldLevel) bool {
 }
 
 type ModelParameterRule struct {
-	Name        string              `json:"name" validate:"required,lt=256"`
-	UseTemplate *string             `json:"use_template" validate:"omitempty,lt=256"`
-	Label       *I18nObject         `json:"label" validate:"omitempty"`
-	Type        *ModelParameterType `json:"type" validate:"omitempty,model_parameter_type"`
-	Help        *I18nObject         `json:"help" validate:"omitempty"`
-	Required    bool                `json:"required"`
-	Default     *any                `json:"default" validate:"omitempty,is_basic_type"`
-	Min         *float64            `json:"min" validate:"omitempty"`
-	Max         *float64            `json:"max" validate:"omitempty"`
-	Precision   *int                `json:"precision" validate:"omitempty"`
-	Options     []string            `json:"options" validate:"omitempty,dive,lt=256"`
+	Name        string              `json:"name" yaml:"name" validate:"required,lt=256"`
+	UseTemplate *string             `json:"use_template" yaml:"use_template" validate:"omitempty,lt=256"`
+	Label       *I18nObject         `json:"label" yaml:"label" validate:"omitempty"`
+	Type        *ModelParameterType `json:"type" yaml:"type" validate:"omitempty,model_parameter_type"`
+	Help        *I18nObject         `json:"help" yaml:"help" validate:"omitempty"`
+	Required    bool                `json:"required" yaml:"required"`
+	Default     *any                `json:"default" yaml:"default" validate:"omitempty,is_basic_type"`
+	Min         *float64            `json:"min" yaml:"min" validate:"omitempty"`
+	Max         *float64            `json:"max" yaml:"max" validate:"omitempty"`
+	Precision   *int                `json:"precision" yaml:"precision" validate:"omitempty"`
+	Options     []string            `json:"options" yaml:"options" validate:"omitempty,dive,lt=256"`
 }
 
 func isParameterRule(fl validator.FieldLevel) bool {
@@ -117,15 +120,15 @@ type ModelPriceConfig struct {
 }
 
 type ModelDeclaration struct {
-	Model           string                         `json:"model" validate:"required,lt=256"`
-	Label           I18nObject                     `json:"label" validate:"required"`
-	ModelType       ModelType                      `json:"model_type" validate:"required,model_type"`
-	Features        []string                       `json:"features" validate:"omitempty,lte=256,dive,lt=256"`
-	FetchFrom       ModelProviderConfigurateMethod `json:"fetch_from" validate:"omitempty,model_provider_configurate_method"`
-	ModelProperties map[string]any                 `json:"model_properties" validate:"omitempty,dive,is_basic_type"`
-	Deprecated      bool                           `json:"deprecated"`
-	ParameterRules  []ModelParameterRule           `json:"parameter_rules" validate:"omitempty,lte=128,dive,parameter_rule"`
-	PriceConfig     *ModelPriceConfig              `json:"pricing" validate:"omitempty"`
+	Model           string                         `json:"model" yaml:"model" validate:"required,lt=256"`
+	Label           I18nObject                     `json:"label" yaml:"label" validate:"required"`
+	ModelType       ModelType                      `json:"model_type" yaml:"model_type" validate:"required,model_type"`
+	Features        []string                       `json:"features" yaml:"features" validate:"omitempty,lte=256,dive,lt=256"`
+	FetchFrom       ModelProviderConfigurateMethod `json:"fetch_from" yaml:"fetch_from" validate:"omitempty,model_provider_configurate_method"`
+	ModelProperties map[string]any                 `json:"model_properties" yaml:"model_properties" validate:"omitempty,dive,is_basic_type"`
+	Deprecated      bool                           `json:"deprecated" yaml:"deprecated"`
+	ParameterRules  []ModelParameterRule           `json:"parameter_rules" yaml:"parameter_rules" validate:"omitempty,lte=128,dive,parameter_rule"`
+	PriceConfig     *ModelPriceConfig              `json:"pricing" yaml:"pricing" validate:"omitempty"`
 }
 
 type ModelProviderFormType string
@@ -193,6 +196,15 @@ type ModelProviderHelpEntity struct {
 	URL   I18nObject `json:"url" validate:"required"`
 }
 
+type ModelPosition struct {
+	LLM           *[]string `json:"llm,omitempty" yaml:"llm,omitempty"`
+	TextEmbedding *[]string `json:"text_embedding,omitempty" yaml:"text_embedding,omitempty"`
+	Rerank        *[]string `json:"rerank,omitempty" yaml:"rerank,omitempty"`
+	TTS           *[]string `json:"tts,omitempty" yaml:"tts,omitempty"`
+	Speech2text   *[]string `json:"speech2text,omitempty" yaml:"speech2text,omitempty"`
+	Moderation    *[]string `json:"moderation,omitempty" yaml:"moderation,omitempty"`
+}
+
 type ModelProviderDeclaration struct {
 	Provider                 string                           `json:"provider" yaml:"provider" validate:"required,lt=256"`
 	Label                    I18nObject                       `json:"label" yaml:"label" validate:"required"`
@@ -205,15 +217,126 @@ type ModelProviderDeclaration struct {
 	ConfigurateMethods       []ModelProviderConfigurateMethod `json:"configurate_methods" yaml:"configurate_methods" validate:"required,lte=16,dive,model_provider_configurate_method"`
 	ProviderCredentialSchema *ModelProviderCredentialSchema   `json:"provider_credential_schema" yaml:"provider_credential_schema,omitempty" validate:"omitempty"`
 	ModelCredentialSchema    *ModelCredentialSchema           `json:"model_credential_schema" yaml:"model_credential_schema,omitempty" validate:"omitempty"`
-	Position                 *struct {
-		LLM           *[]string `json:"llm,omitempty" yaml:"llm,omitempty"`
-		TextEmbedding *[]string `json:"text_embedding,omitempty" yaml:"text_embedding,omitempty"`
-		Rerank        *[]string `json:"rerank,omitempty" yaml:"rerank,omitempty"`
-		TTS           *[]string `json:"tts,omitempty" yaml:"tts,omitempty"`
-		Speech2text   *[]string `json:"speech2text,omitempty" yaml:"speech2text,omitempty"`
-		Moderation    *[]string `json:"moderation,omitempty" yaml:"moderation,omitempty"`
-	} `json:"position,omitempty" yaml:"position,omitempty"`
-	Models []ModelDeclaration `json:"models" yaml:"model_declarations,omitempty"`
+	Position                 *ModelPosition                   `json:"position,omitempty" yaml:"position,omitempty"`
+	Models                   []ModelDeclaration               `json:"models" yaml:"model_declarations,omitempty"`
+	ModelFiles               []string                         `json:"-" yaml:"-"`
+	PositionFiles            map[string]string                `json:"-" yaml:"-"`
+}
+
+func (m *ModelProviderDeclaration) UnmarshalJSON(data []byte) error {
+	type alias ModelProviderDeclaration
+
+	var temp struct {
+		alias
+		Models json.RawMessage `json:"models"`
+	}
+
+	if err := json.Unmarshal(data, &temp); err != nil {
+		return err
+	}
+
+	*m = ModelProviderDeclaration(temp.alias)
+
+	// unmarshal models into map[string]any
+	var models map[string]any
+	if err := json.Unmarshal(temp.Models, &models); err != nil {
+		// can not unmarshal it into map, so it's a list
+		if err := json.Unmarshal(temp.Models, &m.Models); err != nil {
+			return err
+		}
+
+		return nil
+	}
+
+	m.PositionFiles = make(map[string]string)
+
+	types := []string{
+		"llm",
+		"text_embedding",
+		"tts",
+		"speech2text",
+		"moderation",
+		"rerank",
+	}
+
+	for _, model_type := range types {
+		model_type_map, ok := models[model_type].(map[string]any)
+		if ok {
+			model_type_position_file, ok := model_type_map["position"]
+			if ok {
+				model_type_position_file_path, ok := model_type_position_file.(string)
+				if ok {
+					m.PositionFiles[model_type] = model_type_position_file_path
+				}
+			}
+
+			model_type_predefined_files, ok := model_type_map["predefined"].([]string)
+			if ok {
+				m.ModelFiles = append(m.ModelFiles, model_type_predefined_files...)
+			}
+		}
+	}
+
+	return nil
+}
+
+func (m *ModelProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
+	type alias ModelProviderDeclaration
+
+	var temp struct {
+		alias  `yaml:",inline"`
+		Models yaml.Node `yaml:"models"`
+	}
+
+	if err := value.Decode(&temp); err != nil {
+		return err
+	}
+
+	*m = ModelProviderDeclaration(temp.alias)
+
+	// Check if Models is a mapping node
+	if temp.Models.Kind == yaml.MappingNode {
+		m.PositionFiles = make(map[string]string)
+
+		types := []string{
+			"llm",
+			"text_embedding",
+			"tts",
+			"speech2text",
+			"moderation",
+			"rerank",
+		}
+
+		for i := 0; i < len(temp.Models.Content); i += 2 {
+			key := temp.Models.Content[i].Value
+			value := temp.Models.Content[i+1]
+
+			for _, model_type := range types {
+				if key == model_type {
+					if value.Kind == yaml.MappingNode {
+						for j := 0; j < len(value.Content); j += 2 {
+							if value.Content[j].Value == "position" {
+								m.PositionFiles[model_type] = value.Content[j+1].Value
+							} else if value.Content[j].Value == "predefined" {
+								// get content of predefined
+								if value.Content[j+1].Kind == yaml.SequenceNode {
+									for _, file := range value.Content[j+1].Content {
+										m.ModelFiles = append(m.ModelFiles, file.Value)
+									}
+								}
+							}
+						}
+					}
+				}
+			}
+		}
+	} else if temp.Models.Kind == yaml.SequenceNode {
+		if err := temp.Models.Decode(&m.Models); err != nil {
+			return err
+		}
+	}
+
+	return nil
 }
 
 func init() {

+ 32 - 7
internal/types/entities/plugin_entities/tool_declaration.go

@@ -171,14 +171,15 @@ type ToolProviderIdentity struct {
 type ToolProviderDeclaration struct {
 	Identity          ToolProviderIdentity `json:"identity" validate:"required"`
 	CredentialsSchema []ProviderConfig     `json:"credentials_schema" validate:"omitempty,dive"`
-	Tools             []ToolDeclaration    `json:"tools" validate:"required,dive"`
+	Tools             []ToolDeclaration    `validate:"required,dive"`
+	ToolFiles         []string             `json:"-"`
 }
 
 func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
 	type alias struct {
 		Identity          ToolProviderIdentity `yaml:"identity"`
 		CredentialsSchema yaml.Node            `yaml:"credentials_schema"`
-		Tools             []ToolDeclaration    `yaml:"tools"`
+		Tools             yaml.Node            `yaml:"tools"`
 	}
 
 	var temp alias
@@ -190,7 +191,6 @@ func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
 
 	// apply identity
 	t.Identity = temp.Identity
-	t.Tools = temp.Tools
 
 	// check if credentials_schema is a map
 	if temp.CredentialsSchema.Kind != yaml.MappingNode {
@@ -201,7 +201,7 @@ func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
 		}
 		t.CredentialsSchema = credentials_schema
 	} else if temp.CredentialsSchema.Kind == yaml.MappingNode {
-		credentials_schema := make([]ProviderConfig, 0)
+		credentials_schema := make([]ProviderConfig, 0, len(temp.CredentialsSchema.Content)/2)
 		current_key := ""
 		current_value := &ProviderConfig{}
 		for _, item := range temp.CredentialsSchema.Content {
@@ -217,8 +217,21 @@ func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
 			}
 		}
 		t.CredentialsSchema = credentials_schema
-	} else {
-		return fmt.Errorf("invalid credentials_schema type: %v", temp.CredentialsSchema.Kind)
+	}
+
+	// unmarshal tools
+	if temp.Tools.Kind == yaml.SequenceNode {
+		for _, item := range temp.Tools.Content {
+			if item.Kind == yaml.ScalarNode {
+				t.ToolFiles = append(t.ToolFiles, item.Value)
+			} else if item.Kind == yaml.MappingNode {
+				tool := ToolDeclaration{}
+				if err := item.Decode(&tool); err != nil {
+					return err
+				}
+				t.Tools = append(t.Tools, tool)
+			}
+		}
 	}
 
 	return nil
@@ -229,7 +242,8 @@ func (t *ToolProviderDeclaration) UnmarshalJSON(data []byte) error {
 
 	var temp struct {
 		alias
-		CredentialsSchema json.RawMessage `json:"credentials_schema"`
+		CredentialsSchema json.RawMessage   `json:"credentials_schema"`
+		Tools             []json.RawMessage `json:"tools"`
 	}
 
 	if err := json.Unmarshal(data, &temp); err != nil {
@@ -258,6 +272,17 @@ func (t *ToolProviderDeclaration) UnmarshalJSON(data []byte) error {
 		t.CredentialsSchema = credentials_schema_array
 	}
 
+	// unmarshal tools
+	for _, item := range temp.Tools {
+		tool := ToolDeclaration{}
+		if err := json.Unmarshal(item, &tool); err != nil {
+			// try to unmarshal it as a string directly
+			t.ToolFiles = append(t.ToolFiles, string(item))
+		} else {
+			t.Tools = append(t.Tools, tool)
+		}
+	}
+
 	return nil
 }
 

+ 10 - 0
internal/types/entities/plugin_entities/tool_declaration_test.go

@@ -376,6 +376,11 @@ tools:
 		return
 	}
 
+	if len(json_declaration.Tools) != 1 {
+		t.Errorf("UnmarshalToolProviderConfiguration() error for JSON: incorrect Tools length")
+		return
+	}
+
 	yaml_declaration, yaml_err := parser.UnmarshalYamlBytes[ToolProviderDeclaration]([]byte(yaml_data))
 	if yaml_err != nil {
 		t.Errorf("UnmarshalToolProviderConfiguration() error for YAML = %v", yaml_err)
@@ -386,6 +391,11 @@ tools:
 		t.Errorf("UnmarshalToolProviderConfiguration() error for YAML: incorrect CredentialsSchema length")
 		return
 	}
+
+	if len(yaml_declaration.Tools) != 1 {
+		t.Errorf("UnmarshalToolProviderConfiguration() error for YAML: incorrect Tools length")
+		return
+	}
 }
 
 func TestWithoutAuthorToolProvider_Validate(t *testing.T) {