tool_service.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/requests"
  8. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  11. )
  12. func getInvokeToolMap(
  13. session *session_manager.Session,
  14. action PluginAccessAction,
  15. request *requests.RequestInvokeTool,
  16. ) map[string]any {
  17. req := getBasicPluginAccessMap(session.ID(), session.UserID(), PLUGIN_ACCESS_TYPE_TOOL, action)
  18. data := req["data"].(map[string]any)
  19. data["provider"] = request.Provider
  20. data["tool"] = request.Tool
  21. data["parameters"] = request.ToolParameters
  22. data["credentials"] = request.Credentials
  23. return req
  24. }
  25. func InvokeTool(
  26. session *session_manager.Session,
  27. request *requests.RequestInvokeTool,
  28. ) (
  29. *stream.StreamResponse[plugin_entities.ToolResponseChunk], error,
  30. ) {
  31. runtime := plugin_manager.Get(session.PluginIdentity())
  32. if runtime == nil {
  33. return nil, errors.New("plugin not found")
  34. }
  35. response := stream.NewStreamResponse[plugin_entities.ToolResponseChunk](512)
  36. listener := runtime.Listen(session.ID())
  37. listener.AddListener(func(message []byte) {
  38. chunk, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](message)
  39. if err != nil {
  40. log.Error("unmarshal json failed: %s", err.Error())
  41. return
  42. }
  43. switch chunk.Type {
  44. case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
  45. chunk, err := parser.UnmarshalJsonBytes[plugin_entities.ToolResponseChunk](chunk.Data)
  46. if err != nil {
  47. log.Error("unmarshal json failed: %s", err.Error())
  48. return
  49. }
  50. response.Write(chunk)
  51. case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
  52. invokeDify(runtime, session, chunk.Data)
  53. case plugin_entities.SESSION_MESSAGE_TYPE_END:
  54. response.Close()
  55. case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
  56. e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
  57. if err != nil {
  58. break
  59. }
  60. response.WriteError(errors.New(e.Error))
  61. response.Close()
  62. default:
  63. response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
  64. response.Close()
  65. }
  66. })
  67. response.OnClose(func() {
  68. listener.Close()
  69. })
  70. runtime.Write(session.ID(), []byte(parser.MarshalJson(
  71. getInvokeToolMap(session, PLUGIN_ACCESS_ACTION_INVOKE_TOOL, request)),
  72. ))
  73. return response, nil
  74. }