runner.go 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. package service
  2. import (
  3. "sync/atomic"
  4. "time"
  5. "github.com/gin-gonic/gin"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  11. )
  12. func baseSSEService[T any, R any](
  13. r *plugin_entities.InvokePluginRequest[T],
  14. generator func() (*stream.StreamResponse[R], error),
  15. ctx *gin.Context,
  16. ) {
  17. writer := ctx.Writer
  18. writer.WriteHeader(200)
  19. writer.Header().Set("Content-Type", "text/event-stream")
  20. done := make(chan bool)
  21. done_closed := new(int32)
  22. closed := new(int32)
  23. write_data := func(data interface{}) {
  24. if atomic.LoadInt32(closed) == 1 {
  25. return
  26. }
  27. writer.Write([]byte("data: "))
  28. writer.Write(parser.MarshalJsonBytes(data))
  29. writer.Write([]byte("\n\n"))
  30. writer.Flush()
  31. }
  32. plugin_daemon_response, err := generator()
  33. last_response_at := time.Now()
  34. if err != nil {
  35. write_data(entities.NewErrorResponse(-500, err.Error()))
  36. close(done)
  37. return
  38. }
  39. routine.Submit(func() {
  40. for plugin_daemon_response.Next() {
  41. last_response_at = time.Now()
  42. chunk, err := plugin_daemon_response.Read()
  43. if err != nil {
  44. write_data(entities.NewErrorResponse(-500, err.Error()))
  45. break
  46. }
  47. write_data(entities.NewSuccessResponse(chunk))
  48. }
  49. if atomic.CompareAndSwapInt32(done_closed, 0, 1) {
  50. close(done)
  51. }
  52. })
  53. ticker := time.NewTicker(15 * time.Second)
  54. defer ticker.Stop()
  55. defer func() {
  56. atomic.StoreInt32(closed, 1)
  57. }()
  58. for {
  59. select {
  60. case <-writer.CloseNotify():
  61. plugin_daemon_response.Close()
  62. return
  63. case <-done:
  64. return
  65. case <-ticker.C:
  66. if time.Since(last_response_at) > 30*time.Second {
  67. write_data(entities.NewErrorResponse(-500, "killed by timeout"))
  68. if atomic.CompareAndSwapInt32(done_closed, 0, 1) {
  69. close(done)
  70. }
  71. return
  72. }
  73. }
  74. }
  75. }