tool_service.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. package plugin_daemon
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "errors"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  12. "github.com/xeipuuv/gojsonschema"
  13. )
  14. func InvokeTool(
  15. session *session_manager.Session,
  16. request *requests.RequestInvokeTool,
  17. ) (
  18. *stream.Stream[tool_entities.ToolResponseChunk], error,
  19. ) {
  20. runtime := session.Runtime()
  21. if runtime == nil {
  22. return nil, errors.New("plugin not found")
  23. }
  24. response, err := GenericInvokePlugin[
  25. requests.RequestInvokeTool, tool_entities.ToolResponseChunk,
  26. ](
  27. session,
  28. request,
  29. 128,
  30. )
  31. if err != nil {
  32. return nil, err
  33. }
  34. toolDeclaration := runtime.Configuration().Tool
  35. if toolDeclaration == nil {
  36. return nil, errors.New("tool declaration not found")
  37. }
  38. var toolOutputSchema plugin_entities.ToolOutputSchema
  39. for _, v := range toolDeclaration.Tools {
  40. if v.Identity.Name == request.Tool {
  41. toolOutputSchema = v.OutputSchema
  42. }
  43. }
  44. newResponse := stream.NewStream[tool_entities.ToolResponseChunk](128)
  45. routine.Submit(map[string]string{
  46. "module": "plugin_daemon",
  47. "function": "InvokeTool",
  48. "tool_name": request.Tool,
  49. "tool_provider": request.Provider,
  50. }, func() {
  51. files := make(map[string]*bytes.Buffer)
  52. defer newResponse.Close()
  53. for response.Next() {
  54. item, err := response.Read()
  55. if err != nil {
  56. newResponse.WriteError(err)
  57. return
  58. }
  59. if item.Type == tool_entities.ToolResponseChunkTypeBlobChunk {
  60. id, ok := item.Message["id"].(string)
  61. if !ok {
  62. continue
  63. }
  64. totalLength, ok := item.Message["total_length"].(float64)
  65. if !ok {
  66. continue
  67. }
  68. // convert total_length to int
  69. totalLengthInt := int(totalLength)
  70. blob, ok := item.Message["blob"].(string)
  71. if !ok {
  72. continue
  73. }
  74. end, ok := item.Message["end"].(bool)
  75. if !ok {
  76. continue
  77. }
  78. if _, ok := files[id]; !ok {
  79. files[id] = bytes.NewBuffer(make([]byte, 0, totalLengthInt))
  80. }
  81. if end {
  82. newResponse.Write(tool_entities.ToolResponseChunk{
  83. Type: tool_entities.ToolResponseChunkTypeBlob,
  84. Message: map[string]any{
  85. "blob": files[id].Bytes(), // bytes will be encoded to base64 finally
  86. },
  87. Meta: item.Meta,
  88. })
  89. } else {
  90. if files[id].Len() > 15*1024*1024 {
  91. // delete the file if it is too large
  92. delete(files, id)
  93. newResponse.WriteError(errors.New("file is too large"))
  94. return
  95. } else {
  96. // decode the blob using base64
  97. decoded, err := base64.StdEncoding.DecodeString(blob)
  98. if err != nil {
  99. newResponse.WriteError(err)
  100. return
  101. }
  102. if len(decoded) > 8192 {
  103. // single chunk is too large, raises error
  104. newResponse.WriteError(errors.New("single file chunk is too large"))
  105. return
  106. }
  107. files[id].Write(decoded)
  108. }
  109. }
  110. } else {
  111. newResponse.Write(item)
  112. }
  113. }
  114. })
  115. // bind json schema validator
  116. bindValidator(response, toolOutputSchema)
  117. return newResponse, nil
  118. }
  119. func bindValidator(
  120. response *stream.Stream[tool_entities.ToolResponseChunk],
  121. toolOutputSchema plugin_entities.ToolOutputSchema,
  122. ) {
  123. // check if the tool_output_schema is valid
  124. variables := make(map[string]any)
  125. response.Filter(func(trc tool_entities.ToolResponseChunk) error {
  126. if trc.Type == tool_entities.ToolResponseChunkTypeVariable {
  127. variableName, ok := trc.Message["variable_name"].(string)
  128. if !ok {
  129. return errors.New("variable name is not a string")
  130. }
  131. stream, ok := trc.Message["stream"].(bool)
  132. if !ok {
  133. return errors.New("stream is not a boolean")
  134. }
  135. if stream {
  136. // ensure variable_value is a string
  137. variableValue, ok := trc.Message["variable_value"].(string)
  138. if !ok {
  139. return errors.New("variable value is not a string")
  140. }
  141. // create it if not exists
  142. if _, ok := variables[variableName]; !ok {
  143. variables[variableName] = ""
  144. }
  145. originalValue, ok := variables[variableName].(string)
  146. if !ok {
  147. return errors.New("variable value is not a string")
  148. }
  149. // add the variable value to the variable
  150. variables[variableName] = originalValue + variableValue
  151. } else {
  152. variables[variableName] = trc.Message["variable_value"]
  153. }
  154. }
  155. return nil
  156. })
  157. response.BeforeClose(func() {
  158. // validate the variables
  159. schema, err := gojsonschema.NewSchema(gojsonschema.NewGoLoader(toolOutputSchema))
  160. if err != nil {
  161. response.WriteError(err)
  162. return
  163. }
  164. // validate the variables
  165. result, err := schema.Validate(gojsonschema.NewGoLoader(variables))
  166. if err != nil {
  167. response.WriteError(err)
  168. return
  169. }
  170. if !result.Valid() {
  171. response.WriteError(errors.New("tool output schema is not valid"))
  172. return
  173. }
  174. })
  175. }
  176. func ValidateToolCredentials(
  177. session *session_manager.Session,
  178. request *requests.RequestValidateToolCredentials,
  179. ) (
  180. *stream.Stream[tool_entities.ValidateCredentialsResult], error,
  181. ) {
  182. return GenericInvokePlugin[requests.RequestValidateToolCredentials, tool_entities.ValidateCredentialsResult](
  183. session,
  184. request,
  185. 1,
  186. )
  187. }
  188. func GetToolRuntimeParameters(
  189. session *session_manager.Session,
  190. request *requests.RequestGetToolRuntimeParameters,
  191. ) (
  192. *stream.Stream[tool_entities.GetToolRuntimeParametersResponse], error,
  193. ) {
  194. return GenericInvokePlugin[requests.RequestGetToolRuntimeParameters, tool_entities.GetToolRuntimeParametersResponse](
  195. session,
  196. request,
  197. 1,
  198. )
  199. }