tool_service.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. package plugin_daemon
  2. import (
  3. "bytes"
  4. "encoding/base64"
  5. "errors"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  9. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/tool_entities"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  12. "github.com/xeipuuv/gojsonschema"
  13. )
  14. func InvokeTool(
  15. session *session_manager.Session,
  16. request *requests.RequestInvokeTool,
  17. ) (
  18. *stream.Stream[tool_entities.ToolResponseChunk], error,
  19. ) {
  20. runtime := session.Runtime()
  21. if runtime == nil {
  22. return nil, errors.New("plugin not found")
  23. }
  24. response, err := GenericInvokePlugin[
  25. requests.RequestInvokeTool, tool_entities.ToolResponseChunk,
  26. ](
  27. session,
  28. request,
  29. 128,
  30. )
  31. if err != nil {
  32. return nil, err
  33. }
  34. tool_declaration := runtime.Configuration().Tool
  35. if tool_declaration == nil {
  36. return nil, errors.New("tool declaration not found")
  37. }
  38. var tool_output_schema plugin_entities.ToolOutputSchema
  39. for _, v := range tool_declaration.Tools {
  40. if v.Identity.Name == request.Tool {
  41. tool_output_schema = v.OutputSchema
  42. }
  43. }
  44. new_response := stream.NewStream[tool_entities.ToolResponseChunk](128)
  45. routine.Submit(func() {
  46. files := make(map[string]*bytes.Buffer)
  47. defer new_response.Close()
  48. for response.Next() {
  49. item, err := response.Read()
  50. if err != nil {
  51. return
  52. }
  53. if item.Type == tool_entities.ToolResponseChunkTypeBlobChunk {
  54. id, ok := item.Message["id"].(string)
  55. if !ok {
  56. continue
  57. }
  58. total_length, ok := item.Message["total_length"].(float64)
  59. if !ok {
  60. continue
  61. }
  62. // convert total_length to int
  63. total_length_int := int(total_length)
  64. blob, ok := item.Message["blob"].(string)
  65. if !ok {
  66. continue
  67. }
  68. end, ok := item.Message["end"].(bool)
  69. if !ok {
  70. continue
  71. }
  72. if _, ok := files[id]; !ok {
  73. files[id] = bytes.NewBuffer(make([]byte, 0, total_length_int))
  74. }
  75. if end {
  76. new_response.Write(tool_entities.ToolResponseChunk{
  77. Type: tool_entities.ToolResponseChunkTypeBlob,
  78. Message: map[string]any{
  79. "blob": files[id].Bytes(), // bytes will be encoded to base64 finally
  80. },
  81. Meta: item.Meta,
  82. })
  83. } else {
  84. if files[id].Len() > 15*1024*1024 {
  85. // delete the file if it is too large
  86. delete(files, id)
  87. new_response.WriteError(errors.New("file is too large"))
  88. return
  89. } else {
  90. // decode the blob using base64
  91. decoded, err := base64.StdEncoding.DecodeString(blob)
  92. if err != nil {
  93. new_response.WriteError(err)
  94. return
  95. }
  96. if len(decoded) > 8192 {
  97. // single chunk is too large, raises error
  98. new_response.WriteError(errors.New("single file chunk is too large"))
  99. return
  100. }
  101. files[id].Write(decoded)
  102. }
  103. }
  104. } else {
  105. new_response.Write(item)
  106. }
  107. }
  108. })
  109. // bind json schema validator
  110. bindValidator(response, tool_output_schema)
  111. return new_response, nil
  112. }
  113. func bindValidator(
  114. response *stream.Stream[tool_entities.ToolResponseChunk],
  115. tool_output_schema plugin_entities.ToolOutputSchema,
  116. ) {
  117. // check if the tool_output_schema is valid
  118. variables := make(map[string]any)
  119. response.Filter(func(trc tool_entities.ToolResponseChunk) error {
  120. if trc.Type == tool_entities.ToolResponseChunkTypeVariable {
  121. variable_name, ok := trc.Message["variable_name"].(string)
  122. if !ok {
  123. return errors.New("variable name is not a string")
  124. }
  125. stream, ok := trc.Message["stream"].(bool)
  126. if !ok {
  127. return errors.New("stream is not a boolean")
  128. }
  129. if stream {
  130. // ensure variable_value is a string
  131. variable_value, ok := trc.Message["variable_value"].(string)
  132. if !ok {
  133. return errors.New("variable value is not a string")
  134. }
  135. // create it if not exists
  136. if _, ok := variables[variable_name]; !ok {
  137. variables[variable_name] = ""
  138. }
  139. original_value, ok := variables[variable_name].(string)
  140. if !ok {
  141. return errors.New("variable value is not a string")
  142. }
  143. // add the variable value to the variable
  144. variables[variable_name] = original_value + variable_value
  145. } else {
  146. variables[variable_name] = trc.Message["variable_value"]
  147. }
  148. }
  149. return nil
  150. })
  151. response.BeforeClose(func() {
  152. // validate the variables
  153. schema, err := gojsonschema.NewSchema(gojsonschema.NewGoLoader(tool_output_schema))
  154. if err != nil {
  155. response.WriteError(err)
  156. return
  157. }
  158. // validate the variables
  159. result, err := schema.Validate(gojsonschema.NewGoLoader(variables))
  160. if err != nil {
  161. response.WriteError(err)
  162. return
  163. }
  164. if !result.Valid() {
  165. response.WriteError(errors.New("tool output schema is not valid"))
  166. return
  167. }
  168. })
  169. }
  170. func ValidateToolCredentials(
  171. session *session_manager.Session,
  172. request *requests.RequestValidateToolCredentials,
  173. ) (
  174. *stream.Stream[tool_entities.ValidateCredentialsResult], error,
  175. ) {
  176. return GenericInvokePlugin[requests.RequestValidateToolCredentials, tool_entities.ValidateCredentialsResult](
  177. session,
  178. request,
  179. 1,
  180. )
  181. }
  182. func GetToolRuntimeParameters(
  183. session *session_manager.Session,
  184. request *requests.RequestGetToolRuntimeParameters,
  185. ) (
  186. *stream.Stream[tool_entities.GetToolRuntimeParametersResponse], error,
  187. ) {
  188. return GenericInvokePlugin[requests.RequestGetToolRuntimeParameters, tool_entities.GetToolRuntimeParametersResponse](
  189. session,
  190. request,
  191. 1,
  192. )
  193. }