tool_service.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. package plugin_daemon
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "errors"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  12. "github.com/xeipuuv/gojsonschema"
  13. )
  14. func InvokeTool(
  15. session *session_manager.Session,
  16. request *requests.RequestInvokeTool,
  17. ) (
  18. *stream.Stream[tool_entities.ToolResponseChunk], error,
  19. ) {
  20. runtime := session.Runtime()
  21. if runtime == nil {
  22. return nil, errors.New("plugin not found")
  23. }
  24. response, err := GenericInvokePlugin[
  25. requests.RequestInvokeTool, tool_entities.ToolResponseChunk,
  26. ](
  27. session,
  28. request,
  29. 128,
  30. )
  31. if err != nil {
  32. return nil, err
  33. }
  34. toolDeclaration := runtime.Configuration().Tool
  35. if toolDeclaration == nil {
  36. return nil, errors.New("tool declaration not found")
  37. }
  38. var toolOutputSchema plugin_entities.ToolOutputSchema
  39. for _, v := range toolDeclaration.Tools {
  40. if v.Identity.Name == request.Tool {
  41. toolOutputSchema = v.OutputSchema
  42. }
  43. }
  44. newResponse := stream.NewStream[tool_entities.ToolResponseChunk](128)
  45. routine.Submit(func() {
  46. files := make(map[string]*bytes.Buffer)
  47. defer newResponse.Close()
  48. for response.Next() {
  49. item, err := response.Read()
  50. if err != nil {
  51. newResponse.WriteError(err)
  52. return
  53. }
  54. if item.Type == tool_entities.ToolResponseChunkTypeBlobChunk {
  55. id, ok := item.Message["id"].(string)
  56. if !ok {
  57. continue
  58. }
  59. totalLength, ok := item.Message["total_length"].(float64)
  60. if !ok {
  61. continue
  62. }
  63. // convert total_length to int
  64. totalLengthInt := int(totalLength)
  65. blob, ok := item.Message["blob"].(string)
  66. if !ok {
  67. continue
  68. }
  69. end, ok := item.Message["end"].(bool)
  70. if !ok {
  71. continue
  72. }
  73. if _, ok := files[id]; !ok {
  74. files[id] = bytes.NewBuffer(make([]byte, 0, totalLengthInt))
  75. }
  76. if end {
  77. newResponse.Write(tool_entities.ToolResponseChunk{
  78. Type: tool_entities.ToolResponseChunkTypeBlob,
  79. Message: map[string]any{
  80. "blob": files[id].Bytes(), // bytes will be encoded to base64 finally
  81. },
  82. Meta: item.Meta,
  83. })
  84. } else {
  85. if files[id].Len() > 15*1024*1024 {
  86. // delete the file if it is too large
  87. delete(files, id)
  88. newResponse.WriteError(errors.New("file is too large"))
  89. return
  90. } else {
  91. // decode the blob using base64
  92. decoded, err := base64.StdEncoding.DecodeString(blob)
  93. if err != nil {
  94. newResponse.WriteError(err)
  95. return
  96. }
  97. if len(decoded) > 8192 {
  98. // single chunk is too large, raises error
  99. newResponse.WriteError(errors.New("single file chunk is too large"))
  100. return
  101. }
  102. files[id].Write(decoded)
  103. }
  104. }
  105. } else {
  106. newResponse.Write(item)
  107. }
  108. }
  109. })
  110. // bind json schema validator
  111. bindValidator(response, toolOutputSchema)
  112. return newResponse, nil
  113. }
  114. func bindValidator(
  115. response *stream.Stream[tool_entities.ToolResponseChunk],
  116. toolOutputSchema plugin_entities.ToolOutputSchema,
  117. ) {
  118. // check if the tool_output_schema is valid
  119. variables := make(map[string]any)
  120. response.Filter(func(trc tool_entities.ToolResponseChunk) error {
  121. if trc.Type == tool_entities.ToolResponseChunkTypeVariable {
  122. variableName, ok := trc.Message["variable_name"].(string)
  123. if !ok {
  124. return errors.New("variable name is not a string")
  125. }
  126. stream, ok := trc.Message["stream"].(bool)
  127. if !ok {
  128. return errors.New("stream is not a boolean")
  129. }
  130. if stream {
  131. // ensure variable_value is a string
  132. variableValue, ok := trc.Message["variable_value"].(string)
  133. if !ok {
  134. return errors.New("variable value is not a string")
  135. }
  136. // create it if not exists
  137. if _, ok := variables[variableName]; !ok {
  138. variables[variableName] = ""
  139. }
  140. originalValue, ok := variables[variableName].(string)
  141. if !ok {
  142. return errors.New("variable value is not a string")
  143. }
  144. // add the variable value to the variable
  145. variables[variableName] = originalValue + variableValue
  146. } else {
  147. variables[variableName] = trc.Message["variable_value"]
  148. }
  149. }
  150. return nil
  151. })
  152. response.BeforeClose(func() {
  153. // validate the variables
  154. schema, err := gojsonschema.NewSchema(gojsonschema.NewGoLoader(toolOutputSchema))
  155. if err != nil {
  156. response.WriteError(err)
  157. return
  158. }
  159. // validate the variables
  160. result, err := schema.Validate(gojsonschema.NewGoLoader(variables))
  161. if err != nil {
  162. response.WriteError(err)
  163. return
  164. }
  165. if !result.Valid() {
  166. response.WriteError(errors.New("tool output schema is not valid"))
  167. return
  168. }
  169. })
  170. }
  171. func ValidateToolCredentials(
  172. session *session_manager.Session,
  173. request *requests.RequestValidateToolCredentials,
  174. ) (
  175. *stream.Stream[tool_entities.ValidateCredentialsResult], error,
  176. ) {
  177. return GenericInvokePlugin[requests.RequestValidateToolCredentials, tool_entities.ValidateCredentialsResult](
  178. session,
  179. request,
  180. 1,
  181. )
  182. }
  183. func GetToolRuntimeParameters(
  184. session *session_manager.Session,
  185. request *requests.RequestGetToolRuntimeParameters,
  186. ) (
  187. *stream.Stream[tool_entities.GetToolRuntimeParametersResponse], error,
  188. ) {
  189. return GenericInvokePlugin[requests.RequestGetToolRuntimeParameters, tool_entities.GetToolRuntimeParametersResponse](
  190. session,
  191. request,
  192. 1,
  193. )
  194. }