tool_service.go 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  10. "github.com/xeipuuv/gojsonschema"
  11. )
  12. func InvokeTool(
  13. session *session_manager.Session,
  14. request *requests.RequestInvokeTool,
  15. ) (
  16. *stream.Stream[tool_entities.ToolResponseChunk], error,
  17. ) {
  18. runtime := plugin_manager.Manager().Get(session.PluginUniqueIdentifier)
  19. if runtime == nil {
  20. return nil, errors.New("plugin not found")
  21. }
  22. response, err := genericInvokePlugin[
  23. requests.RequestInvokeTool, tool_entities.ToolResponseChunk,
  24. ](
  25. session,
  26. request,
  27. 128,
  28. )
  29. if err != nil {
  30. return nil, err
  31. }
  32. tool_declaration := runtime.Configuration().Tool
  33. if tool_declaration == nil {
  34. return nil, errors.New("tool declaration not found")
  35. }
  36. var tool_output_schema plugin_entities.ToolOutputSchema
  37. for _, v := range tool_declaration.Tools {
  38. if v.Identity.Name == request.Tool {
  39. tool_output_schema = v.OutputSchema
  40. }
  41. }
  42. // bind json schema validator
  43. bindValidator(response, tool_output_schema)
  44. return response, nil
  45. }
  46. func bindValidator(
  47. response *stream.Stream[tool_entities.ToolResponseChunk],
  48. tool_output_schema plugin_entities.ToolOutputSchema,
  49. ) {
  50. // check if the tool_output_schema is valid
  51. variables := make(map[string]any)
  52. response.Filter(func(trc tool_entities.ToolResponseChunk) error {
  53. if trc.Type == tool_entities.ToolResponseChunkTypeVariable {
  54. variable_name, ok := trc.Message["variable_name"].(string)
  55. if !ok {
  56. return errors.New("variable name is not a string")
  57. }
  58. stream, ok := trc.Message["stream"].(bool)
  59. if !ok {
  60. return errors.New("stream is not a boolean")
  61. }
  62. if stream {
  63. // ensure variable_value is a string
  64. variable_value, ok := trc.Message["variable_value"].(string)
  65. if !ok {
  66. return errors.New("variable value is not a string")
  67. }
  68. // create it if not exists
  69. if _, ok := variables[variable_name]; !ok {
  70. variables[variable_name] = ""
  71. }
  72. original_value, ok := variables[variable_name].(string)
  73. if !ok {
  74. return errors.New("variable value is not a string")
  75. }
  76. // add the variable value to the variable
  77. variables[variable_name] = original_value + variable_value
  78. } else {
  79. variables[variable_name] = trc.Message["variable_value"]
  80. }
  81. }
  82. return nil
  83. })
  84. response.BeforeClose(func() {
  85. // validate the variables
  86. schema, err := gojsonschema.NewSchema(gojsonschema.NewGoLoader(tool_output_schema))
  87. if err != nil {
  88. response.WriteError(err)
  89. return
  90. }
  91. // validate the variables
  92. result, err := schema.Validate(gojsonschema.NewGoLoader(variables))
  93. if err != nil {
  94. response.WriteError(err)
  95. return
  96. }
  97. if !result.Valid() {
  98. response.WriteError(errors.New("tool output schema is not valid"))
  99. return
  100. }
  101. })
  102. }
  103. func ValidateToolCredentials(
  104. session *session_manager.Session,
  105. request *requests.RequestValidateToolCredentials,
  106. ) (
  107. *stream.Stream[tool_entities.ValidateCredentialsResult], error,
  108. ) {
  109. return genericInvokePlugin[requests.RequestValidateToolCredentials, tool_entities.ValidateCredentialsResult](
  110. session,
  111. request,
  112. 1,
  113. )
  114. }