瀏覽代碼

fix: nodejs

Yeuoly 1 年之前
父節點
當前提交
3620c21560

+ 55 - 8
cmd/test/fuzz/main.go

@@ -4,27 +4,74 @@ import (
 	"fmt"
 	"os"
 	"os/exec"
+	"strconv"
+	"strings"
 	"sync"
+
+	"github.com/langgenius/dify-sandbox/internal/static/nodejs_syscall"
 )
 
 const (
 	SYSCALL_NUMS = 400
 )
 
-func run(i int) {
-	os.Setenv("DISABLE_SYSCALL", fmt.Sprintf("%d", i))
-	_, err := exec.Command("node", "test.js").Output()
-	if err != nil {
-		fmt.Println(i)
+func run(allowed_syscalls []int) {
+	os.Chdir("/tmp/sandbox-463ec16c-8796-4e8f-988a-f61de7dc6976/tmp/sandbox-nodejs-project/node_temp/node_temp")
+
+	nums := []string{}
+	for _, syscall := range allowed_syscalls {
+		nums = append(nums, strconv.Itoa(syscall))
+	}
+	os.Setenv("ALLOWED_SYSCALLS", strings.Join(nums, ","))
+	_, err := exec.Command("node", "test.js", "65537", "1001", "{\"enable_network\":true}").Output()
+	if err == nil {
+		fmt.Println("success")
+	} else {
+		fmt.Println("failed")
 	}
 }
 
+func find_syscall(syscall int, syscalls []int) int {
+	for i, s := range syscalls {
+		if s == syscall {
+			return i
+		}
+	}
+	return -1
+}
+
 func main() {
-	os.Chdir(".node_temp")
+	original := nodejs_syscall.ALLOW_SYSCALLS
+	original = append(original, nodejs_syscall.ALLOW_NETWORK_SYSCALLS...)
+	original = append(original, nodejs_syscall.ALLOW_ERROR_SYSCALLS...)
+
 	// generate task list
-	list := make([]int, SYSCALL_NUMS)
+	list := make([][]int, SYSCALL_NUMS)
 	for i := 0; i < SYSCALL_NUMS; i++ {
-		list[i] = i
+		list[i] = make([]int, len(original))
+		copy(list[i], original)
+		// add i
+		if find_syscall(i, original) == -1 {
+			list[i] = append(list[i], i)
+		}
+
+		for j := 124; j < 125; j++ {
+			if find_syscall(j, list[i]) == -1 {
+				list[i] = append(list[i], j)
+			}
+		}
+
+		for j := 220; j < 221; j++ {
+			if find_syscall(j, list[i]) == -1 {
+				list[i] = append(list[i], j)
+			}
+		}
+
+		for j := 293; j < 294; j++ {
+			if find_syscall(j, list[i]) == -1 {
+				list[i] = append(list[i], j)
+			}
+		}
 	}
 
 	lock := sync.Mutex{}

+ 78 - 17
internal/core/lib/nodejs/add_seccomp.go

@@ -1,15 +1,26 @@
 package nodejs
 
 import (
+	"bytes"
+	"encoding/binary"
+	"errors"
 	"os"
 	"strconv"
+	"strings"
 	"syscall"
+	"unsafe"
 
+	"github.com/langgenius/dify-sandbox/internal/core/lib"
 	"github.com/langgenius/dify-sandbox/internal/static/nodejs_syscall"
 	sg "github.com/seccomp/libseccomp-golang"
 )
 
-// var allow_syscalls = []int{}
+const (
+	seccompSetModeFilter   = 0x1
+	seccompFilterFlagTSYNC = 0x1
+)
+
+//var allow_syscalls = []int{}
 
 func InitSeccomp(uid int, gid int, enable_network bool) error {
 	err := syscall.Chroot(".")
@@ -21,40 +32,90 @@ func InitSeccomp(uid int, gid int, enable_network bool) error {
 		return err
 	}
 
-	disabled_syscall, err := strconv.Atoi(os.Getenv("DISABLE_SYSCALL"))
-	if err != nil {
-		disabled_syscall = -1
-	}
+	lib.SetNoNewPrivs()
 
 	ctx, err := sg.NewFilter(sg.ActKillProcess)
 	if err != nil {
 		return err
 	}
-	defer ctx.Release()
 
-	for _, syscall := range nodejs_syscall.ALLOW_SYSCALLS {
-		if syscall == disabled_syscall {
-			continue
+	allowed_syscall := os.Getenv("ALLOWED_SYSCALLS")
+	if allowed_syscall != "" {
+		nums := strings.Split(allowed_syscall, ",")
+		for num := range nums {
+			syscall, err := strconv.Atoi(nums[num])
+			if err != nil {
+				return err
+			}
+			err = ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActAllow)
+			if err != nil {
+				return err
+			}
 		}
-		err = ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActAllow)
-		if err != nil {
-			return err
+	} else {
+		for _, syscall := range nodejs_syscall.ALLOW_SYSCALLS {
+			err = ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActAllow)
+			if err != nil {
+				return err
+			}
 		}
-	}
 
-	if enable_network {
-		for _, syscall := range nodejs_syscall.ALLOW_NETWORK_SYSCALLS {
-			err = ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActAllow)
+		for _, syscall := range nodejs_syscall.ALLOW_ERROR_SYSCALLS {
+			err = ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActErrno)
 			if err != nil {
 				return err
 			}
 		}
+
+		if enable_network {
+			for _, syscall := range nodejs_syscall.ALLOW_NETWORK_SYSCALLS {
+				err = ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActAllow)
+				if err != nil {
+					return err
+				}
+			}
+		}
+	}
+
+	reader, writer, err := os.Pipe()
+	if err != nil {
+		return err
 	}
+	defer reader.Close()
+	defer writer.Close()
+
+	file := os.NewFile(uintptr(writer.Fd()), "pipe")
+	ctx.ExportBPF(file)
 
-	err = ctx.Load()
+	// read from pipe
+	data := make([]byte, 4096)
+	n, err := reader.Read(data)
 	if err != nil {
 		return err
 	}
+	// load bpf
+	sock_filters := make([]syscall.SockFilter, n/8)
+	bytesBuffer := bytes.NewBuffer(data)
+	err = binary.Read(bytesBuffer, binary.LittleEndian, &sock_filters)
+	if err != nil {
+		return err
+	}
+
+	bpf := syscall.SockFprog{
+		Len:    uint16(len(sock_filters)),
+		Filter: &sock_filters[0],
+	}
+
+	_, _, err2 := syscall.Syscall(
+		syscall.SYS_SECCOMP,
+		uintptr(seccompSetModeFilter),
+		uintptr(seccompFilterFlagTSYNC),
+		uintptr(unsafe.Pointer(&bpf)),
+	)
+
+	if err2 != 0 {
+		return errors.New("seccomp error")
+	}
 
 	// setuid
 	err = syscall.Setuid(uid)

+ 3 - 0
internal/core/lib/python/add_seccomp.go

@@ -3,6 +3,7 @@ package python
 import (
 	"syscall"
 
+	"github.com/langgenius/dify-sandbox/internal/core/lib"
 	"github.com/langgenius/dify-sandbox/internal/static/python_syscall"
 	sg "github.com/seccomp/libseccomp-golang"
 )
@@ -19,6 +20,8 @@ func InitSeccomp(uid int, gid int, enable_network bool) error {
 		return err
 	}
 
+	lib.SetNoNewPrivs()
+
 	ctx, err := sg.NewFilter(sg.ActKillProcess)
 	if err != nil {
 		return err

+ 13 - 0
internal/core/lib/set_no_new_privs.go

@@ -0,0 +1,13 @@
+package lib
+
+import (
+	"syscall"
+)
+
+func SetNoNewPrivs() error {
+	_, _, e := syscall.Syscall6(syscall.SYS_PRCTL, 0x26, 1, 0, 0, 0, 0)
+	if e != 0 {
+		return e
+	}
+	return nil
+}

+ 1 - 1
internal/core/runner/nodejs/prescript.js

@@ -2,7 +2,7 @@ const argv = process.argv
 
 const koffi = require('koffi')
 const lib = koffi.load('/tmp/sandbox-nodejs/nodejs.so')
-const difySeccomp = lib.func('void DifySeccomp(int, int)')
+const difySeccomp = lib.func('void DifySeccomp(int, int, bool)')
 
 const uid = parseInt(argv[2])
 const gid = parseInt(argv[3])

+ 4 - 0
internal/static/nodejs_syscall/syscalls_amd64.go

@@ -29,6 +29,10 @@ var ALLOW_SYSCALLS = []int{
 	syscall.SYS_DUP3,
 }
 
+var ALLOW_ERROR_SYSCALLS = []int{
+	syscall.SYS_CLONE,
+}
+
 var ALLOW_NETWORK_SYSCALLS = []int{
 	syscall.SYS_SOCKET, syscall.SYS_CONNECT, syscall.SYS_BIND, syscall.SYS_LISTEN, syscall.SYS_ACCEPT, syscall.SYS_SENDTO, syscall.SYS_RECVFROM,
 	syscall.SYS_GETSOCKNAME, syscall.SYS_RECVMSG, syscall.SYS_GETPEERNAME, syscall.SYS_SETSOCKOPT, syscall.SYS_PPOLL, syscall.SYS_UNAME,

+ 17 - 7
internal/static/nodejs_syscall/syscalls_arm64.go

@@ -11,24 +11,34 @@ var ALLOW_SYSCALLS = []int{
 	syscall.SYS_READLINKAT, syscall.SYS_OPENAT,
 
 	// process
-	syscall.SYS_GETPID, syscall.SYS_TGKILL, syscall.SYS_FUTEX, syscall.SYS_EXIT_GROUP,
+	syscall.SYS_GETPID, syscall.SYS_TGKILL, syscall.SYS_FUTEX, syscall.SYS_IOCTL,
+	syscall.SYS_EXIT, syscall.SYS_EXIT_GROUP,
+	syscall.SYS_SET_ROBUST_LIST, syscall.SYS_NANOSLEEP, syscall.SYS_SCHED_GETAFFINITY,
+	syscall.SYS_SCHED_YIELD,
 
 	// memory
 	syscall.SYS_RT_SIGPROCMASK, syscall.SYS_SIGALTSTACK, syscall.SYS_RT_SIGACTION,
 	syscall.SYS_MMAP, syscall.SYS_MUNMAP, syscall.SYS_MADVISE, syscall.SYS_MPROTECT,
+	syscall.SYS_RT_SIGRETURN,
 
 	//user/group
-	syscall.SYS_SETUID, syscall.SYS_SETGID,
+	syscall.SYS_SETUID, syscall.SYS_SETGID, syscall.SYS_GETTID,
 	syscall.SYS_GETUID, syscall.SYS_GETGID,
 
 	// epoll
 	syscall.SYS_EPOLL_CTL, syscall.SYS_EPOLL_PWAIT,
 }
 
+var ALLOW_ERROR_SYSCALLS = []int{
+	syscall.SYS_CLONE, 293,
+}
+
 var ALLOW_NETWORK_SYSCALLS = []int{
-	syscall.SYS_SOCKET, syscall.SYS_CONNECT, syscall.SYS_BIND, syscall.SYS_LISTEN, syscall.SYS_ACCEPT, syscall.SYS_SENDTO, syscall.SYS_RECVFROM,
-	syscall.SYS_GETSOCKNAME, syscall.SYS_RECVMSG, syscall.SYS_GETPEERNAME, syscall.SYS_SETSOCKOPT, syscall.SYS_PPOLL, syscall.SYS_UNAME,
-	syscall.SYS_SENDMMSG, syscall.SYS_GETSOCKOPT,
-	syscall.SYS_FSTATAT, syscall.SYS_IOCTL, syscall.SYS_LSEEK,
-	syscall.SYS_FSTAT, syscall.SYS_FCNTL, syscall.SYS_FSTATFS,
+	syscall.SYS_SOCKET, syscall.SYS_CONNECT, syscall.SYS_BIND, syscall.SYS_LISTEN, syscall.SYS_ACCEPT,
+	syscall.SYS_SENDTO, syscall.SYS_RECVFROM,
+	syscall.SYS_GETSOCKNAME, syscall.SYS_SETSOCKOPT, syscall.SYS_GETSOCKOPT,
+	syscall.SYS_SENDMMSG, syscall.SYS_RECVMSG,
+	syscall.SYS_GETPEERNAME, syscall.SYS_PPOLL, syscall.SYS_UNAME,
+	syscall.SYS_FSTATAT, syscall.SYS_LSEEK,
+	syscall.SYS_FSTATFS,
 }