runner.go 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. package service
  2. import (
  3. "sync/atomic"
  4. "time"
  5. "github.com/gin-gonic/gin"
  6. "github.com/langgenius/dify-plugin-daemon/internal/types/entities"
  7. "github.com/langgenius/dify-plugin-daemon/internal/types/entities/plugin_entities"
  8. "github.com/langgenius/dify-plugin-daemon/internal/utils/parser"
  9. "github.com/langgenius/dify-plugin-daemon/internal/utils/routine"
  10. "github.com/langgenius/dify-plugin-daemon/internal/utils/stream"
  11. )
  12. func baseSSEService[T any, R any](
  13. r *plugin_entities.InvokePluginRequest[T],
  14. generator func() (*stream.StreamResponse[R], error),
  15. ctx *gin.Context,
  16. ) {
  17. writer := ctx.Writer
  18. writer.WriteHeader(200)
  19. writer.Header().Set("Content-Type", "text/event-stream")
  20. done := make(chan bool)
  21. done_closed := new(int32)
  22. write_data := func(data interface{}) {
  23. writer.Write([]byte("data: "))
  24. writer.Write(parser.MarshalJsonBytes(data))
  25. writer.Write([]byte("\n\n"))
  26. writer.Flush()
  27. }
  28. plugin_daemon_response, err := generator()
  29. last_response_at := time.Now()
  30. if err != nil {
  31. write_data(entities.NewErrorResponse(-500, err.Error()))
  32. close(done)
  33. return
  34. }
  35. routine.Submit(func() {
  36. for plugin_daemon_response.Next() {
  37. last_response_at = time.Now()
  38. chunk, err := plugin_daemon_response.Read()
  39. if err != nil {
  40. write_data(entities.NewErrorResponse(-500, err.Error()))
  41. break
  42. }
  43. write_data(entities.NewSuccessResponse(chunk))
  44. }
  45. if atomic.CompareAndSwapInt32(done_closed, 0, 1) {
  46. close(done)
  47. }
  48. })
  49. ticker := time.NewTicker(15 * time.Second)
  50. defer ticker.Stop()
  51. for {
  52. select {
  53. case <-writer.CloseNotify():
  54. plugin_daemon_response.Close()
  55. return
  56. case <-done:
  57. return
  58. case <-ticker.C:
  59. if time.Since(last_response_at) > 30*time.Second {
  60. write_data(entities.NewErrorResponse(-500, "killed by timeout"))
  61. if atomic.CompareAndSwapInt32(done_closed, 0, 1) {
  62. close(done)
  63. }
  64. return
  65. }
  66. }
  67. }
  68. }