io.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. package aws_manager
  2. import (
  3. "bufio"
  4. "bytes"
  5. "context"
  6. "fmt"
  7. "net/http"
  8. "net/url"
  9. "time"
  10. "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
  11. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  12. "github.com/langgenius/dify-plugin-daemon/internal/utils/log"
  13. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  14. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  15. )
  16. func (r *AWSPluginRuntime) Listen(session_id string) *entities.Broadcast[plugin_entities.SessionMessage] {
  17. l := entities.NewBroadcast[plugin_entities.SessionMessage]()
  18. // store the listener
  19. r.listeners.Store(session_id, l)
  20. return l
  21. }
  22. // For AWS Lambda, write is equivalent to http request, it's not a normal stream like stdio and tcp
  23. func (r *AWSPluginRuntime) Write(session_id string, data []byte) {
  24. l, ok := r.listeners.Load(session_id)
  25. if !ok {
  26. log.Error("session %s not found", session_id)
  27. return
  28. }
  29. url, err := url.JoinPath(r.LambdaURL, "invoke")
  30. if err != nil {
  31. l.Send(plugin_entities.SessionMessage{
  32. Type: plugin_entities.SESSION_MESSAGE_TYPE_ERROR,
  33. Data: parser.MarshalJsonBytes(plugin_entities.ErrorResponse{
  34. Error: fmt.Sprintf("Error creating request: %v", err),
  35. }),
  36. })
  37. l.Close()
  38. r.Error(fmt.Sprintf("Error creating request: %v", err))
  39. return
  40. }
  41. connect_time := 240 * time.Second
  42. // create a new http request
  43. ctx, cancel := context.WithTimeout(context.Background(), connect_time)
  44. time.AfterFunc(connect_time, cancel)
  45. req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(data))
  46. if err != nil {
  47. r.Error(fmt.Sprintf("Error creating request: %v", err))
  48. return
  49. }
  50. req.Header.Set("Content-Type", "application/json")
  51. req.Header.Set("Accept", "text/event-stream")
  52. req.Header.Set("Dify-Plugin-Session-ID", session_id)
  53. routine.Submit(func() {
  54. // remove the session from listeners
  55. defer r.listeners.Delete(session_id)
  56. defer l.Close()
  57. defer l.Send(plugin_entities.SessionMessage{
  58. Type: plugin_entities.SESSION_MESSAGE_TYPE_END,
  59. Data: []byte(""),
  60. })
  61. response, err := r.client.Do(req)
  62. if err != nil {
  63. l.Send(plugin_entities.SessionMessage{
  64. Type: plugin_entities.SESSION_MESSAGE_TYPE_ERROR,
  65. Data: parser.MarshalJsonBytes(plugin_entities.ErrorResponse{
  66. Error: "failed to establish connection to plugin",
  67. }),
  68. })
  69. r.Error(fmt.Sprintf("Error sending request to aws lambda: %v", err))
  70. return
  71. }
  72. // write to data stream
  73. scanner := bufio.NewScanner(response.Body)
  74. session_alive := true
  75. for scanner.Scan() && session_alive {
  76. bytes := scanner.Bytes()
  77. if len(bytes) == 0 {
  78. continue
  79. }
  80. plugin_entities.ParsePluginUniversalEvent(
  81. bytes,
  82. func(session_id string, data []byte) {
  83. session_message, err := parser.UnmarshalJsonBytes[plugin_entities.SessionMessage](data)
  84. if err != nil {
  85. l.Send(plugin_entities.SessionMessage{
  86. Type: plugin_entities.SESSION_MESSAGE_TYPE_ERROR,
  87. Data: parser.MarshalJsonBytes(plugin_entities.ErrorResponse{
  88. Error: fmt.Sprintf("failed to parse session message %s, err: %v", bytes, err),
  89. }),
  90. })
  91. session_alive = false
  92. }
  93. l.Send(session_message)
  94. },
  95. func() {},
  96. func(err string) {
  97. l.Send(plugin_entities.SessionMessage{
  98. Type: plugin_entities.SESSION_MESSAGE_TYPE_ERROR,
  99. Data: parser.MarshalJsonBytes(plugin_entities.ErrorResponse{
  100. Error: fmt.Sprintf("encountered an error: %v", err),
  101. }),
  102. })
  103. },
  104. func(message string) {},
  105. )
  106. }
  107. if scanner.Err() != nil {
  108. l.Send(plugin_entities.SessionMessage{
  109. Type: plugin_entities.SESSION_MESSAGE_TYPE_ERROR,
  110. Data: parser.MarshalJsonBytes(plugin_entities.ErrorResponse{
  111. Error: fmt.Sprintf("failed to read response body: %v", scanner.Err()),
  112. }),
  113. })
  114. }
  115. })
  116. }