Bläddra i källkod

fix: disbaled spaces and uppercase in tool identity

Yeuoly 6 månader sedan
förälder
incheckning
39449ca3de

+ 25 - 2
pkg/entities/plugin_entities/tool_declaration.go

@@ -3,6 +3,7 @@ package plugin_entities
 import (
 	"encoding/json"
 	"fmt"
+	"regexp"
 
 	"github.com/go-playground/locales/en"
 	ut "github.com/go-playground/universal-translator"
@@ -17,10 +18,21 @@ import (
 
 type ToolIdentity struct {
 	Author string     `json:"author" yaml:"author" validate:"required"`
-	Name   string     `json:"name" yaml:"name" validate:"required"`
+	Name   string     `json:"name" yaml:"name" validate:"required,tool_identity_name"`
 	Label  I18nObject `json:"label" yaml:"label" validate:"required"`
 }
 
+var toolIdentityNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
+
+func isToolIdentityName(fl validator.FieldLevel) bool {
+	value := fl.Field().String()
+	return toolIdentityNameRegex.MatchString(value)
+}
+
+func init() {
+	validators.GlobalEntitiesValidator.RegisterValidation("tool_identity_name", isToolIdentityName)
+}
+
 type ToolParameterOption struct {
 	Value string     `json:"value" yaml:"value" validate:"required"`
 	Label I18nObject `json:"label" yaml:"label" validate:"required"`
@@ -188,13 +200,24 @@ func init() {
 
 type ToolProviderIdentity struct {
 	Author      string                        `json:"author" validate:"required"`
-	Name        string                        `json:"name" validate:"required"`
+	Name        string                        `json:"name" validate:"required,tool_provider_identity_name"`
 	Description I18nObject                    `json:"description"`
 	Icon        string                        `json:"icon" validate:"required"`
 	Label       I18nObject                    `json:"label" validate:"required"`
 	Tags        []manifest_entities.PluginTag `json:"tags" validate:"omitempty,dive,plugin_tag"`
 }
 
+var toolProviderIdentityNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
+
+func isToolProviderIdentityName(fl validator.FieldLevel) bool {
+	value := fl.Field().String()
+	return toolProviderIdentityNameRegex.MatchString(value)
+}
+
+func init() {
+	validators.GlobalEntitiesValidator.RegisterValidation("tool_provider_identity_name", isToolProviderIdentityName)
+}
+
 type ToolProviderDeclaration struct {
 	Identity          ToolProviderIdentity `json:"identity" yaml:"identity" validate:"required"`
 	CredentialsSchema []ProviderConfig     `json:"credentials_schema" yaml:"credentials_schema" validate:"omitempty,dive"`

+ 59 - 0
pkg/entities/plugin_entities/tool_declaration_test.go

@@ -1266,3 +1266,62 @@ func TestInvalidJSONSchemaToolProvider_Validate(t *testing.T) {
 		t.Errorf("TestInvalidJSONSchemaToolProvider_Validate() error = %v", err)
 	}
 }
+
+func TestToolName_Validate(t *testing.T) {
+	data := parser.MarshalJsonBytes(ToolProviderIdentity{
+		Author: "author",
+		Name:   "tool-name",
+		Description: I18nObject{
+			EnUS:   "description",
+			ZhHans: "描述",
+		},
+		Icon: "icon",
+		Label: I18nObject{
+			EnUS:   "label",
+			ZhHans: "标签",
+		},
+	})
+
+	if _, err := parser.UnmarshalJsonBytes[ToolProviderIdentity](data); err != nil {
+		t.Errorf("TestToolName_Validate() error = %v", err)
+	}
+
+	data = parser.MarshalJsonBytes(ToolProviderIdentity{
+		Author: "author",
+		Name:   "tool AA",
+		Label: I18nObject{
+			EnUS:   "label",
+			ZhHans: "标签",
+		},
+	})
+
+	if _, err := parser.UnmarshalJsonBytes[ToolProviderIdentity](data); err == nil {
+		t.Errorf("TestToolName_Validate() error = %v", err)
+	}
+
+	data = parser.MarshalJsonBytes(ToolIdentity{
+		Author: "author",
+		Name:   "tool-name-123",
+		Label: I18nObject{
+			EnUS:   "label",
+			ZhHans: "标签",
+		},
+	})
+
+	if _, err := parser.UnmarshalJsonBytes[ToolIdentity](data); err != nil {
+		t.Errorf("TestToolName_Validate() error = %v", err)
+	}
+
+	data = parser.MarshalJsonBytes(ToolIdentity{
+		Author: "author",
+		Name:   "tool name-123",
+		Label: I18nObject{
+			EnUS:   "label",
+			ZhHans: "标签",
+		},
+	})
+
+	if _, err := parser.UnmarshalJsonBytes[ToolIdentity](data); err == nil {
+		t.Errorf("TestToolName_Validate() error = %v", err)
+	}
+}