tool_declaration.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. package plugin_entities
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "regexp"
  6. "github.com/go-playground/locales/en"
  7. ut "github.com/go-playground/universal-translator"
  8. "github.com/go-playground/validator/v10"
  9. en_translations "github.com/go-playground/validator/v10/translations/en"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  11. "github.com/langgenius/dify-plugin-daemon/pkg/entities/manifest_entities"
  12. "github.com/langgenius/dify-plugin-daemon/pkg/validators"
  13. "github.com/xeipuuv/gojsonschema"
  14. "gopkg.in/yaml.v3"
  15. )
  16. type ToolIdentity struct {
  17. Author string `json:"author" yaml:"author" validate:"required"`
  18. Name string `json:"name" yaml:"name" validate:"required,tool_identity_name"`
  19. Label I18nObject `json:"label" yaml:"label" validate:"required"`
  20. }
  21. var toolIdentityNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
  22. func isToolIdentityName(fl validator.FieldLevel) bool {
  23. value := fl.Field().String()
  24. return toolIdentityNameRegex.MatchString(value)
  25. }
  26. func init() {
  27. validators.GlobalEntitiesValidator.RegisterValidation("tool_identity_name", isToolIdentityName)
  28. }
  29. type ToolParameterOption struct {
  30. Value string `json:"value" yaml:"value" validate:"required"`
  31. Label I18nObject `json:"label" yaml:"label" validate:"required"`
  32. }
  33. type ToolParameterType string
  34. const (
  35. TOOL_PARAMETER_TYPE_STRING ToolParameterType = STRING
  36. TOOL_PARAMETER_TYPE_NUMBER ToolParameterType = NUMBER
  37. TOOL_PARAMETER_TYPE_BOOLEAN ToolParameterType = BOOLEAN
  38. TOOL_PARAMETER_TYPE_SELECT ToolParameterType = SELECT
  39. TOOL_PARAMETER_TYPE_SECRET_INPUT ToolParameterType = SECRET_INPUT
  40. TOOL_PARAMETER_TYPE_FILE ToolParameterType = FILE
  41. TOOL_PARAMETER_TYPE_FILES ToolParameterType = FILES
  42. TOOL_PARAMETER_TYPE_APP_SELECTOR ToolParameterType = APP_SELECTOR
  43. TOOL_PARAMETER_TYPE_MODEL_SELECTOR ToolParameterType = MODEL_SELECTOR
  44. // TOOL_PARAMETER_TYPE_TOOL_SELECTOR ToolParameterType = TOOL_SELECTOR
  45. )
  46. func isToolParameterType(fl validator.FieldLevel) bool {
  47. value := fl.Field().String()
  48. switch value {
  49. case string(TOOL_PARAMETER_TYPE_STRING),
  50. string(TOOL_PARAMETER_TYPE_NUMBER),
  51. string(TOOL_PARAMETER_TYPE_BOOLEAN),
  52. string(TOOL_PARAMETER_TYPE_SELECT),
  53. string(TOOL_PARAMETER_TYPE_SECRET_INPUT),
  54. string(TOOL_PARAMETER_TYPE_FILE),
  55. string(TOOL_PARAMETER_TYPE_FILES),
  56. // string(TOOL_PARAMETER_TYPE_TOOL_SELECTOR),
  57. string(TOOL_PARAMETER_TYPE_APP_SELECTOR),
  58. string(TOOL_PARAMETER_TYPE_MODEL_SELECTOR):
  59. return true
  60. }
  61. return false
  62. }
  63. type ToolParameterForm string
  64. const (
  65. TOOL_PARAMETER_FORM_SCHEMA ToolParameterForm = "schema"
  66. TOOL_PARAMETER_FORM_FORM ToolParameterForm = "form"
  67. TOOL_PARAMETER_FORM_LLM ToolParameterForm = "llm"
  68. )
  69. func isToolParameterForm(fl validator.FieldLevel) bool {
  70. value := fl.Field().String()
  71. switch value {
  72. case string(TOOL_PARAMETER_FORM_SCHEMA),
  73. string(TOOL_PARAMETER_FORM_FORM),
  74. string(TOOL_PARAMETER_FORM_LLM):
  75. return true
  76. }
  77. return false
  78. }
  79. type ParameterAutoGenerateType string
  80. const (
  81. PARAMETER_AUTO_GENERATE_TYPE_PROMPT_INSTRUCTION ParameterAutoGenerateType = "prompt_instruction"
  82. )
  83. func isParameterAutoGenerateType(fl validator.FieldLevel) bool {
  84. value := fl.Field().String()
  85. switch value {
  86. case string(PARAMETER_AUTO_GENERATE_TYPE_PROMPT_INSTRUCTION):
  87. return true
  88. }
  89. return false
  90. }
  91. func init() {
  92. validators.GlobalEntitiesValidator.RegisterValidation("parameter_auto_generate_type", isParameterAutoGenerateType)
  93. }
  94. type ParameterAutoGenerate struct {
  95. Type ParameterAutoGenerateType `json:"type" yaml:"type" validate:"required,parameter_auto_generate_type"`
  96. }
  97. type ParameterTemplate struct {
  98. Enabled bool `json:"enabled" yaml:"enabled"`
  99. }
  100. type ToolParameter struct {
  101. Name string `json:"name" yaml:"name" validate:"required,gt=0,lt=1024"`
  102. Label I18nObject `json:"label" yaml:"label" validate:"required"`
  103. HumanDescription I18nObject `json:"human_description" yaml:"human_description" validate:"required"`
  104. Type ToolParameterType `json:"type" yaml:"type" validate:"required,tool_parameter_type"`
  105. Scope *string `json:"scope" yaml:"scope" validate:"omitempty,max=1024,is_scope"`
  106. Form ToolParameterForm `json:"form" yaml:"form" validate:"required,tool_parameter_form"`
  107. LLMDescription string `json:"llm_description" yaml:"llm_description" validate:"omitempty"`
  108. Required bool `json:"required" yaml:"required"`
  109. AutoGenerate *ParameterAutoGenerate `json:"auto_generate" yaml:"auto_generate" validate:"omitempty"`
  110. Template *ParameterTemplate `json:"template" yaml:"template" validate:"omitempty"`
  111. Default any `json:"default" yaml:"default" validate:"omitempty,is_basic_type"`
  112. Min *float64 `json:"min" yaml:"min" validate:"omitempty"`
  113. Max *float64 `json:"max" yaml:"max" validate:"omitempty"`
  114. Precision *int `json:"precision" yaml:"precision" validate:"omitempty"`
  115. Options []ToolParameterOption `json:"options" yaml:"options" validate:"omitempty,dive"`
  116. }
  117. type ToolDescription struct {
  118. Human I18nObject `json:"human" validate:"required"`
  119. LLM string `json:"llm" validate:"required"`
  120. }
  121. type ToolOutputSchema map[string]any
  122. type ToolDeclaration struct {
  123. Identity ToolIdentity `json:"identity" yaml:"identity" validate:"required"`
  124. Description ToolDescription `json:"description" yaml:"description" validate:"required"`
  125. Parameters []ToolParameter `json:"parameters" yaml:"parameters" validate:"omitempty,dive"`
  126. OutputSchema ToolOutputSchema `json:"output_schema" yaml:"output_schema" validate:"omitempty,json_schema"`
  127. HasRuntimeParameters bool `json:"has_runtime_parameters" yaml:"has_runtime_parameters"`
  128. }
  129. func isJSONSchema(fl validator.FieldLevel) bool {
  130. // get schema from interface
  131. schemaMapInf := fl.Field().Interface()
  132. // convert to map[string]any
  133. var schemaMap map[string]any
  134. toolSchemaMap, ok := schemaMapInf.(ToolOutputSchema)
  135. if !ok {
  136. agentSchemaMap, ok := schemaMapInf.(AgentStrategyOutputSchema)
  137. if !ok {
  138. return false
  139. }
  140. schemaMap = agentSchemaMap
  141. } else {
  142. schemaMap = toolSchemaMap
  143. }
  144. // validate root schema must be object type
  145. rootType, ok := schemaMap["type"].(string)
  146. if !ok || rootType != "object" {
  147. return false
  148. }
  149. // validate properties
  150. properties, ok := schemaMap["properties"].(map[string]any)
  151. if !ok {
  152. return false
  153. }
  154. // disallow text, json, files as property names
  155. disallowedProps := []string{"text", "json", "files"}
  156. for _, prop := range disallowedProps {
  157. if _, exists := properties[prop]; exists {
  158. return false
  159. }
  160. }
  161. _, err := gojsonschema.NewSchema(gojsonschema.NewGoLoader(fl.Field().Interface()))
  162. if err != nil {
  163. return false
  164. }
  165. return err == nil
  166. }
  167. func init() {
  168. validators.GlobalEntitiesValidator.RegisterValidation("json_schema", isJSONSchema)
  169. }
  170. type ToolProviderIdentity struct {
  171. Author string `json:"author" validate:"required"`
  172. Name string `json:"name" validate:"required,tool_provider_identity_name"`
  173. Description I18nObject `json:"description"`
  174. Icon string `json:"icon" validate:"required"`
  175. Label I18nObject `json:"label" validate:"required"`
  176. Tags []manifest_entities.PluginTag `json:"tags" validate:"omitempty,dive,plugin_tag"`
  177. }
  178. var toolProviderIdentityNameRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
  179. func isToolProviderIdentityName(fl validator.FieldLevel) bool {
  180. value := fl.Field().String()
  181. return toolProviderIdentityNameRegex.MatchString(value)
  182. }
  183. func init() {
  184. validators.GlobalEntitiesValidator.RegisterValidation("tool_provider_identity_name", isToolProviderIdentityName)
  185. }
  186. type ToolProviderDeclaration struct {
  187. Identity ToolProviderIdentity `json:"identity" yaml:"identity" validate:"required"`
  188. CredentialsSchema []ProviderConfig `json:"credentials_schema" yaml:"credentials_schema" validate:"omitempty,dive"`
  189. Tools []ToolDeclaration `json:"tools" yaml:"tools" validate:"required,dive"`
  190. ToolFiles []string `json:"-" yaml:"-"`
  191. }
  192. func (t *ToolProviderDeclaration) MarshalJSON() ([]byte, error) {
  193. type alias ToolProviderDeclaration
  194. p := alias(*t)
  195. if p.CredentialsSchema == nil {
  196. p.CredentialsSchema = []ProviderConfig{}
  197. }
  198. if p.Tools == nil {
  199. p.Tools = []ToolDeclaration{}
  200. }
  201. return json.Marshal(p)
  202. }
  203. func (t *ToolProviderDeclaration) UnmarshalYAML(value *yaml.Node) error {
  204. type alias struct {
  205. Identity ToolProviderIdentity `yaml:"identity"`
  206. CredentialsSchema yaml.Node `yaml:"credentials_schema"`
  207. CredentialsForProvider yaml.Node `yaml:"credentials_for_provider"`
  208. Tools yaml.Node `yaml:"tools"`
  209. }
  210. var temp alias
  211. err := value.Decode(&temp)
  212. if err != nil {
  213. return err
  214. }
  215. // apply credentials_for_provider to credentials_schema if not exists
  216. if (temp.CredentialsSchema.Kind == yaml.ScalarNode && temp.CredentialsSchema.Value == "") ||
  217. len(temp.CredentialsSchema.Content) == 0 {
  218. temp.CredentialsSchema = temp.CredentialsForProvider
  219. }
  220. // apply identity
  221. t.Identity = temp.Identity
  222. // check if credentials_schema is a map
  223. if temp.CredentialsSchema.Kind != yaml.MappingNode {
  224. // not a map, convert it into array
  225. credentialsSchema := make([]ProviderConfig, 0)
  226. if err := temp.CredentialsSchema.Decode(&credentialsSchema); err != nil {
  227. return err
  228. }
  229. t.CredentialsSchema = credentialsSchema
  230. } else if temp.CredentialsSchema.Kind == yaml.MappingNode {
  231. credentialsSchema := make([]ProviderConfig, 0, len(temp.CredentialsSchema.Content)/2)
  232. currentKey := ""
  233. currentValue := &ProviderConfig{}
  234. for _, item := range temp.CredentialsSchema.Content {
  235. if item.Kind == yaml.ScalarNode {
  236. currentKey = item.Value
  237. } else if item.Kind == yaml.MappingNode {
  238. currentValue = &ProviderConfig{}
  239. if err := item.Decode(currentValue); err != nil {
  240. return err
  241. }
  242. currentValue.Name = currentKey
  243. credentialsSchema = append(credentialsSchema, *currentValue)
  244. }
  245. }
  246. t.CredentialsSchema = credentialsSchema
  247. }
  248. if t.ToolFiles == nil {
  249. t.ToolFiles = []string{}
  250. }
  251. // unmarshal tools
  252. if temp.Tools.Kind == yaml.SequenceNode {
  253. for _, item := range temp.Tools.Content {
  254. if item.Kind == yaml.ScalarNode {
  255. t.ToolFiles = append(t.ToolFiles, item.Value)
  256. } else if item.Kind == yaml.MappingNode {
  257. tool := ToolDeclaration{}
  258. if err := item.Decode(&tool); err != nil {
  259. return err
  260. }
  261. t.Tools = append(t.Tools, tool)
  262. }
  263. }
  264. }
  265. if t.CredentialsSchema == nil {
  266. t.CredentialsSchema = []ProviderConfig{}
  267. }
  268. if t.Tools == nil {
  269. t.Tools = []ToolDeclaration{}
  270. }
  271. if t.Identity.Tags == nil {
  272. t.Identity.Tags = []manifest_entities.PluginTag{}
  273. }
  274. return nil
  275. }
  276. func (t *ToolProviderDeclaration) UnmarshalJSON(data []byte) error {
  277. type alias ToolProviderDeclaration
  278. var temp struct {
  279. alias
  280. CredentialsSchema json.RawMessage `json:"credentials_schema"`
  281. CredentialsForProvider json.RawMessage `json:"credentials_for_provider"`
  282. Tools []json.RawMessage `json:"tools"`
  283. }
  284. if err := json.Unmarshal(data, &temp); err != nil {
  285. return err
  286. }
  287. if len(temp.CredentialsSchema) == 0 {
  288. temp.CredentialsSchema = temp.CredentialsForProvider
  289. }
  290. *t = ToolProviderDeclaration(temp.alias)
  291. // Determine the type of CredentialsSchema
  292. var raw_message map[string]json.RawMessage
  293. if err := json.Unmarshal(temp.CredentialsSchema, &raw_message); err == nil {
  294. // It's an object
  295. credentialsSchemaObject := make(map[string]ProviderConfig)
  296. if err := json.Unmarshal(temp.CredentialsSchema, &credentialsSchemaObject); err != nil {
  297. return fmt.Errorf("failed to unmarshal credentials_schema as object: %v", err)
  298. }
  299. for _, value := range credentialsSchemaObject {
  300. t.CredentialsSchema = append(t.CredentialsSchema, value)
  301. }
  302. } else {
  303. // It's likely an array
  304. var credentials_schema_array []ProviderConfig
  305. if err := json.Unmarshal(temp.CredentialsSchema, &credentials_schema_array); err != nil {
  306. return fmt.Errorf("failed to unmarshal credentials_schema as array: %v", err)
  307. }
  308. t.CredentialsSchema = credentials_schema_array
  309. }
  310. if t.ToolFiles == nil {
  311. t.ToolFiles = []string{}
  312. }
  313. // unmarshal tools
  314. for _, item := range temp.Tools {
  315. tool := ToolDeclaration{}
  316. if err := json.Unmarshal(item, &tool); err != nil {
  317. // try to unmarshal it as a string directly
  318. t.ToolFiles = append(t.ToolFiles, string(item))
  319. } else {
  320. t.Tools = append(t.Tools, tool)
  321. }
  322. }
  323. if t.CredentialsSchema == nil {
  324. t.CredentialsSchema = []ProviderConfig{}
  325. }
  326. if t.Tools == nil {
  327. t.Tools = []ToolDeclaration{}
  328. }
  329. if t.Identity.Tags == nil {
  330. t.Identity.Tags = []manifest_entities.PluginTag{}
  331. }
  332. return nil
  333. }
  334. func init() {
  335. // init validator
  336. en := en.New()
  337. uni := ut.New(en, en)
  338. translator, _ := uni.GetTranslator("en")
  339. // register translations for default validators
  340. en_translations.RegisterDefaultTranslations(validators.GlobalEntitiesValidator, translator)
  341. validators.GlobalEntitiesValidator.RegisterValidation("tool_parameter_type", isToolParameterType)
  342. validators.GlobalEntitiesValidator.RegisterTranslation(
  343. "tool_parameter_type",
  344. translator,
  345. func(ut ut.Translator) error {
  346. return ut.Add("tool_parameter_type", "{0} is not a valid tool parameter type", true)
  347. },
  348. func(ut ut.Translator, fe validator.FieldError) string {
  349. t, _ := ut.T("tool_parameter_type", fe.Field())
  350. return t
  351. },
  352. )
  353. validators.GlobalEntitiesValidator.RegisterValidation("tool_parameter_form", isToolParameterForm)
  354. validators.GlobalEntitiesValidator.RegisterTranslation(
  355. "tool_parameter_form",
  356. translator,
  357. func(ut ut.Translator) error {
  358. return ut.Add("tool_parameter_form", "{0} is not a valid tool parameter form", true)
  359. },
  360. func(ut ut.Translator, fe validator.FieldError) string {
  361. t, _ := ut.T("tool_parameter_form", fe.Field())
  362. return t
  363. },
  364. )
  365. validators.GlobalEntitiesValidator.RegisterValidation("is_basic_type", isBasicType)
  366. }
  367. func UnmarshalToolProviderDeclaration(data []byte) (*ToolProviderDeclaration, error) {
  368. obj, err := parser.UnmarshalJsonBytes[ToolProviderDeclaration](data)
  369. if err != nil {
  370. return nil, fmt.Errorf("failed to unmarshal tool provider configuration: %w", err)
  371. }
  372. if err := validators.GlobalEntitiesValidator.Struct(obj); err != nil {
  373. return nil, fmt.Errorf("failed to validate tool provider configuration: %w", err)
  374. }
  375. return &obj, nil
  376. }