Explorar o código

feat: support tool output schema

Yeuoly hai 1 ano
pai
achega
c8fa710768

+ 12 - 0
internal/types/entities/plugin_entities/provider_declaration.go

@@ -1 +1,13 @@
 package plugin_entities
+
+type GenericProviderDeclaration struct {
+}
+
+type ModelProviderDeclaration struct {
+}
+
+type ToolProviderDeclaration struct {
+}
+
+type EndpointProviderDeclaration struct {
+}

+ 19 - 3
internal/types/entities/plugin_entities/tool_configuration.go

@@ -9,6 +9,7 @@ import (
 	en_translations "github.com/go-playground/validator/v10/translations/en"
 	"github.com/langgenius/dify-plugin-daemon/internal/types/validators"
 	"github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
+	"github.com/xeipuuv/gojsonschema"
 )
 
 type ToolIdentity struct {
@@ -85,10 +86,25 @@ type ToolDescription struct {
 	LLM   string     `json:"llm" validate:"required"`
 }
 
+type ToolOutputSchema map[string]any
+
 type ToolConfiguration struct {
-	Identity    ToolIdentity    `json:"identity" validate:"required"`
-	Description ToolDescription `json:"description" validate:"required"`
-	Parameters  []ToolParameter `json:"parameters" validate:"omitempty,dive"`
+	Identity     ToolIdentity     `json:"identity" validate:"required"`
+	Description  ToolDescription  `json:"description" validate:"required"`
+	Parameters   []ToolParameter  `json:"parameters" validate:"omitempty,dive"`
+	OutputSchema ToolOutputSchema `json:"output_schema" validate:"omitempty,json_schema"`
+}
+
+func isJSONSchema(fl validator.FieldLevel) bool {
+	_, err := gojsonschema.NewSchema(gojsonschema.NewGoLoader(fl.Field().Interface()))
+	if err != nil {
+		return false
+	}
+	return true
+}
+
+func init() {
+	validators.GlobalEntitiesValidator.RegisterValidation("json_schema", isJSONSchema)
 }
 
 type ToolCredentialsOption struct {

+ 118 - 0
internal/types/entities/plugin_entities/tool_configuration_test.go

@@ -539,3 +539,121 @@ func TestWrongToolParameterFormToolProvider_Validate(t *testing.T) {
 		return
 	}
 }
+
+func TestJSONSchemaTypeToolProvider_Validate(t *testing.T) {
+	const data = `
+{
+	"identity": {
+		"author": "author",
+		"name": "name",
+		"description": {
+			"en_US": "description",
+			"zh_Hans": "描述",
+			"pt_BR": "descrição"
+		},
+		"icon": "icon",
+		"label": {
+			"en_US": "label",
+			"zh_Hans": "标签",
+			"pt_BR": "etiqueta"
+		},
+		"tags": []
+	},
+	"credentials_schema": {},
+	"tools": [
+		{
+			"identity": {
+				"author": "author",
+				"name": "tool",
+				"label": {
+					"en_US": "label",
+					"zh_Hans": "标签",
+					"pt_BR": "etiqueta"
+				}
+			},
+			"description": {
+				"human": {
+					"en_US": "description",
+					"zh_Hans": "描述",
+					"pt_BR": "descrição"
+				},
+				"llm": "description"
+			},
+			"output_schema": {
+				"type": "object",
+				"properties": {
+					"name": {
+						"type": "string"
+					}
+				}
+			}
+		}
+	]
+}
+	`
+
+	_, err := UnmarshalToolProviderConfiguration([]byte(data))
+	if err != nil {
+		t.Errorf("UnmarshalToolProviderConfiguration() error = %v, wantErr %v", err, true)
+		return
+	}
+}
+
+func TestWrongJSONSchemaToolProvider_Validate(t *testing.T) {
+	const data = `
+{
+	"identity": {
+		"author": "author",
+		"name": "name",
+		"description": {
+			"en_US": "description",
+			"zh_Hans": "描述",
+			"pt_BR": "descrição"
+		},
+		"icon": "icon",
+		"label": {
+			"en_US": "label",
+			"zh_Hans": "标签",
+			"pt_BR": "etiqueta"
+		},
+		"tags": []
+	},
+	"credentials_schema": {},
+	"tools": [
+		{
+			"identity": {
+				"author": "author",
+				"name": "tool",
+				"label": {
+					"en_US": "label",
+					"zh_Hans": "标签",
+					"pt_BR": "etiqueta"
+				}
+			},
+			"description": {
+				"human": {
+					"en_US": "description",
+					"zh_Hans": "描述",
+					"pt_BR": "descrição"
+				},
+				"llm": "description"
+			},
+			"output_schema": {
+				"type": "object",
+				"properties": {
+					"name": {
+						"type": "aaa"
+					}
+				}
+			}
+		}
+	]
+}
+	`
+
+	_, err := UnmarshalToolProviderConfiguration([]byte(data))
+	if err == nil {
+		t.Errorf("UnmarshalToolProviderConfiguration() error = %v, wantErr %v", err, true)
+		return
+	}
+}