瀏覽代碼

feat: tests for model entities

Yeuoly 1 年之前
父節點
當前提交
994c27dbf4

+ 69 - 29
internal/types/entities/plugin_entities/model_configuration.go

@@ -1,6 +1,8 @@
 package plugin_entities
 
 import (
+	"encoding/json"
+
 	"github.com/go-playground/locales/en"
 	ut "github.com/go-playground/universal-translator"
 	"github.com/go-playground/validator/v10"
@@ -12,7 +14,7 @@ type ModelType string
 
 const (
 	MODEL_TYPE_LLM            ModelType = "llm"
-	MODEL_TYPE_TEXT_EMBEDDING ModelType = "text_embedding"
+	MODEL_TYPE_TEXT_EMBEDDING ModelType = "text-embedding"
 	MODEL_TYPE_RERANKING      ModelType = "rerank"
 	MODEL_TYPE_SPEECH2TEXT    ModelType = "speech2text"
 	MODEL_TYPE_MODERATION     ModelType = "moderation"
@@ -38,8 +40,8 @@ func isModelType(fl validator.FieldLevel) bool {
 type ModelProviderConfigurateMethod string
 
 const (
-	CONFIGURATE_METHOD_PREDEFINED_MODEL   ModelProviderConfigurateMethod = "predefined_model"
-	CONFIGURATE_METHOD_CUSTOMIZABLE_MODEL ModelProviderConfigurateMethod = "customizable_model"
+	CONFIGURATE_METHOD_PREDEFINED_MODEL   ModelProviderConfigurateMethod = "predefined-model"
+	CONFIGURATE_METHOD_CUSTOMIZABLE_MODEL ModelProviderConfigurateMethod = "customizable-model"
 )
 
 func isModelProviderConfigurateMethod(fl validator.FieldLevel) bool {
@@ -74,17 +76,38 @@ func isModelParameterType(fl validator.FieldLevel) bool {
 }
 
 type ModelParameterRule struct {
-	Name        string             `json:"name" validate:"required,lt=256"`
-	UseTemplate *string            `json:"use_template" validate:"omitempty,lt=256"`
-	Label       I18nObject         `json:"label" validate:"required"`
-	Type        ModelParameterType `json:"type" validate:"required,model_parameter_type"`
-	Help        *I18nObject        `json:"help" validate:"omitempty"`
-	Required    bool               `json:"required"`
-	Default     *any               `json:"default" validate:"omitempty,is_basic_type"`
-	Min         *float64           `json:"min" validate:"omitempty"`
-	Max         *float64           `json:"max" validate:"omitempty"`
-	Precision   *int               `json:"precision" validate:"omitempty"`
-	Options     []string           `json:"options" validate:"omitempty,dive,lt=256"`
+	Name        string              `json:"name" validate:"required,lt=256"`
+	UseTemplate *string             `json:"use_template" validate:"omitempty,lt=256"`
+	Label       *I18nObject         `json:"label" validate:"omitempty"`
+	Type        *ModelParameterType `json:"type" validate:"omitempty,model_parameter_type"`
+	Help        *I18nObject         `json:"help" validate:"omitempty"`
+	Required    bool                `json:"required"`
+	Default     *any                `json:"default" validate:"omitempty,is_basic_type"`
+	Min         *float64            `json:"min" validate:"omitempty"`
+	Max         *float64            `json:"max" validate:"omitempty"`
+	Precision   *int                `json:"precision" validate:"omitempty"`
+	Options     []string            `json:"options" validate:"omitempty,dive,lt=256"`
+}
+
+func isParameterRule(fl validator.FieldLevel) bool {
+	// if use_template is empty, then label, type should be required
+	// try get the value of use_template
+	use_template_handle := fl.Field().FieldByName("UseTemplate")
+	// check if use_template is null pointer
+	if use_template_handle.IsNil() {
+		// label and type should be required
+		// try get the value of label
+		if fl.Field().FieldByName("Label").IsNil() {
+			return false
+		}
+
+		// try get the value of type
+		if fl.Field().FieldByName("Type").IsNil() {
+			return false
+		}
+	}
+
+	return true
 }
 
 type ModelPriceConfig struct {
@@ -102,19 +125,19 @@ type ModelConfiguration struct {
 	Model           string                         `json:"model" validate:"required,lt=256"`
 	Label           I18nObject                     `json:"label" validate:"required"`
 	ModelType       ModelType                      `json:"model_type" validate:"required,model_type"`
-	Features        []string                       `json:"features" validate:"omitempty,dive,lt=256"`
-	FetchFrom       ModelProviderConfigurateMethod `json:"fetch_from" validate:"required,model_provider_configurate_method"`
+	Features        []string                       `json:"features" validate:"omitempty,lte=256,dive,lt=256"`
+	FetchFrom       ModelProviderConfigurateMethod `json:"fetch_from" validate:"omitempty,model_provider_configurate_method"`
 	ModelProperties map[string]any                 `json:"model_properties" validate:"omitempty,dive,is_basic_type"`
 	Deprecated      bool                           `json:"deprecated"`
-	ParameterRules  []ModelParameterRule           `json:"parameter_rules" validate:"omitempty,dive"`
+	ParameterRules  []ModelParameterRule           `json:"parameter_rules" validate:"omitempty,lte=128,dive,parameter_rule"`
 	PriceConfig     *ModelPriceConfig              `json:"price_config" validate:"omitempty"`
 }
 
 type ModelProviderFormType string
 
 const (
-	FORM_TYPE_TEXT_INPUT   ModelProviderFormType = "text_input"
-	FORM_TYPE_SECRET_INPUT ModelProviderFormType = "secret_input"
+	FORM_TYPE_TEXT_INPUT   ModelProviderFormType = "text-input"
+	FORM_TYPE_SECRET_INPUT ModelProviderFormType = "secret-input"
 	FORM_TYPE_SELECT       ModelProviderFormType = "select"
 	FORM_TYPE_RADIO        ModelProviderFormType = "radio"
 	FORM_TYPE_SWITCH       ModelProviderFormType = "switch"
@@ -141,7 +164,7 @@ type ModelProviderFormShowOnObject struct {
 type ModelProviderFormOption struct {
 	Label  I18nObject                      `json:"label" validate:"required"`
 	Value  string                          `json:"value" validate:"required,lt=256"`
-	ShowOn []ModelProviderFormShowOnObject `json:"show_on" validate:"omitempty,dive,lt=16"`
+	ShowOn []ModelProviderFormShowOnObject `json:"show_on" validate:"omitempty,lte=16,dive"`
 }
 
 type ModelProviderCredentialFormSchema struct {
@@ -150,14 +173,14 @@ type ModelProviderCredentialFormSchema struct {
 	Type        ModelProviderFormType           `json:"type" validate:"required,model_provider_form_type"`
 	Required    bool                            `json:"required"`
 	Default     *string                         `json:"default" validate:"omitempty,lt=256"`
-	Options     []ModelProviderFormOption       `json:"options" validate:"omitempty,dive,lt=128"`
+	Options     []ModelProviderFormOption       `json:"options" validate:"omitempty,lte=128,dive"`
 	Placeholder *I18nObject                     `json:"placeholder" validate:"omitempty"`
 	MaxLength   int                             `json:"max_length"`
-	ShowOn      []ModelProviderFormShowOnObject `json:"show_on" validate:"omitempty,dive,lt=16"`
+	ShowOn      []ModelProviderFormShowOnObject `json:"show_on" validate:"omitempty,lte=16,dive"`
 }
 
 type ModelProviderCredentialSchema struct {
-	CredentialFormSchemas []ModelProviderCredentialFormSchema `json:"credential_form_schemas" validate:"omitempty,dive,lt=32"`
+	CredentialFormSchemas []ModelProviderCredentialFormSchema `json:"credential_form_schemas" validate:"omitempty,lte=32,dive"`
 }
 
 type FieldModelSchema struct {
@@ -167,12 +190,12 @@ type FieldModelSchema struct {
 
 type ModelCredentialSchema struct {
 	Model                  FieldModelSchema                    `json:"model" validate:"required"`
-	CredentialsFormSchemas []ModelProviderCredentialFormSchema `json:"credentials_form_schemas" validate:"omitempty,dive,lt=32"`
+	CredentialsFormSchemas []ModelProviderCredentialFormSchema `json:"credentials_form_schemas" validate:"omitempty,lte=32,dive"`
 }
 
 type ModelProviderHelpEntity struct {
-	Title I18nObject `json:"title" validate:"required,lt=256"`
-	URL   string     `json:"url" validate:"required,lt=256"`
+	Title I18nObject `json:"title" validate:"required"`
+	URL   I18nObject `json:"url" validate:"required"`
 }
 
 type ModelProviderConfiguration struct {
@@ -183,9 +206,9 @@ type ModelProviderConfiguration struct {
 	IconLarge                *I18nObject                      `json:"icon_large" validate:"omitempty"`
 	Background               *string                          `json:"background" validate:"omitempty"`
 	Help                     *ModelProviderHelpEntity         `json:"help" validate:"omitempty"`
-	SupportedModelTypes      []ModelType                      `json:"supported_model_types" validate:"required,dive,model_type,unique"`
-	ConfigurateMethods       []ModelProviderConfigurateMethod `json:"configurate_methods" validate:"required,dive,model_provider_configurate_method,unique"`
-	Models                   []ModelConfiguration             `json:"models" validate:"omitempty,dive,lt=1024"`
+	SupportedModelTypes      []ModelType                      `json:"supported_model_types" validate:"required,lte=16,dive,model_type"`
+	ConfigurateMethods       []ModelProviderConfigurateMethod `json:"configurate_methods" validate:"required,lte=16,dive,model_provider_configurate_method"`
+	Models                   []ModelConfiguration             `json:"models" validate:"omitempty,lte=1024,dive"`
 	ProviderCredentialSchema *ModelProviderCredentialSchema   `json:"provider_credential_schema" validate:"omitempty"`
 	ModelCredentialSchema    *ModelCredentialSchema           `json:"model_credential_schema" validate:"omitempty"`
 }
@@ -254,5 +277,22 @@ func init() {
 		},
 	)
 
+	global_model_provider_validator.RegisterValidation("parameter_rule", isParameterRule)
+
 	global_model_provider_validator.RegisterValidation("is_basic_type", isGenericType)
 }
+
+func UnmarshalModelProviderConfiguration(data []byte) (*ModelProviderConfiguration, error) {
+	var modelProviderConfiguration ModelProviderConfiguration
+	err := json.Unmarshal(data, &modelProviderConfiguration)
+	if err != nil {
+		return nil, err
+	}
+
+	err = global_model_provider_validator.Struct(modelProviderConfiguration)
+	if err != nil {
+		return nil, err
+	}
+
+	return &modelProviderConfiguration, nil
+}

+ 164 - 0
internal/types/entities/plugin_entities/model_configuration_test.go

@@ -0,0 +1,164 @@
+package plugin_entities
+
+import (
+	"encoding/json"
+	"testing"
+
+	"gopkg.in/yaml.v3"
+)
+
+func parse_yaml_to_json(data []byte) ([]byte, error) {
+	var obj interface{}
+	err := yaml.Unmarshal(data, &obj)
+	if err != nil {
+		return nil, err
+	}
+
+	json_data, err := json.Marshal(obj)
+	if err != nil {
+		return nil, err
+	}
+
+	return json_data, nil
+}
+
+const (
+	template = `
+provider: openai
+label:
+  en_US: OpenAI
+description:
+  en_US: Models provided by OpenAI, such as GPT-3.5-Turbo and GPT-4.
+  zh_Hans: OpenAI 提供的模型,例如 GPT-3.5-Turbo 和 GPT-4。
+icon_small:
+  en_US: icon_s_en.svg
+icon_large:
+  en_US: icon_l_en.svg
+background: "#E5E7EB"
+help:
+  title:
+    en_US: Get your API Key from OpenAI
+    zh_Hans: 从 OpenAI 获取 API Key
+  url:
+    en_US: https://platform.openai.com/account/api-keys
+supported_model_types:
+  - llm
+  - text-embedding
+  - speech2text
+  - moderation
+  - tts
+configurate_methods:
+  - predefined-model
+  - customizable-model
+model_credential_schema:
+  model:
+    label:
+      en_US: Model Name
+      zh_Hans: 模型名称
+    placeholder:
+      en_US: Enter your model name
+      zh_Hans: 输入模型名称
+  credential_form_schemas:
+    - variable: openai_api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key
+    - variable: openai_organization
+      label:
+        zh_Hans: 组织 ID
+        en_US: Organization
+      type: text-input
+      required: false
+      placeholder:
+        zh_Hans: 在此输入您的组织 ID
+        en_US: Enter your Organization ID
+    - variable: openai_api_base
+      label:
+        zh_Hans: API Base
+        en_US: API Base
+      type: text-input
+      required: false
+      placeholder:
+        zh_Hans: 在此输入您的 API Base
+        en_US: Enter your API Base
+provider_credential_schema:
+  credential_form_schemas:
+    - variable: openai_api_key
+      label:
+        en_US: API Key
+      type: secret-input
+      required: true
+      placeholder:
+        zh_Hans: 在此输入您的 API Key
+        en_US: Enter your API Key
+    - variable: openai_organization
+      label:
+        zh_Hans: 组织 ID
+        en_US: Organization
+      type: text-input
+      required: false
+      placeholder:
+        zh_Hans: 在此输入您的组织 ID
+        en_US: Enter your Organization ID
+    - variable: openai_api_base
+      label:
+        zh_Hans: API Base
+        en_US: API Base
+      type: text-input
+      required: false
+      placeholder:
+        zh_Hans: 在此输入您的 API Base, 如:https://api.openai.com
+        en_US: Enter your API Base, e.g. https://api.openai.com
+models:
+  - model: gpt-3.5-turbo-16k-0613
+    label:
+      zh_Hans: gpt-3.5-turbo-16k-0613
+      en_US: gpt-3.5-turbo-16k-0613
+    model_type: llm
+    features:
+      - multi-tool-call
+      - agent-thought
+      - stream-tool-call
+    model_properties:
+      mode: chat
+      context_size: 16385
+    parameter_rules:
+      - name: temperature
+        use_template: temperature
+      - name: top_p
+        use_template: top_p
+      - name: presence_penalty
+        use_template: presence_penalty
+      - name: frequency_penalty
+        use_template: frequency_penalty
+      - name: max_tokens
+        use_template: max_tokens
+        default: 512
+        min: 1
+        max: 16385
+      - name: response_format
+        use_template: response_format
+    pricing:
+      input: '0.003'
+      output: '0.004'
+      unit: '0.001'
+      currency: USD
+    `
+)
+
+func TestFullFunctionModelProvider_Validate(t *testing.T) {
+	json_data, err := parse_yaml_to_json([]byte(template))
+	if err != nil {
+		t.Error(err)
+	}
+
+	_, err = UnmarshalModelProviderConfiguration(json_data)
+	if err != nil {
+		t.Errorf("UnmarshalModelProviderConfiguration() error = %v", err)
+	}
+
+}