tool_service.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  10. "github.com/xeipuuv/gojsonschema"
  11. )
  12. func InvokeTool(
  13. session *session_manager.Session,
  14. request *requests.RequestInvokeTool,
  15. ) (
  16. *stream.Stream[tool_entities.ToolResponseChunk], error,
  17. ) {
  18. runtime := plugin_manager.Manager().Get(session.PluginUniqueIdentifier)
  19. if runtime == nil {
  20. return nil, errors.New("plugin not found")
  21. }
  22. response, err := genericInvokePlugin[
  23. requests.RequestInvokeTool, tool_entities.ToolResponseChunk,
  24. ](
  25. session,
  26. request,
  27. 128,
  28. )
  29. if err != nil {
  30. return nil, err
  31. }
  32. tool_declaration := runtime.Configuration().Tool
  33. if tool_declaration == nil {
  34. return nil, errors.New("tool declaration not found")
  35. }
  36. var tool_output_schema plugin_entities.ToolOutputSchema
  37. for _, v := range tool_declaration.Tools {
  38. if v.Identity.Name == request.Tool {
  39. tool_output_schema = v.OutputSchema
  40. }
  41. }
  42. // check if the tool_output_schema is valid
  43. variables := make(map[string]any)
  44. response.Filter(func(trc tool_entities.ToolResponseChunk) error {
  45. if trc.Type == tool_entities.ToolResponseChunkTypeVariable {
  46. variable_name, ok := trc.Message["variable_name"].(string)
  47. if !ok {
  48. return errors.New("variable name is not a string")
  49. }
  50. stream, ok := trc.Message["stream"].(bool)
  51. if !ok {
  52. return errors.New("stream is not a boolean")
  53. }
  54. if stream {
  55. // ensure variable_value is a string
  56. variable_value, ok := trc.Message["variable_value"].(string)
  57. if !ok {
  58. return errors.New("variable value is not a string")
  59. }
  60. // create it if not exists
  61. if _, ok := variables[variable_name]; !ok {
  62. variables[variable_name] = ""
  63. }
  64. original_value, ok := variables[variable_name].(string)
  65. if !ok {
  66. return errors.New("variable value is not a string")
  67. }
  68. // add the variable value to the variable
  69. variables[variable_name] = original_value + variable_value
  70. } else {
  71. variables[variable_name] = trc.Message["variable_value"]
  72. }
  73. }
  74. return nil
  75. })
  76. response.BeforeClose(func() {
  77. // validate the variables
  78. schema, err := gojsonschema.NewSchema(gojsonschema.NewGoLoader(tool_output_schema))
  79. if err != nil {
  80. response.WriteError(err)
  81. return
  82. }
  83. // validate the variables
  84. result, err := schema.Validate(gojsonschema.NewGoLoader(variables))
  85. if err != nil {
  86. response.WriteError(err)
  87. return
  88. }
  89. if !result.Valid() {
  90. response.WriteError(errors.New("tool output schema is not valid"))
  91. return
  92. }
  93. })
  94. return response, nil
  95. }
  96. func ValidateToolCredentials(
  97. session *session_manager.Session,
  98. request *requests.RequestValidateToolCredentials,
  99. ) (
  100. *stream.Stream[tool_entities.ValidateCredentialsResult], error,
  101. ) {
  102. return genericInvokePlugin[requests.RequestValidateToolCredentials, tool_entities.ValidateCredentialsResult](
  103. session,
  104. request,
  105. 1,
  106. )
  107. }