Browse Source

fix: missing model parameter rules default template transformers in default plugin decoder

Yeuoly 9 months ago
parent
commit
f66f763392

+ 3 - 3
internal/server/controllers/plugins.go

@@ -66,10 +66,10 @@ func InstallPluginFromIdentifier(app *app.Config) gin.HandlerFunc {
 
 func UninstallPlugin(c *gin.Context) {
 	BindRequest(c, func(request struct {
-		TenantID               string                                 `uri:"tenant_id" validate:"required"`
-		PluginUniqueIdentifier plugin_entities.PluginUniqueIdentifier `json:"plugin_unique_identifier" validate:"required,plugin_unique_identifier"`
+		TenantID             string `uri:"tenant_id" validate:"required"`
+		PluginInstallationID string `json:"plugin_installation_id" validate:"required"`
 	}) {
-		c.JSON(http.StatusOK, service.UninstallPlugin(request.TenantID, request.PluginUniqueIdentifier))
+		c.JSON(http.StatusOK, service.UninstallPlugin(request.TenantID, request.PluginInstallationID))
 	})
 }
 

+ 8 - 4
internal/service/install_plugin.go

@@ -99,22 +99,26 @@ func FetchPluginFromIdentifier(
 
 func UninstallPlugin(
 	tenant_id string,
-	plugin_unique_identifier plugin_entities.PluginUniqueIdentifier,
+	plugin_installation_id string,
 ) *entities.Response {
 	// Check if the plugin exists for the tenant
 	installation, err := db.GetOne[models.PluginInstallation](
 		db.Equal("tenant_id", tenant_id),
-		db.Equal("plugin_unique_identifier", plugin_unique_identifier.String()),
+		db.Equal("id", plugin_installation_id),
 	)
 	if err == db.ErrDatabaseNotFound {
-		return entities.NewErrorResponse(-404, "Plugin not found for this tenant")
+		return entities.NewErrorResponse(-404, "Plugin installation not found for this tenant")
 	}
 	if err != nil {
 		return entities.NewErrorResponse(-500, err.Error())
 	}
 
 	// Uninstall the plugin
-	_, err = curd.UninstallPlugin(tenant_id, plugin_unique_identifier, installation.ID)
+	_, err = curd.UninstallPlugin(
+		tenant_id,
+		plugin_entities.PluginUniqueIdentifier(installation.PluginUniqueIdentifier),
+		installation.ID,
+	)
 	if err != nil {
 		return entities.NewErrorResponse(-500, fmt.Sprintf("Failed to uninstall plugin: %s", err.Error()))
 	}

+ 279 - 0
internal/types/entities/plugin_entities/model_declaration.go

@@ -2,12 +2,14 @@ package plugin_entities
 
 import (
 	"encoding/json"
+	"fmt"
 
 	"github.com/go-playground/locales/en"
 	ut "github.com/go-playground/universal-translator"
 	"github.com/go-playground/validator/v10"
 	en_translations "github.com/go-playground/validator/v10/translations/en"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
+	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
 	"github.com/shopspring/decimal"
 	"gopkg.in/yaml.v3"
 )
@@ -91,6 +93,219 @@ type ModelParameterRule struct {
 	Options     []string            `json:"options" yaml:"options" validate:"omitempty,dive,lt=256"`
 }
 
+type DefaultParameterName string
+
+const (
+	TEMPERATURE       DefaultParameterName = "temperature"
+	TOP_P             DefaultParameterName = "top_p"
+	PRESENCE_PENALTY  DefaultParameterName = "presence_penalty"
+	FREQUENCY_PENALTY DefaultParameterName = "frequency_penalty"
+	MAX_TOKENS        DefaultParameterName = "max_tokens"
+	RESPONSE_FORMAT   DefaultParameterName = "response_format"
+)
+
+var PARAMETER_RULE_TEMPLATE = map[DefaultParameterName]ModelParameterRule{
+	TEMPERATURE: {
+		Label: &I18nObject{
+			EnUS:   "Temperature",
+			ZhHans: "温度",
+			JaJp:   "温度",
+			PtBr:   "Temperatura",
+		},
+		Type: parser.ToPtr(PARAMETER_TYPE_FLOAT),
+		Help: &I18nObject{
+			EnUS:   "Controls randomness. Lower temperature results in less random completions. As the temperature approaches zero, the model will become deterministic and repetitive. Higher temperature results in more random completions.",
+			ZhHans: "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。较高的温度会导致更多的随机完成。",
+			JaJp:   "温度はランダム性を制御します。温度が低いほどランダムな完成が少なくなります。温度がゼロに近づくと、モデルは決定論的で繰り返しになります。温度が高いほどランダムな完成が多くなります。",
+			PtBr:   "A temperatura controla a aleatoriedade. Menores temperaturas resultam em menos conclusões aleatórias. À medida que a temperatura se aproxima de zero, o modelo se tornará determinístico e repetitivo. Temperaturas mais altas resultam em mais conclusões aleatórias.",
+		},
+		Required:  false,
+		Default:   parser.ToPtr(any(0.0)),
+		Min:       parser.ToPtr(0.0),
+		Max:       parser.ToPtr(1.0),
+		Precision: parser.ToPtr(2),
+	},
+	TOP_P: {
+		Label: &I18nObject{
+			EnUS:   "Top P",
+			ZhHans: "Top P",
+			JaJp:   "Top P",
+			PtBr:   "Top P",
+		},
+		Type: parser.ToPtr(PARAMETER_TYPE_FLOAT),
+		Help: &I18nObject{
+			EnUS:   "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options are considered.",
+			ZhHans: "通过核心采样控制多样性:0.5表示考虑了一半的所有可能性加权选项。",
+			JaJp:   "核サンプリングを通じて多様性を制御します:0.5は、すべての可能性加权オプションの半分を考慮します。",
+			PtBr:   "Controla a diversidade via amostragem de núcleo: 0.5 significa que metade das opções com maior probabilidade são consideradas.",
+		},
+		Required:  false,
+		Default:   parser.ToPtr(any(1.0)),
+		Min:       parser.ToPtr(0.0),
+		Max:       parser.ToPtr(1.0),
+		Precision: parser.ToPtr(2),
+	},
+	PRESENCE_PENALTY: {
+		Label: &I18nObject{
+			EnUS:   "Presence Penalty",
+			ZhHans: "存在惩罚",
+			JaJp:   "存在ペナルティ",
+			PtBr:   "Penalidade de presença",
+		},
+		Type: parser.ToPtr(PARAMETER_TYPE_FLOAT),
+		Help: &I18nObject{
+			EnUS:   "Applies a penalty to the log-probability of tokens already in the text.",
+			ZhHans: "对文本中已有的标记的对数概率施加惩罚。",
+			JaJp:   "テキストに既に存在するトークンの対数確率にペナルティを適用します。",
+			PtBr:   "Aplica uma penalidade à probabilidade logarítmica de tokens já presentes no texto.",
+		},
+		Required:  false,
+		Default:   parser.ToPtr(any(0.0)),
+		Min:       parser.ToPtr(0.0),
+		Max:       parser.ToPtr(1.0),
+		Precision: parser.ToPtr(2),
+	},
+	FREQUENCY_PENALTY: {
+		Label: &I18nObject{
+			EnUS:   "Frequency Penalty",
+			ZhHans: "频率惩罚",
+			JaJp:   "頻度ペナルティ",
+			PtBr:   "Penalidade de frequência",
+		},
+		Type: parser.ToPtr(PARAMETER_TYPE_FLOAT),
+		Help: &I18nObject{
+			EnUS:   "Applies a penalty to the log-probability of tokens that appear in the text.",
+			ZhHans: "对文本中出现的标记的对数概率施加惩罚。",
+			JaJp:   "テキストに出現するトークンの対数確率にペナルティを適用します。",
+			PtBr:   "Aplica uma penalidade à probabilidade logarítmica de tokens que aparecem no texto.",
+		},
+		Required:  false,
+		Default:   parser.ToPtr(any(0.0)),
+		Min:       parser.ToPtr(0.0),
+		Max:       parser.ToPtr(1.0),
+		Precision: parser.ToPtr(2),
+	},
+	MAX_TOKENS: {
+		Label: &I18nObject{
+			EnUS:   "Max Tokens",
+			ZhHans: "最大标记",
+			JaJp:   "最大トークン",
+			PtBr:   "Máximo de tokens",
+		},
+		Type: parser.ToPtr(PARAMETER_TYPE_INT),
+		Help: &I18nObject{
+			EnUS:   "Specifies the upper limit on the length of generated results. If the generated results are truncated, you can increase this parameter.",
+			ZhHans: "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。",
+			JaJp:   "生成結果の長さの上限を指定します。生成結果が切り捨てられた場合は、このパラメータを大きくすることができます。",
+			PtBr:   "Especifica o limite superior para o comprimento dos resultados gerados. Se os resultados gerados forem truncados, você pode aumentar este parâmetro.",
+		},
+		Required:  false,
+		Default:   parser.ToPtr(any(64)),
+		Min:       parser.ToPtr(1.0),
+		Max:       parser.ToPtr(2048.0),
+		Precision: parser.ToPtr(0),
+	},
+	RESPONSE_FORMAT: {
+		Label: &I18nObject{
+			EnUS:   "Response Format",
+			ZhHans: "回复格式",
+			JaJp:   "応答形式",
+			PtBr:   "Formato de resposta",
+		},
+		Type: parser.ToPtr(PARAMETER_TYPE_STRING),
+		Help: &I18nObject{
+			EnUS:   "Set a response format, ensure the output from llm is a valid code block as possible, such as JSON, XML, etc.",
+			ZhHans: "设置一个返回格式,确保llm的输出尽可能是有效的代码块,如JSON、XML等",
+			JaJp:   "応答形式を設定します。llmの出力が可能な限り有効なコードブロックであることを確認します。",
+			PtBr:   "Defina um formato de resposta para garantir que a saída do llm seja um bloco de código válido o mais possível, como JSON, XML, etc.",
+		},
+		Required: false,
+		Options:  []string{"JSON", "XML"},
+	},
+}
+
+func (m *ModelParameterRule) TransformTemplate() error {
+	// if use_template is not empty, transform to use default value
+	if m.UseTemplate != nil && *m.UseTemplate != "" {
+		// get the value of use_template
+		use_template_value := m.UseTemplate
+		// get the template
+		template, ok := PARAMETER_RULE_TEMPLATE[DefaultParameterName(*use_template_value)]
+		if !ok {
+			return fmt.Errorf("use_template %s not found", *use_template_value)
+		}
+		// transform to default value
+		if m.Label == nil {
+			m.Label = template.Label
+		}
+		if m.Type == nil {
+			m.Type = template.Type
+		}
+		if m.Help == nil {
+			m.Help = template.Help
+		}
+		if m.Default == nil {
+			m.Default = template.Default
+		}
+		if m.Min == nil {
+			m.Min = template.Min
+		}
+		if m.Max == nil {
+			m.Max = template.Max
+		}
+		if m.Precision == nil {
+			m.Precision = template.Precision
+		}
+		if m.Options == nil {
+			m.Options = template.Options
+		}
+	}
+	if m.Options == nil {
+		m.Options = []string{}
+	}
+	return nil
+}
+
+func (m *ModelParameterRule) UnmarshalJSON(data []byte) error {
+	type alias ModelParameterRule
+
+	temp := &struct {
+		*alias
+	}{
+		alias: (*alias)(m),
+	}
+
+	if err := json.Unmarshal(data, &temp); err != nil {
+		return err
+	}
+
+	if err := m.TransformTemplate(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (m *ModelParameterRule) UnmarshalYAML(value *yaml.Node) error {
+	type alias ModelParameterRule
+
+	temp := &struct {
+		*alias `yaml:",inline"`
+	}{
+		alias: (*alias)(m),
+	}
+
+	if err := value.Decode(&temp); err != nil {
+		return err
+	}
+
+	if err := m.TransformTemplate(); err != nil {
+		return err
+	}
+
+	return nil
+}
+
 func isParameterRule(fl validator.FieldLevel) bool {
 	// if use_template is empty, then label, type should be required
 	// try get the value of use_template
@@ -131,6 +346,54 @@ type ModelDeclaration struct {
 	PriceConfig     *ModelPriceConfig              `json:"pricing" yaml:"pricing" validate:"omitempty"`
 }
 
+func (m *ModelDeclaration) UnmarshalJSON(data []byte) error {
+	type alias ModelDeclaration
+
+	temp := &struct {
+		*alias
+	}{
+		alias: (*alias)(m),
+	}
+
+	if err := json.Unmarshal(data, &temp); err != nil {
+		return err
+	}
+
+	if m.FetchFrom == "" {
+		m.FetchFrom = CONFIGURATE_METHOD_PREDEFINED_MODEL
+	}
+
+	if m.ParameterRules == nil {
+		m.ParameterRules = []ModelParameterRule{}
+	}
+
+	return nil
+}
+
+func (m *ModelDeclaration) UnmarshalYAML(value *yaml.Node) error {
+	type alias ModelDeclaration
+
+	temp := &struct {
+		*alias `yaml:",inline"`
+	}{
+		alias: (*alias)(m),
+	}
+
+	if err := value.Decode(&temp); err != nil {
+		return err
+	}
+
+	if m.FetchFrom == "" {
+		m.FetchFrom = CONFIGURATE_METHOD_PREDEFINED_MODEL
+	}
+
+	if m.ParameterRules == nil {
+		m.ParameterRules = []ModelParameterRule{}
+	}
+
+	return nil
+}
+
 type ModelProviderFormType string
 
 const (
@@ -237,6 +500,14 @@ func (m *ModelProviderDeclaration) UnmarshalJSON(data []byte) error {
 
 	*m = ModelProviderDeclaration(temp.alias)
 
+	if m.ModelCredentialSchema != nil && m.ModelCredentialSchema.CredentialFormSchemas == nil {
+		m.ModelCredentialSchema.CredentialFormSchemas = []ModelProviderCredentialFormSchema{}
+	}
+
+	if m.ProviderCredentialSchema != nil && m.ProviderCredentialSchema.CredentialFormSchemas == nil {
+		m.ProviderCredentialSchema.CredentialFormSchemas = []ModelProviderCredentialFormSchema{}
+	}
+
 	// unmarshal models into map[string]any
 	var models map[string]any
 	if err := json.Unmarshal(temp.Models, &models); err != nil {
@@ -294,6 +565,14 @@ func (m *ModelProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
 
 	*m = ModelProviderDeclaration(temp.alias)
 
+	if m.ModelCredentialSchema != nil && m.ModelCredentialSchema.CredentialFormSchemas == nil {
+		m.ModelCredentialSchema.CredentialFormSchemas = []ModelProviderCredentialFormSchema{}
+	}
+
+	if m.ProviderCredentialSchema != nil && m.ProviderCredentialSchema.CredentialFormSchemas == nil {
+		m.ProviderCredentialSchema.CredentialFormSchemas = []ModelProviderCredentialFormSchema{}
+	}
+
 	// Check if Models is a mapping node
 	if temp.Models.Kind == yaml.MappingNode {
 		m.PositionFiles = make(map[string]string)

+ 149 - 95
internal/types/entities/plugin_entities/model_declaration_test.go

@@ -23,102 +23,101 @@ func parse_yaml_to_json(data []byte) ([]byte, error) {
 	return json_data, nil
 }
 
-const (
-	template = `
-provider: openai
-label:
-  en_US: OpenAI
-description:
-  en_US: Models provided by OpenAI, such as GPT-3.5-Turbo and GPT-4.
-  zh_Hans: OpenAI 提供的模型,例如 GPT-3.5-Turbo 和 GPT-4。
-icon_small:
-  en_US: icon_s_en.svg
-icon_large:
-  en_US: icon_l_en.svg
-background: "#E5E7EB"
-help:
-  title:
-    en_US: Get your API Key from OpenAI
-    zh_Hans: 从 OpenAI 获取 API Key
-  url:
-    en_US: https://platform.openai.com/account/api-keys
-supported_model_types:
-  - llm
-  - text-embedding
-  - speech2text
-  - moderation
-  - tts
-configurate_methods:
-  - predefined-model
-  - customizable-model
-model_credential_schema:
-  model:
-    label:
-      en_US: Model Name
-      zh_Hans: 模型名称
-    placeholder:
-      en_US: Enter your model name
-      zh_Hans: 输入模型名称
-  credential_form_schemas:
-    - variable: openai_api_key
-      label:
-        en_US: API Key
-      type: secret-input
-      required: true
-      placeholder:
-        zh_Hans: 在此输入您的 API Key
-        en_US: Enter your API Key
-    - variable: openai_organization
-      label:
-        zh_Hans: 组织 ID
-        en_US: Organization
-      type: text-input
-      required: false
-      placeholder:
-        zh_Hans: 在此输入您的组织 ID
-        en_US: Enter your Organization ID
-    - variable: openai_api_base
-      label:
-        zh_Hans: API Base
-        en_US: API Base
-      type: text-input
-      required: false
-      placeholder:
-        zh_Hans: 在此输入您的 API Base
-        en_US: Enter your API Base
-provider_credential_schema:
-  credential_form_schemas:
-    - variable: openai_api_key
-      label:
-        en_US: API Key
-      type: secret-input
-      required: true
-      placeholder:
-        zh_Hans: 在此输入您的 API Key
-        en_US: Enter your API Key
-    - variable: openai_organization
-      label:
-        zh_Hans: 组织 ID
-        en_US: Organization
-      type: text-input
-      required: false
-      placeholder:
-        zh_Hans: 在此输入您的组织 ID
-        en_US: Enter your Organization ID
-    - variable: openai_api_base
-      label:
-        zh_Hans: API Base
-        en_US: API Base
-      type: text-input
-      required: false
-      placeholder:
-        zh_Hans: 在此输入您的 API Base, 如:https://api.openai.com
-        en_US: Enter your API Base, e.g. https://api.openai.com
-    `
-)
-
 func TestFullFunctionModelProvider_Validate(t *testing.T) {
-	json_data, err := parse_yaml_to_json([]byte(template))
+	const (
+		model_provider_template = `
+    provider: openai
+    label:
+      en_US: OpenAI
+    description:
+      en_US: Models provided by OpenAI, such as GPT-3.5-Turbo and GPT-4.
+      zh_Hans: OpenAI 提供的模型,例如 GPT-3.5-Turbo 和 GPT-4。
+    icon_small:
+      en_US: icon_s_en.svg
+    icon_large:
+      en_US: icon_l_en.svg
+    background: "#E5E7EB"
+    help:
+      title:
+        en_US: Get your API Key from OpenAI
+        zh_Hans: 从 OpenAI 获取 API Key
+      url:
+        en_US: https://platform.openai.com/account/api-keys
+    supported_model_types:
+      - llm
+      - text-embedding
+      - speech2text
+      - moderation
+      - tts
+    configurate_methods:
+      - predefined-model
+      - customizable-model
+    model_credential_schema:
+      model:
+        label:
+          en_US: Model Name
+          zh_Hans: 模型名称
+        placeholder:
+          en_US: Enter your model name
+          zh_Hans: 输入模型名称
+      credential_form_schemas:
+        - variable: openai_api_key
+          label:
+            en_US: API Key
+          type: secret-input
+          required: true
+          placeholder:
+            zh_Hans: 在此输入您的 API Key
+            en_US: Enter your API Key
+        - variable: openai_organization
+          label:
+            zh_Hans: 组织 ID
+            en_US: Organization
+          type: text-input
+          required: false
+          placeholder:
+            zh_Hans: 在此输入您的组织 ID
+            en_US: Enter your Organization ID
+        - variable: openai_api_base
+          label:
+            zh_Hans: API Base
+            en_US: API Base
+          type: text-input
+          required: false
+          placeholder:
+            zh_Hans: 在此输入您的 API Base
+            en_US: Enter your API Base
+    provider_credential_schema:
+      credential_form_schemas:
+        - variable: openai_api_key
+          label:
+            en_US: API Key
+          type: secret-input
+          required: true
+          placeholder:
+            zh_Hans: 在此输入您的 API Key
+            en_US: Enter your API Key
+        - variable: openai_organization
+          label:
+            zh_Hans: 组织 ID
+            en_US: Organization
+          type: text-input
+          required: false
+          placeholder:
+            zh_Hans: 在此输入您的组织 ID
+            en_US: Enter your Organization ID
+        - variable: openai_api_base
+          label:
+            zh_Hans: API Base
+            en_US: API Base
+          type: text-input
+          required: false
+          placeholder:
+            zh_Hans: 在此输入您的 API Base, 如:https://api.openai.com
+            en_US: Enter your API Base, e.g. https://api.openai.com
+        `
+	)
+	json_data, err := parse_yaml_to_json([]byte(model_provider_template))
 	if err != nil {
 		t.Error(err)
 	}
@@ -128,3 +127,58 @@ func TestFullFunctionModelProvider_Validate(t *testing.T) {
 		t.Errorf("UnmarshalModelProviderConfiguration() error = %v", err)
 	}
 }
+
+func TestModelParameterRule_UseTemplateYAML(t *testing.T) {
+	const (
+		model_parameter_rule_template = `
+name: temperature
+use_template: temperature
+`
+	)
+
+	yaml_data := []byte(model_parameter_rule_template)
+
+	model, err := parser.UnmarshalYamlBytes[ModelParameterRule](yaml_data)
+	if err != nil {
+		t.Errorf("UnmarshalModelParameterRule() error = %v", err)
+		return
+	}
+
+	if model.Type == nil {
+		t.Errorf("UnmarshalModelParameterRule() error = %v", err)
+		return
+	}
+
+	if *model.Type != PARAMETER_TYPE_FLOAT {
+		t.Errorf("UnmarshalModelParameterRule() error = %v", err)
+	}
+
+	if model.Min == nil || model.Max == nil || model.Precision == nil {
+		t.Errorf("Missing default value")
+	}
+}
+
+func TestModelParameterRule_UseTemplateJSON(t *testing.T) {
+	const (
+		model_parameter_rule_template = `{"name": "temperature", "use_template": "temperature"}`
+	)
+
+	json_data := []byte(model_parameter_rule_template)
+
+	model, err := parser.UnmarshalJsonBytes[ModelParameterRule](json_data)
+	if err != nil {
+		t.Errorf("UnmarshalModelParameterRule() error = %v", err)
+	}
+
+	if model.Type == nil {
+		t.Errorf("UnmarshalModelParameterRule() error = %v", err)
+	}
+
+	if *model.Type != PARAMETER_TYPE_FLOAT {
+		t.Errorf("UnmarshalModelParameterRule() error = %v", err)
+	}
+
+	if model.Min == nil || model.Max == nil || model.Precision == nil {
+		t.Errorf("Missing default value")
+	}
+}

+ 5 - 0
internal/utils/parser/ptr.go

@@ -0,0 +1,5 @@
+package parser
+
+func ToPtr[T any](value T) *T {
+	return &value
+}