generic.go 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation/transaction"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_manager"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  12. )
  13. func genericInvokePlugin[Req any, Rsp any](
  14. session *session_manager.Session,
  15. request *Req,
  16. response_buffer_size int,
  17. ) (*stream.Stream[Rsp], error) {
  18. runtime := plugin_manager.GetGlobalPluginManager().Get(session.PluginUniqueIdentifier)
  19. if runtime == nil {
  20. return nil, errors.New("plugin not found")
  21. }
  22. response := stream.NewStreamResponse[Rsp](response_buffer_size)
  23. listener := runtime.Listen(session.ID)
  24. listener.Listen(func(chunk plugin_entities.SessionMessage) {
  25. switch chunk.Type {
  26. case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
  27. chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
  28. if err != nil {
  29. log.Error("unmarshal json failed: %s", err.Error())
  30. response.WriteError(err)
  31. } else {
  32. response.Write(chunk)
  33. }
  34. case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
  35. // check if the request contains a aws_event_id
  36. if runtime.Type() == plugin_entities.PLUGIN_RUNTIME_TYPE_AWS {
  37. response.WriteError(errors.New("aws event is not supported by full duplex"))
  38. response.Close()
  39. return
  40. }
  41. if err := backwards_invocation.InvokeDify(
  42. runtime.Configuration(),
  43. session.InvokeFrom,
  44. session,
  45. transaction.NewFullDuplexEventWriter(session),
  46. chunk.Data,
  47. ); err != nil {
  48. log.Error("invoke dify failed: %s", err.Error())
  49. return
  50. }
  51. case plugin_entities.SESSION_MESSAGE_TYPE_END:
  52. response.Close()
  53. case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
  54. e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
  55. if err != nil {
  56. break
  57. }
  58. response.WriteError(errors.New(e.Error))
  59. response.Close()
  60. default:
  61. response.WriteError(errors.New("unknown stream message type: " + string(chunk.Type)))
  62. response.Close()
  63. }
  64. })
  65. response.OnClose(func() {
  66. listener.Close()
  67. })
  68. session.Write(
  69. session_manager.PLUGIN_IN_STREAM_EVENT_REQUEST,
  70. getInvokePluginMap(
  71. session,
  72. request,
  73. ),
  74. )
  75. return response, nil
  76. }
  77. func getInvokePluginMap(
  78. session *session_manager.Session,
  79. request any,
  80. ) map[string]any {
  81. req := getBasicPluginAccessMap(session.UserID, session.InvokeFrom, session.Action)
  82. for k, v := range parser.StructToMap(request) {
  83. req[k] = v
  84. }
  85. return req
  86. }