tool_declaration.go 14 KB


  1. package plugin_entities
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "github.com/go-playground/locales/en"
  6. ut "github.com/go-playground/universal-translator"
  7. "github.com/go-playground/validator/v10"
  8. en_translations "github.com/go-playground/validator/v10/translations/en"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  10. "github.com/langgenius/dify-plugin-daemon/pkg/entities/manifest_entities"
  11. "github.com/langgenius/dify-plugin-daemon/pkg/validators"
  12. "github.com/xeipuuv/gojsonschema"
  13. "gopkg.in/yaml.v3"
  14. )
  15. type ToolIdentity struct {
  16. Author string `json:"author" yaml:"author" validate:"required"`
  17. Name string `json:"name" yaml:"name" validate:"required"`
  18. Label I18nObject `json:"label" yaml:"label" validate:"required"`
  19. }
  20. type ToolParameterOption struct {
  21. Value string `json:"value" yaml:"value" validate:"required"`
  22. Label I18nObject `json:"label" yaml:"label" validate:"required"`
  23. }
  24. type ToolParameterType string
  25. const (
  26. TOOL_PARAMETER_TYPE_STRING ToolParameterType = STRING
  27. TOOL_PARAMETER_TYPE_NUMBER ToolParameterType = NUMBER
  28. TOOL_PARAMETER_TYPE_BOOLEAN ToolParameterType = BOOLEAN
  29. TOOL_PARAMETER_TYPE_SELECT ToolParameterType = SELECT
  30. TOOL_PARAMETER_TYPE_SECRET_INPUT ToolParameterType = SECRET_INPUT
  31. TOOL_PARAMETER_TYPE_FILE ToolParameterType = FILE
  32. TOOL_PARAMETER_TYPE_FILES ToolParameterType = FILES
  33. TOOL_PARAMETER_TYPE_APP_SELECTOR ToolParameterType = APP_SELECTOR
  34. TOOL_PARAMETER_TYPE_MODEL_SELECTOR ToolParameterType = MODEL_SELECTOR
  35. // TOOL_PARAMETER_TYPE_TOOL_SELECTOR ToolParameterType = TOOL_SELECTOR
  36. )
  37. func isToolParameterType(fl validator.FieldLevel) bool {
  38. value := fl.Field().String()
  39. switch value {
  40. case string(TOOL_PARAMETER_TYPE_STRING),
  41. string(TOOL_PARAMETER_TYPE_NUMBER),
  42. string(TOOL_PARAMETER_TYPE_BOOLEAN),
  43. string(TOOL_PARAMETER_TYPE_SELECT),
  44. string(TOOL_PARAMETER_TYPE_SECRET_INPUT),
  45. string(TOOL_PARAMETER_TYPE_FILE),
  46. string(TOOL_PARAMETER_TYPE_FILES),
  47. // string(TOOL_PARAMETER_TYPE_TOOL_SELECTOR),
  48. string(TOOL_PARAMETER_TYPE_APP_SELECTOR),
  49. string(TOOL_PARAMETER_TYPE_MODEL_SELECTOR):
  50. return true
  51. }
  52. return false
  53. }
  54. type ToolParameterForm string
  55. const (
  56. TOOL_PARAMETER_FORM_SCHEMA ToolParameterForm = "schema"
  57. TOOL_PARAMETER_FORM_FORM ToolParameterForm = "form"
  58. TOOL_PARAMETER_FORM_LLM ToolParameterForm = "llm"
  59. )
  60. func isToolParameterForm(fl validator.FieldLevel) bool {
  61. value := fl.Field().String()
  62. switch value {
  63. case string(TOOL_PARAMETER_FORM_SCHEMA),
  64. string(TOOL_PARAMETER_FORM_FORM),
  65. string(TOOL_PARAMETER_FORM_LLM):
  66. return true
  67. }
  68. return false
  69. }
  70. type ParameterAutoGenerateType string
  71. const (
  72. PARAMETER_AUTO_GENERATE_TYPE_PROMPT_INSTRUCTION ParameterAutoGenerateType = "prompt_instruction"
  73. )
  74. func isParameterAutoGenerateType(fl validator.FieldLevel) bool {
  75. value := fl.Field().String()
  76. switch value {
  77. case string(PARAMETER_AUTO_GENERATE_TYPE_PROMPT_INSTRUCTION):
  78. return true
  79. }
  80. return false
  81. }
  82. func init() {
  83. validators.GlobalEntitiesValidator.RegisterValidation("parameter_auto_generate_type", isParameterAutoGenerateType)
  84. }
  85. type ParameterAutoGenerate struct {
  86. Type ParameterAutoGenerateType `json:"type" yaml:"type" validate:"required,parameter_auto_generate_type"`
  87. }
  88. type ParameterTemplate struct {
  89. Enabled bool `json:"enabled" yaml:"enabled"`
  90. }
  91. type ToolParameter struct {
  92. Name string `json:"name" yaml:"name" validate:"required,gt=0,lt=1024"`
  93. Label I18nObject `json:"label" yaml:"label" validate:"required"`
  94. HumanDescription I18nObject `json:"human_description" yaml:"human_description" validate:"required"`
  95. Type ToolParameterType `json:"type" yaml:"type" validate:"required,tool_parameter_type"`
  96. Scope *string `json:"scope" yaml:"scope" validate:"omitempty,max=1024,is_scope"`
  97. Form ToolParameterForm `json:"form" yaml:"form" validate:"required,tool_parameter_form"`
  98. LLMDescription string `json:"llm_description" yaml:"llm_description" validate:"omitempty"`
  99. Required bool `json:"required" yaml:"required"`
  100. AutoGenerate *ParameterAutoGenerate `json:"auto_generate" yaml:"auto_generate" validate:"omitempty"`
  101. Template *ParameterTemplate `json:"template" yaml:"template" validate:"omitempty"`
  102. Default any `json:"default" yaml:"default" validate:"omitempty,is_basic_type"`
  103. Min *float64 `json:"min" yaml:"min" validate:"omitempty"`
  104. Max *float64 `json:"max" yaml:"max" validate:"omitempty"`
  105. Precision *int `json:"precision" yaml:"precision" validate:"omitempty"`
  106. Options []ToolParameterOption `json:"options" yaml:"options" validate:"omitempty,dive"`
  107. }
  108. type ToolDescription struct {
  109. Human I18nObject `json:"human" validate:"required"`
  110. LLM string `json:"llm" validate:"required"`
  111. }
  112. type ToolOutputSchema map[string]any
  113. type ToolDeclaration struct {
  114. Identity ToolIdentity `json:"identity" yaml:"identity" validate:"required"`
  115. Description ToolDescription `json:"description" yaml:"description" validate:"required"`
  116. Parameters []ToolParameter `json:"parameters" yaml:"parameters" validate:"omitempty,dive"`
  117. OutputSchema ToolOutputSchema `json:"output_schema" yaml:"output_schema" validate:"omitempty,json_schema"`
  118. HasRuntimeParameters bool `json:"has_runtime_parameters" yaml:"has_runtime_parameters"`
  119. }
  120. func isJSONSchema(fl validator.FieldLevel) bool {
  121. // get schema from interface
  122. schemaMapInf := fl.Field().Interface()
  123. // convert to map[string]any
  124. var schemaMap map[string]any
  125. toolSchemaMap, ok := schemaMapInf.(ToolOutputSchema)
  126. if !ok {
  127. agentSchemaMap, ok := schemaMapInf.(AgentStrategyOutputSchema)
  128. if !ok {
  129. return false
  130. }
  131. schemaMap = agentSchemaMap
  132. } else {
  133. schemaMap = toolSchemaMap
  134. }
  135. // validate root schema must be object type
  136. rootType, ok := schemaMap["type"].(string)
  137. if !ok || rootType != "object" {
  138. return false
  139. }
  140. // validate properties
  141. properties, ok := schemaMap["properties"].(map[string]any)
  142. if !ok {
  143. return false
  144. }
  145. // disallow text, json, files as property names
  146. disallowedProps := []string{"text", "json", "files"}
  147. for _, prop := range disallowedProps {
  148. if _, exists := properties[prop]; exists {
  149. return false
  150. }
  151. }
  152. _, err := gojsonschema.NewSchema(gojsonschema.NewGoLoader(fl.Field().Interface()))
  153. if err != nil {
  154. return false
  155. }
  156. return err == nil
  157. }
  158. func init() {
  159. validators.GlobalEntitiesValidator.RegisterValidation("json_schema", isJSONSchema)
  160. }
  161. type ToolProviderIdentity struct {
  162. Author string `json:"author" validate:"required"`
  163. Name string `json:"name" validate:"required"`
  164. Description I18nObject `json:"description"`
  165. Icon string `json:"icon" validate:"required"`
  166. Label I18nObject `json:"label" validate:"required"`
  167. Tags []manifest_entities.PluginTag `json:"tags" validate:"omitempty,dive,plugin_tag"`
  168. }
  169. type ToolProviderDeclaration struct {
  170. Identity ToolProviderIdentity `json:"identity" yaml:"identity" validate:"required"`
  171. CredentialsSchema []ProviderConfig `json:"credentials_schema" yaml:"credentials_schema" validate:"omitempty,dive"`
  172. Tools []ToolDeclaration `json:"tools" yaml:"tools" validate:"required,dive"`
  173. ToolFiles []string `json:"-" yaml:"-"`
  174. }
  175. func (t *ToolProviderDeclaration) MarshalJSON() ([]byte, error) {
  176. type alias ToolProviderDeclaration
  177. p := alias(*t)
  178. if p.CredentialsSchema == nil {
  179. p.CredentialsSchema = []ProviderConfig{}
  180. }
  181. if p.Tools == nil {
  182. p.Tools = []ToolDeclaration{}
  183. }
  184. return json.Marshal(p)
  185. }
  186. func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
  187. type alias struct {
  188. Identity ToolProviderIdentity `yaml:"identity"`
  189. CredentialsSchema yaml.Node `yaml:"credentials_schema"`
  190. CredentialsForProvider yaml.Node `yaml:"credentials_for_provider"`
  191. Tools yaml.Node `yaml:"tools"`
  192. }
  193. var temp alias
  194. err := value.Decode(&temp)
  195. if err != nil {
  196. return err
  197. }
  198. // apply credentials_for_provider to credentials_schema if not exists
  199. if (temp.CredentialsSchema.Kind == yaml.ScalarNode && temp.CredentialsSchema.Value == "") ||
  200. len(temp.CredentialsSchema.Content) == 0 {
  201. temp.CredentialsSchema = temp.CredentialsForProvider
  202. }
  203. // apply identity
  204. t.Identity = temp.Identity
  205. // check if credentials_schema is a map
  206. if temp.CredentialsSchema.Kind != yaml.MappingNode {
  207. // not a map, convert it into array
  208. credentialsSchema := make([]ProviderConfig, 0)
  209. if err := temp.CredentialsSchema.Decode(&credentialsSchema); err != nil {
  210. return err
  211. }
  212. t.CredentialsSchema = credentialsSchema
  213. } else if temp.CredentialsSchema.Kind == yaml.MappingNode {
  214. credentialsSchema := make([]ProviderConfig, 0, len(temp.CredentialsSchema.Content)/2)
  215. currentKey := ""
  216. currentValue := &ProviderConfig{}
  217. for _, item := range temp.CredentialsSchema.Content {
  218. if item.Kind == yaml.ScalarNode {
  219. currentKey = item.Value
  220. } else if item.Kind == yaml.MappingNode {
  221. currentValue = &ProviderConfig{}
  222. if err := item.Decode(currentValue); err != nil {
  223. return err
  224. }
  225. currentValue.Name = currentKey
  226. credentialsSchema = append(credentialsSchema, *currentValue)
  227. }
  228. }
  229. t.CredentialsSchema = credentialsSchema
  230. }
  231. if t.ToolFiles == nil {
  232. t.ToolFiles = []string{}
  233. }
  234. // unmarshal tools
  235. if temp.Tools.Kind == yaml.SequenceNode {
  236. for _, item := range temp.Tools.Content {
  237. if item.Kind == yaml.ScalarNode {
  238. t.ToolFiles = append(t.ToolFiles, item.Value)
  239. } else if item.Kind == yaml.MappingNode {
  240. tool := ToolDeclaration{}
  241. if err := item.Decode(&tool); err != nil {
  242. return err
  243. }
  244. t.Tools = append(t.Tools, tool)
  245. }
  246. }
  247. }
  248. if t.CredentialsSchema == nil {
  249. t.CredentialsSchema = []ProviderConfig{}
  250. }
  251. if t.Tools == nil {
  252. t.Tools = []ToolDeclaration{}
  253. }
  254. if t.Identity.Tags == nil {
  255. t.Identity.Tags = []manifest_entities.PluginTag{}
  256. }
  257. return nil
  258. }
  259. func (t *ToolProviderDeclaration) UnmarshalJSON(data []byte) error {
  260. type alias ToolProviderDeclaration
  261. var temp struct {
  262. alias
  263. CredentialsSchema json.RawMessage `json:"credentials_schema"`
  264. CredentialsForProvider json.RawMessage `json:"credentials_for_provider"`
  265. Tools []json.RawMessage `json:"tools"`
  266. }
  267. if err := json.Unmarshal(data, &temp); err != nil {
  268. return err
  269. }
  270. if len(temp.CredentialsSchema) == 0 {
  271. temp.CredentialsSchema = temp.CredentialsForProvider
  272. }
  273. *t = ToolProviderDeclaration(temp.alias)
  274. // Determine the type of CredentialsSchema
  275. var raw_message map[string]json.RawMessage
  276. if err := json.Unmarshal(temp.CredentialsSchema, &raw_message); err == nil {
  277. // It's an object
  278. credentialsSchemaObject := make(map[string]ProviderConfig)
  279. if err := json.Unmarshal(temp.CredentialsSchema, &credentialsSchemaObject); err != nil {
  280. return fmt.Errorf("failed to unmarshal credentials_schema as object: %v", err)
  281. }
  282. for _, value := range credentialsSchemaObject {
  283. t.CredentialsSchema = append(t.CredentialsSchema, value)
  284. }
  285. } else {
  286. // It's likely an array
  287. var credentials_schema_array []ProviderConfig
  288. if err := json.Unmarshal(temp.CredentialsSchema, &credentials_schema_array); err != nil {
  289. return fmt.Errorf("failed to unmarshal credentials_schema as array: %v", err)
  290. }
  291. t.CredentialsSchema = credentials_schema_array
  292. }
  293. if t.ToolFiles == nil {
  294. t.ToolFiles = []string{}
  295. }
  296. // unmarshal tools
  297. for _, item := range temp.Tools {
  298. tool := ToolDeclaration{}
  299. if err := json.Unmarshal(item, &tool); err != nil {
  300. // try to unmarshal it as a string directly
  301. t.ToolFiles = append(t.ToolFiles, string(item))
  302. } else {
  303. t.Tools = append(t.Tools, tool)
  304. }
  305. }
  306. if t.CredentialsSchema == nil {
  307. t.CredentialsSchema = []ProviderConfig{}
  308. }
  309. if t.Tools == nil {
  310. t.Tools = []ToolDeclaration{}
  311. }
  312. if t.Identity.Tags == nil {
  313. t.Identity.Tags = []manifest_entities.PluginTag{}
  314. }
  315. return nil
  316. }
  317. func init() {
  318. // init validator
  319. en := en.New()
  320. uni := ut.New(en, en)
  321. translator, _ := uni.GetTranslator("en")
  322. // register translations for default validators
  323. en_translations.RegisterDefaultTranslations(validators.GlobalEntitiesValidator, translator)
  324. validators.GlobalEntitiesValidator.RegisterValidation("tool_parameter_type", isToolParameterType)
  325. validators.GlobalEntitiesValidator.RegisterTranslation(
  326. "tool_parameter_type",
  327. translator,
  328. func(ut ut.Translator) error {
  329. return ut.Add("tool_parameter_type", "{0} is not a valid tool parameter type", true)
  330. },
  331. func(ut ut.Translator, fe validator.FieldError) string {
  332. t, _ := ut.T("tool_parameter_type", fe.Field())
  333. return t
  334. },
  335. )
  336. validators.GlobalEntitiesValidator.RegisterValidation("tool_parameter_form", isToolParameterForm)
  337. validators.GlobalEntitiesValidator.RegisterTranslation(
  338. "tool_parameter_form",
  339. translator,
  340. func(ut ut.Translator) error {
  341. return ut.Add("tool_parameter_form", "{0} is not a valid tool parameter form", true)
  342. },
  343. func(ut ut.Translator, fe validator.FieldError) string {
  344. t, _ := ut.T("tool_parameter_form", fe.Field())
  345. return t
  346. },
  347. )
  348. validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isBasicType)
  349. }
  350. func UnmarshalToolProviderDeclaration(data []byte) (*ToolProviderDeclaration, error) {
  351. obj, err := parser.UnmarshalJsonBytes[ToolProviderDeclaration](data)
  352. if err != nil {
  353. return nil, fmt.Errorf("failed to unmarshal tool provider configuration: %w", err)
  354. }
  355. if err := validators.GlobalEntitiesValidator.Struct(obj); err != nil {
  356. return nil, fmt.Errorf("failed to validate tool provider configuration: %w", err)
  357. }
  358. return &obj, nil
  359. }