generic.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation/transaction"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  12. )
  13. func GenericInvokePlugin[Req any, Rsp any](
  14. session *session_manager.Session,
  15. request *Req,
  16. response_buffer_size int,
  17. ) (*stream.Stream[Rsp], error) {
  18. runtime := session.Runtime()
  19. if runtime == nil {
  20. return nil, errors.New("plugin runtime not found")
  21. }
  22. response := stream.NewStream[Rsp](response_buffer_size)
  23. listener := runtime.Listen(session.ID)
  24. listener.Listen(func(chunk plugin_entities.SessionMessage) {
  25. switch chunk.Type {
  26. case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
  27. chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
  28. if err != nil {
  29. log.Error("unmarshal json failed: %s", err.Error())
  30. response.WriteError(errors.New(parser.MarshalJson(map[string]string{
  31. "error_type": "unmarshal_error",
  32. "message": fmt.Sprintf("unmarshal json failed: %s", err.Error()),
  33. })))
  34. response.Close()
  35. return
  36. } else {
  37. response.Write(chunk)
  38. }
  39. case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
  40. // check if the request contains a aws_event_id
  41. if runtime.Type() == plugin_entities.PLUGIN_RUNTIME_TYPE_AWS {
  42. response.WriteError(errors.New(parser.MarshalJson(map[string]string{
  43. "error_type": "aws_event_not_supported",
  44. "message": "aws event is not supported by full duplex",
  45. })))
  46. response.Close()
  47. return
  48. }
  49. if err := backwards_invocation.InvokeDify(
  50. runtime.Configuration(),
  51. session.InvokeFrom,
  52. session,
  53. transaction.NewFullDuplexEventWriter(session),
  54. chunk.Data,
  55. ); err != nil {
  56. response.WriteError(errors.New(parser.MarshalJson(map[string]string{
  57. "error_type": "invoke_dify_error",
  58. "message": fmt.Sprintf("invoke dify failed: %s", err.Error()),
  59. })))
  60. response.Close()
  61. return
  62. }
  63. case plugin_entities.SESSION_MESSAGE_TYPE_END:
  64. response.Close()
  65. case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
  66. e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
  67. if err != nil {
  68. break
  69. }
  70. response.WriteError(errors.New(e.Error()))
  71. response.Close()
  72. default:
  73. response.WriteError(errors.New(parser.MarshalJson(map[string]string{
  74. "error_type": "unknown_stream_message_type",
  75. "message": "unknown stream message type: " + string(chunk.Type),
  76. })))
  77. response.Close()
  78. }
  79. })
  80. response.OnClose(func() {
  81. listener.Close()
  82. })
  83. session.Write(
  84. session_manager.PLUGIN_IN_STREAM_EVENT_REQUEST,
  85. getInvokePluginMap(
  86. session,
  87. request,
  88. ),
  89. )
  90. return response, nil
  91. }
  92. func getInvokePluginMap(
  93. session *session_manager.Session,
  94. request any,
  95. ) map[string]any {
  96. req := getBasicPluginAccessMap(
  97. session.UserID,
  98. session.InvokeFrom,
  99. session.Action,
  100. )
  101. for k, v := range parser.StructToMap(request) {
  102. req[k] = v
  103. }
  104. return req
  105. }