generic.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. package plugin_daemon
  2. import (
  3. "errors"
  4. "fmt"
  5. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation"
  6. "github.com/langgenius/dify-plugin-daemon/internal/core/plugin_daemon/backwards_invocation/transaction"
  7. "github.com/langgenius/dify-plugin-daemon/internal/core/session_manager"
  8. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  10. "github.com/langgenius/dify-plugin-daemon/pkg/entities/plugin_entities"
  11. )
  12. func GenericInvokePlugin[Req any, Rsp any](
  13. session *session_manager.Session,
  14. request *Req,
  15. response_buffer_size int,
  16. ) (*stream.Stream[Rsp], error) {
  17. runtime := session.Runtime()
  18. if runtime == nil {
  19. return nil, errors.New("plugin runtime not found")
  20. }
  21. response := stream.NewStream[Rsp](response_buffer_size)
  22. listener := runtime.Listen(session.ID)
  23. listener.Listen(func(chunk plugin_entities.SessionMessage) {
  24. switch chunk.Type {
  25. case plugin_entities.SESSION_MESSAGE_TYPE_STREAM:
  26. chunk, err := parser.UnmarshalJsonBytes[Rsp](chunk.Data)
  27. if err != nil {
  28. response.WriteError(errors.New(parser.MarshalJson(map[string]string{
  29. "error_type": "unmarshal_error",
  30. "message": fmt.Sprintf("unmarshal json failed: %s", err.Error()),
  31. })))
  32. response.Close()
  33. return
  34. } else {
  35. response.Write(chunk)
  36. }
  37. case plugin_entities.SESSION_MESSAGE_TYPE_INVOKE:
  38. // check if the request contains a aws_event_id
  39. if runtime.Type() == plugin_entities.PLUGIN_RUNTIME_TYPE_SERVERLESS {
  40. response.WriteError(errors.New(parser.MarshalJson(map[string]string{
  41. "error_type": "aws_event_not_supported",
  42. "message": "aws event is not supported by full duplex",
  43. })))
  44. response.Close()
  45. return
  46. }
  47. if err := backwards_invocation.InvokeDify(
  48. runtime.Configuration(),
  49. session.InvokeFrom,
  50. session,
  51. transaction.NewFullDuplexEventWriter(session),
  52. chunk.Data,
  53. ); err != nil {
  54. response.WriteError(errors.New(parser.MarshalJson(map[string]string{
  55. "error_type": "invoke_dify_error",
  56. "message": fmt.Sprintf("invoke dify failed: %s", err.Error()),
  57. })))
  58. response.Close()
  59. return
  60. }
  61. case plugin_entities.SESSION_MESSAGE_TYPE_END:
  62. response.Close()
  63. case plugin_entities.SESSION_MESSAGE_TYPE_ERROR:
  64. e, err := parser.UnmarshalJsonBytes[plugin_entities.ErrorResponse](chunk.Data)
  65. if err != nil {
  66. break
  67. }
  68. response.WriteError(errors.New(e.Error()))
  69. response.Close()
  70. default:
  71. response.WriteError(errors.New(parser.MarshalJson(map[string]string{
  72. "error_type": "unknown_stream_message_type",
  73. "message": "unknown stream message type: " + string(chunk.Type),
  74. })))
  75. response.Close()
  76. }
  77. })
  78. // close the listener if stream outside is closed due to close of connection
  79. response.OnClose(func() {
  80. listener.Close()
  81. })
  82. session.Write(
  83. session_manager.PLUGIN_IN_STREAM_EVENT_REQUEST,
  84. getInvokePluginMap(
  85. session,
  86. request,
  87. ),
  88. )
  89. return response, nil
  90. }
  91. func getInvokePluginMap(
  92. session *session_manager.Session,
  93. request any,
  94. ) map[string]any {
  95. req := getBasicPluginAccessMap(
  96. session.UserID,
  97. session.InvokeFrom,
  98. session.Action,
  99. )
  100. for k, v := range parser.StructToMap(request) {
  101. req[k] = v
  102. }
  103. return req
  104. }