base_sse.go 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. package service
  2. import (
  3. "errors"
  4. "sync/atomic"
  5. "time"
  6. "github.com/gin-gonic/gin"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/types/exception"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  11. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  12. )
  13. // baseSSEService is a helper function to handle SSE service
  14. // it accepts a generator function that returns a stream response to gin context
  15. func baseSSEService[R any](
  16. generator func() (*stream.Stream[R], error),
  17. ctx *gin.Context,
  18. max_timeout_seconds int,
  19. ) {
  20. writer := ctx.Writer
  21. writer.WriteHeader(200)
  22. writer.Header().Set("Content-Type", "text/event-stream")
  23. done := make(chan bool)
  24. doneClosed := new(int32)
  25. closed := new(int32)
  26. writeData := func(data interface{}) {
  27. if atomic.LoadInt32(closed) == 1 {
  28. return
  29. }
  30. writer.Write([]byte("data: "))
  31. writer.Write(parser.MarshalJsonBytes(data))
  32. writer.Write([]byte("\n\n"))
  33. writer.Flush()
  34. }
  35. pluginDaemonResponse, err := generator()
  36. if err != nil {
  37. writeData(exception.InternalServerError(err).ToResponse())
  38. close(done)
  39. return
  40. }
  41. routine.Submit(map[string]string{
  42. "module": "service",
  43. "function": "baseSSEService",
  44. }, func() {
  45. for pluginDaemonResponse.Next() {
  46. chunk, err := pluginDaemonResponse.Read()
  47. if err != nil {
  48. writeData(exception.InvokePluginError(err).ToResponse())
  49. break
  50. }
  51. writeData(entities.NewSuccessResponse(chunk))
  52. }
  53. if atomic.CompareAndSwapInt32(doneClosed, 0, 1) {
  54. close(done)
  55. }
  56. })
  57. timer := time.NewTimer(time.Duration(max_timeout_seconds) * time.Second)
  58. defer timer.Stop()
  59. defer func() {
  60. atomic.StoreInt32(closed, 1)
  61. }()
  62. select {
  63. case <-writer.CloseNotify():
  64. pluginDaemonResponse.Close()
  65. return
  66. case <-done:
  67. return
  68. case <-timer.C:
  69. writeData(exception.InternalServerError(errors.New("killed by timeout")).ToResponse())
  70. if atomic.CompareAndSwapInt32(doneClosed, 0, 1) {
  71. close(done)
  72. }
  73. return
  74. }
  75. }