base_sse.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. package service
  2. import (
  3. "errors"
  4. "sync/atomic"
  5. "time"
  6. "github.com/gin-gonic/gin"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/exception"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  12. )
  13. // baseSSEService is a helper function to handle SSE service
  14. // it accepts a generator function that returns a stream response to gin context
  15. func baseSSEService[R any](
  16. generator func() (*stream.Stream[R], error),
  17. ctx *gin.Context,
  18. max_timeout_seconds int,
  19. ) {
  20. writer := ctx.Writer
  21. writer.WriteHeader(200)
  22. writer.Header().Set("Content-Type", "text/event-stream")
  23. done := make(chan bool)
  24. doneClosed := new(int32)
  25. closed := new(int32)
  26. writeData := func(data interface{}) {
  27. if atomic.LoadInt32(closed) == 1 {
  28. return
  29. }
  30. writer.Write([]byte("data: "))
  31. writer.Write(parser.MarshalJsonBytes(data))
  32. writer.Write([]byte("\n\n"))
  33. writer.Flush()
  34. }
  35. pluginDaemonResponse, err := generator()
  36. if err != nil {
  37. writeData(exception.InternalServerError(err).ToResponse())
  38. close(done)
  39. return
  40. }
  41. routine.Submit(func() {
  42. for pluginDaemonResponse.Next() {
  43. chunk, err := pluginDaemonResponse.Read()
  44. if err != nil {
  45. writeData(exception.InvokePluginError(err).ToResponse())
  46. break
  47. }
  48. writeData(entities.NewSuccessResponse(chunk))
  49. }
  50. if atomic.CompareAndSwapInt32(doneClosed, 0, 1) {
  51. close(done)
  52. }
  53. })
  54. timer := time.NewTimer(time.Duration(max_timeout_seconds) * time.Second)
  55. defer timer.Stop()
  56. defer func() {
  57. atomic.StoreInt32(closed, 1)
  58. }()
  59. select {
  60. case <-writer.CloseNotify():
  61. pluginDaemonResponse.Close()
  62. return
  63. case <-done:
  64. return
  65. case <-timer.C:
  66. writeData(exception.InternalServerError(errors.New("killed by timeout")).ToResponse())
  67. if atomic.CompareAndSwapInt32(doneClosed, 0, 1) {
  68. close(done)
  69. }
  70. return
  71. }
  72. }