tool_service.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  5. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  9. "github.com/xeipuuv/gojsonschema"
  10. )
  11. func InvokeTool(
  12. session *session_manager.Session,
  13. request *requests.RequestInvokeTool,
  14. ) (
  15. *stream.Stream[tool_entities.ToolResponseChunk], error,
  16. ) {
  17. runtime := session.Runtime()
  18. if runtime == nil {
  19. return nil, errors.New("plugin not found")
  20. }
  21. response, err := GenericInvokePlugin[
  22. requests.RequestInvokeTool, tool_entities.ToolResponseChunk,
  23. ](
  24. session,
  25. request,
  26. 128,
  27. )
  28. if err != nil {
  29. return nil, err
  30. }
  31. tool_declaration := runtime.Configuration().Tool
  32. if tool_declaration == nil {
  33. return nil, errors.New("tool declaration not found")
  34. }
  35. var tool_output_schema plugin_entities.ToolOutputSchema
  36. for _, v := range tool_declaration.Tools {
  37. if v.Identity.Name == request.Tool {
  38. tool_output_schema = v.OutputSchema
  39. }
  40. }
  41. // bind json schema validator
  42. bindValidator(response, tool_output_schema)
  43. return response, nil
  44. }
  45. func bindValidator(
  46. response *stream.Stream[tool_entities.ToolResponseChunk],
  47. tool_output_schema plugin_entities.ToolOutputSchema,
  48. ) {
  49. // check if the tool_output_schema is valid
  50. variables := make(map[string]any)
  51. response.Filter(func(trc tool_entities.ToolResponseChunk) error {
  52. if trc.Type == tool_entities.ToolResponseChunkTypeVariable {
  53. variable_name, ok := trc.Message["variable_name"].(string)
  54. if !ok {
  55. return errors.New("variable name is not a string")
  56. }
  57. stream, ok := trc.Message["stream"].(bool)
  58. if !ok {
  59. return errors.New("stream is not a boolean")
  60. }
  61. if stream {
  62. // ensure variable_value is a string
  63. variable_value, ok := trc.Message["variable_value"].(string)
  64. if !ok {
  65. return errors.New("variable value is not a string")
  66. }
  67. // create it if not exists
  68. if _, ok := variables[variable_name]; !ok {
  69. variables[variable_name] = ""
  70. }
  71. original_value, ok := variables[variable_name].(string)
  72. if !ok {
  73. return errors.New("variable value is not a string")
  74. }
  75. // add the variable value to the variable
  76. variables[variable_name] = original_value + variable_value
  77. } else {
  78. variables[variable_name] = trc.Message["variable_value"]
  79. }
  80. }
  81. return nil
  82. })
  83. response.BeforeClose(func() {
  84. // validate the variables
  85. schema, err := gojsonschema.NewSchema(gojsonschema.NewGoLoader(tool_output_schema))
  86. if err != nil {
  87. response.WriteError(err)
  88. return
  89. }
  90. // validate the variables
  91. result, err := schema.Validate(gojsonschema.NewGoLoader(variables))
  92. if err != nil {
  93. response.WriteError(err)
  94. return
  95. }
  96. if !result.Valid() {
  97. response.WriteError(errors.New("tool output schema is not valid"))
  98. return
  99. }
  100. })
  101. }
  102. func ValidateToolCredentials(
  103. session *session_manager.Session,
  104. request *requests.RequestValidateToolCredentials,
  105. ) (
  106. *stream.Stream[tool_entities.ValidateCredentialsResult], error,
  107. ) {
  108. return GenericInvokePlugin[requests.RequestValidateToolCredentials, tool_entities.ValidateCredentialsResult](
  109. session,
  110. request,
  111. 1,
  112. )
  113. }
  114. func GetToolRuntimeParameters(
  115. session *session_manager.Session,
  116. request *requests.RequestGetToolRuntimeParameters,
  117. ) (
  118. *stream.Stream[tool_entities.GetToolRuntimeParametersResponse], error,
  119. ) {
  120. return GenericInvokePlugin[requests.RequestGetToolRuntimeParameters, tool_entities.GetToolRuntimeParametersResponse](
  121. session,
  122. request,
  123. 1,
  124. )
  125. }