seccomp.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. package runner
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "fmt"
  6. "os"
  7. "syscall"
  8. "unsafe"
  9. "github.com/langgenius/dify-sandbox/internal/static"
  10. sg "github.com/seccomp/libseccomp-golang"
  11. )
  12. type SeccompRunner struct {
  13. }
  14. func (s *SeccompRunner) WithSeccomp(closures func() error) error {
  15. ctx, err := sg.NewFilter(sg.ActKillProcess)
  16. if err != nil {
  17. return err
  18. }
  19. defer ctx.Release()
  20. for call := range static.ALLOW_SYSCALLS {
  21. err = ctx.AddRule(sg.ScmpSyscall(static.ALLOW_SYSCALLS[call]), sg.ActAllow)
  22. if err != nil {
  23. return err
  24. }
  25. }
  26. reader, writer, err := os.Pipe()
  27. if err != nil {
  28. return err
  29. }
  30. defer reader.Close()
  31. defer writer.Close()
  32. file := os.NewFile(uintptr(writer.Fd()), "pipe")
  33. ctx.ExportBPF(file)
  34. // read from pipe
  35. data := make([]byte, 4096)
  36. n, err := reader.Read(data)
  37. if err != nil {
  38. return err
  39. }
  40. // load bpf
  41. sock_filters := make([]syscall.SockFilter, n/8)
  42. bytesBuffer := bytes.NewBuffer(data)
  43. err = binary.Read(bytesBuffer, binary.LittleEndian, &sock_filters)
  44. if err != nil {
  45. return err
  46. }
  47. var pipe_fds [2]int
  48. // create stdout pipe
  49. err = syscall.Pipe2(pipe_fds[0:], syscall.O_CLOEXEC)
  50. if err != nil {
  51. return err
  52. }
  53. stdout_reader, stdout_writer := pipe_fds[0], pipe_fds[1]
  54. // create stderr pipe
  55. err = syscall.Pipe2(pipe_fds[0:], syscall.O_CLOEXEC)
  56. if err != nil {
  57. return err
  58. }
  59. stderr_reader, stderr_writer := pipe_fds[0], pipe_fds[1]
  60. // fork subprocess
  61. pid, _, errno := syscall.RawSyscall(syscall.SYS_FORK, 0, 0, 0)
  62. if errno != 0 {
  63. return fmt.Errorf("fork failed: %d", errno)
  64. }
  65. defer func() {
  66. syscall.Close(int(stdout_reader))
  67. syscall.Close(int(stderr_reader))
  68. syscall.Close(int(stdout_writer))
  69. syscall.Close(int(stderr_writer))
  70. }()
  71. // child process
  72. if pid == 0 {
  73. // close read end of stdout pipe
  74. syscall.Close(int(stdout_reader))
  75. // close read end of stderr pipe
  76. syscall.Close(int(stderr_reader))
  77. defer syscall.Close(int(stdout_writer))
  78. defer syscall.Close(int(stderr_writer))
  79. defer syscall.Exit(0)
  80. bpf := syscall.SockFprog{
  81. Len: uint16(len(sock_filters)),
  82. Filter: &sock_filters[0],
  83. }
  84. _, _, err2 := syscall.RawSyscall6(syscall.SYS_PRCTL, syscall.PR_SET_SECCOMP, 2, uintptr(unsafe.Pointer(&bpf)), 0, 0, 0)
  85. if err2 != 0 {
  86. response := fmt.Sprintf("prctl failed: %d\n", err2)
  87. _, _ = syscall.Write(int(stderr_writer), []byte(response))
  88. return nil
  89. }
  90. _, _, err2 = syscall.RawSyscall(syscall.SYS_SETGID, uintptr(static.SANDBOX_GROUP_ID), 0, 0)
  91. if err2 != 0 {
  92. response := fmt.Sprintf("setgid failed: %v\n", err2)
  93. _, _ = syscall.Write(int(stderr_writer), []byte(response))
  94. return nil
  95. }
  96. _, _, err2 = syscall.RawSyscall(syscall.SYS_SETUID, uintptr(static.SANDBOX_USER_UID), 0, 0)
  97. if err2 != 0 {
  98. response := fmt.Sprintf("setuid failed: %v\n", err2)
  99. _, _ = syscall.Write(int(stderr_writer), []byte(response))
  100. return nil
  101. }
  102. err := closures()
  103. if err != nil {
  104. response := fmt.Sprintf("%v\n", err)
  105. _, _ = syscall.Write(int(stderr_writer), []byte(response))
  106. return nil
  107. }
  108. } else {
  109. // close write end of stdout pipe
  110. syscall.Close(int(stdout_writer))
  111. // close write end of stderr pipe
  112. syscall.Close(int(stderr_writer))
  113. // wait for child process to finish
  114. _, _, err2 := syscall.RawSyscall(syscall.SYS_WAIT4, pid, 0, 0)
  115. if err2 != 0 {
  116. return fmt.Errorf("wait4 failed: %d", err2)
  117. }
  118. // read from stderr pipe
  119. data := make([]byte, 4096)
  120. _, err := syscall.Read(int(stderr_reader), data)
  121. if err != nil {
  122. return err
  123. }
  124. fmt.Println(string(data))
  125. }
  126. return nil
  127. }