Yeuoly пре 1 година
родитељ
комит
9eb344219b

cmd/test/fuzz/main.go → cmd/test/fuzz_nodejs/main.go


+ 94 - 0
cmd/test/fuzz_python/main.go

@@ -0,0 +1,94 @@
+package main
+
+import (
+	"fmt"
+	"os"
+	"os/exec"
+	"strconv"
+	"strings"
+	"sync"
+
+	"github.com/langgenius/dify-sandbox/internal/static/python_syscall"
+)
+
+const (
+	SYSCALL_NUMS = 400
+)
+
+func run(allowed_syscalls []int) {
+	// os.Chdir("/tmp/123")
+
+	nums := []string{}
+	for _, syscall := range allowed_syscalls {
+		nums = append(nums, strconv.Itoa(syscall))
+	}
+	os.Setenv("ALLOWED_SYSCALLS", strings.Join(nums, ","))
+	_, err := exec.Command("python3", ".test.py").Output()
+	if err == nil {
+	} else {
+		fmt.Println("failed")
+	}
+}
+
+func find_syscall(syscall int, syscalls []int) int {
+	for i, s := range syscalls {
+		if s == syscall {
+			return i
+		}
+	}
+	return -1
+}
+
+func main() {
+	original := python_syscall.ALLOW_SYSCALLS
+	original = append(original, python_syscall.ALLOW_NETWORK_SYSCALLS...)
+
+	// generate task list
+	list := make([][]int, SYSCALL_NUMS)
+	for i := 0; i < SYSCALL_NUMS; i++ {
+		list[i] = make([]int, len(original))
+		copy(list[i], original)
+		// add i
+		if find_syscall(i, original) == -1 {
+			list[i] = append(list[i], i)
+		}
+
+		for j := 22; j < 23; j++ {
+			if find_syscall(j, list[i]) == -1 {
+				list[i] = append(list[i], j)
+			}
+		}
+
+		for j := 124; j < 125; j++ {
+			if find_syscall(j, list[i]) == -1 {
+				list[i] = append(list[i], j)
+			}
+		}
+	}
+
+	lock := sync.Mutex{}
+	wg := sync.WaitGroup{}
+	i := 0
+
+	// run 4 tasks concurrently
+	for j := 0; j < 4; j++ {
+		wg.Add(1)
+		go func() {
+			defer wg.Done()
+			for {
+				lock.Lock()
+				if i >= len(list) {
+					lock.Unlock()
+					return
+				}
+				task := list[i]
+				i++
+				lock.Unlock()
+				run(task)
+			}
+		}()
+	}
+
+	// wait for all tasks to finish
+	wg.Wait()
+}

+ 8 - 69
internal/core/lib/nodejs/add_seccomp.go

@@ -1,18 +1,13 @@
 package nodejs
 
 import (
-	"bytes"
-	"encoding/binary"
-	"errors"
 	"os"
 	"strconv"
 	"strings"
 	"syscall"
-	"unsafe"
 
 	"github.com/langgenius/dify-sandbox/internal/core/lib"
 	"github.com/langgenius/dify-sandbox/internal/static/nodejs_syscall"
-	sg "github.com/seccomp/libseccomp-golang"
 )
 
 //var allow_syscalls = []int{}
@@ -29,10 +24,8 @@ func InitSeccomp(uid int, gid int, enable_network bool) error {
 
 	lib.SetNoNewPrivs()
 
-	ctx, err := sg.NewFilter(sg.ActKillProcess)
-	if err != nil {
-		return err
-	}
+	allowed_syscalls := []int{}
+	allowed_not_kill_syscalls := []int{}
 
 	allowed_syscall := os.Getenv("ALLOWED_SYSCALLS")
 	if allowed_syscall != "" {
@@ -40,77 +33,23 @@ func InitSeccomp(uid int, gid int, enable_network bool) error {
 		for num := range nums {
 			syscall, err := strconv.Atoi(nums[num])
 			if err != nil {
-				return err
-			}
-			err = ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActAllow)
-			if err != nil {
-				return err
+				continue
 			}
+			allowed_syscalls = append(allowed_syscalls, syscall)
 		}
 	} else {
-		for _, syscall := range nodejs_syscall.ALLOW_SYSCALLS {
-			err = ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActAllow)
-			if err != nil {
-				return err
-			}
-		}
-
-		for _, syscall := range nodejs_syscall.ALLOW_ERROR_SYSCALLS {
-			err = ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActErrno)
-			if err != nil {
-				return err
-			}
-		}
+		allowed_syscalls = append(allowed_syscalls, nodejs_syscall.ALLOW_SYSCALLS...)
+		allowed_not_kill_syscalls = append(allowed_not_kill_syscalls, nodejs_syscall.ALLOW_ERROR_SYSCALLS...)
 
 		if enable_network {
-			for _, syscall := range nodejs_syscall.ALLOW_NETWORK_SYSCALLS {
-				err = ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActAllow)
-				if err != nil {
-					return err
-				}
-			}
+			allowed_syscalls = append(allowed_syscalls, nodejs_syscall.ALLOW_NETWORK_SYSCALLS...)
 		}
 	}
 
-	reader, writer, err := os.Pipe()
+	err = lib.Seccomp(allowed_syscalls, allowed_not_kill_syscalls)
 	if err != nil {
 		return err
 	}
-	defer reader.Close()
-	defer writer.Close()
-
-	file := os.NewFile(uintptr(writer.Fd()), "pipe")
-	ctx.ExportBPF(file)
-
-	// read from pipe
-	data := make([]byte, 4096)
-	n, err := reader.Read(data)
-	if err != nil {
-		return err
-	}
-	// load bpf
-	sock_filters := make([]syscall.SockFilter, n/8)
-	bytesBuffer := bytes.NewBuffer(data)
-	err = binary.Read(bytesBuffer, binary.LittleEndian, &sock_filters)
-	if err != nil {
-		return err
-	}
-
-	bpf := syscall.SockFprog{
-		Len:    uint16(len(sock_filters)),
-		Filter: &sock_filters[0],
-	}
-
-	_, _, err2 := syscall.Syscall(
-		syscall.SYS_SECCOMP,
-		uintptr(lib.SeccompSetModeFilter),
-		uintptr(lib.SeccompFilterFlagTSYNC),
-		uintptr(unsafe.Pointer(&bpf)),
-	)
-
-	if err2 != 0 {
-		return errors.New("seccomp error")
-	}
 
 	// setuid
 	err = syscall.Setuid(uid)

+ 17 - 55
internal/core/lib/python/add_seccomp.go

@@ -1,16 +1,13 @@
 package python
 
 import (
-	"bytes"
-	"encoding/binary"
-	"errors"
 	"os"
+	"strconv"
+	"strings"
 	"syscall"
-	"unsafe"
 
 	"github.com/langgenius/dify-sandbox/internal/core/lib"
 	"github.com/langgenius/dify-sandbox/internal/static/python_syscall"
-	sg "github.com/seccomp/libseccomp-golang"
 )
 
 //var allow_syscalls = []int{}
@@ -27,66 +24,31 @@ func InitSeccomp(uid int, gid int, enable_network bool) error {
 
 	lib.SetNoNewPrivs()
 
-	ctx, err := sg.NewFilter(sg.ActKillProcess)
-	if err != nil {
-		return err
-	}
-
-	for _, syscall := range python_syscall.ALLOW_SYSCALLS {
-		err = ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActAllow)
-		if err != nil {
-			return err
-		}
-	}
+	allowed_syscalls := []int{}
+	allowed_not_kill_syscalls := []int{}
 
-	if enable_network {
-		for _, syscall := range python_syscall.ALLOW_NETWORK_SYSCALLS {
-			err = ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActAllow)
+	allowed_syscall := os.Getenv("ALLOWED_SYSCALLS")
+	if allowed_syscall != "" {
+		nums := strings.Split(allowed_syscall, ",")
+		for num := range nums {
+			syscall, err := strconv.Atoi(nums[num])
 			if err != nil {
-				return err
+				continue
 			}
+			allowed_syscalls = append(allowed_syscalls, syscall)
 		}
-	}
+	} else {
+		allowed_syscalls = append(allowed_syscalls, python_syscall.ALLOW_SYSCALLS...)
 
-	reader, writer, err := os.Pipe()
-	if err != nil {
-		return err
+		if enable_network {
+			allowed_syscalls = append(allowed_syscalls, python_syscall.ALLOW_NETWORK_SYSCALLS...)
+		}
 	}
-	defer reader.Close()
-	defer writer.Close()
 
-	file := os.NewFile(uintptr(writer.Fd()), "pipe")
-	ctx.ExportBPF(file)
-
-	// read from pipe
-	data := make([]byte, 4096)
-	n, err := reader.Read(data)
+	err = lib.Seccomp(allowed_syscalls, allowed_not_kill_syscalls)
 	if err != nil {
 		return err
 	}
-	// load bpf
-	sock_filters := make([]syscall.SockFilter, n/8)
-	bytesBuffer := bytes.NewBuffer(data)
-	err = binary.Read(bytesBuffer, binary.LittleEndian, &sock_filters)
-	if err != nil {
-		return err
-	}
-
-	bpf := syscall.SockFprog{
-		Len:    uint16(len(sock_filters)),
-		Filter: &sock_filters[0],
-	}
-
-	_, _, err2 := syscall.Syscall(
-		syscall.SYS_SECCOMP,
-		uintptr(lib.SeccompSetModeFilter),
-		uintptr(lib.SeccompFilterFlagTSYNC),
-		uintptr(unsafe.Pointer(&bpf)),
-	)
-
-	if err2 != 0 {
-		return errors.New("seccomp error")
-	}
 
 	// setuid
 	err = syscall.Setuid(uid)

+ 68 - 0
internal/core/lib/seccomp.go

@@ -0,0 +1,68 @@
+package lib
+
+import (
+	"bytes"
+	"encoding/binary"
+	"os"
+	"syscall"
+	"unsafe"
+
+	sg "github.com/seccomp/libseccomp-golang"
+)
+
+func Seccomp(allowed_syscalls []int, allowed_not_kill_syscalls []int) error {
+	ctx, err := sg.NewFilter(sg.ActKillProcess)
+	if err != nil {
+		return err
+	}
+
+	reader, writer, err := os.Pipe()
+	if err != nil {
+		return err
+	}
+	defer reader.Close()
+	defer writer.Close()
+
+	for _, syscall := range allowed_syscalls {
+		ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActAllow)
+	}
+
+	for _, syscall := range allowed_not_kill_syscalls {
+		ctx.AddRule(sg.ScmpSyscall(syscall), sg.ActErrno)
+	}
+
+	file := os.NewFile(uintptr(writer.Fd()), "pipe")
+	ctx.ExportBPF(file)
+
+	// read from pipe
+	data := make([]byte, 4096)
+	n, err := reader.Read(data)
+	if err != nil {
+		return err
+	}
+	// load bpf
+	sock_filters := make([]syscall.SockFilter, n/8)
+	bytesBuffer := bytes.NewBuffer(data)
+	err = binary.Read(bytesBuffer, binary.LittleEndian, &sock_filters)
+	if err != nil {
+		return err
+	}
+
+	bpf := syscall.SockFprog{
+		Len:    uint16(len(sock_filters)),
+		Filter: &sock_filters[0],
+	}
+
+	_, _, err2 := syscall.Syscall(
+		SYS_SECCOMP,
+		uintptr(SeccompSetModeFilter),
+		uintptr(SeccompFilterFlagTSYNC),
+		uintptr(unsafe.Pointer(&bpf)),
+	)
+
+	if err2 != 0 {
+		return err2
+	}
+
+	return nil
+}

+ 7 - 0
internal/core/lib/seccomp_syscall_amd64.go

@@ -0,0 +1,7 @@
+//go:build linux && amd64
+
+package lib
+
+const (
+	SYS_SECCOMP = 317
+)

+ 9 - 0
internal/core/lib/seccomp_syscall_arm64.go

@@ -0,0 +1,9 @@
+//go:build linux && arm64
+
+package lib
+
+import "syscall"
+
+const (
+	SYS_SECCOMP = syscall.SYS_SECCOMP
+)

+ 2 - 2
internal/static/python_syscall/syscalls_arm64.go

@@ -20,7 +20,7 @@ var ALLOW_SYSCALLS = []int{
 	syscall.SYS_GETPID, syscall.SYS_GETPPID, syscall.SYS_GETTID,
 	syscall.SYS_EXIT, syscall.SYS_EXIT_GROUP,
 	syscall.SYS_TGKILL, syscall.SYS_RT_SIGACTION,
-	syscall.SYS_IOCTL,
+	syscall.SYS_IOCTL, syscall.SYS_SCHED_YIELD,
 	// time
 	syscall.SYS_CLOCK_GETTIME, syscall.SYS_GETTIMEOFDAY, syscall.SYS_NANOSLEEP,
 	syscall.SYS_EPOLL_CTL, syscall.SYS_CLOCK_NANOSLEEP, syscall.SYS_PSELECT6,
@@ -35,5 +35,5 @@ var ALLOW_NETWORK_SYSCALLS = []int{
 	syscall.SYS_RECVFROM, syscall.SYS_RECVMSG, syscall.SYS_GETSOCKOPT,
 	syscall.SYS_GETSOCKNAME, syscall.SYS_GETPEERNAME, syscall.SYS_SETSOCKOPT,
 	syscall.SYS_PPOLL, syscall.SYS_UNAME, syscall.SYS_SENDMMSG,
-	syscall.SYS_FSTATAT, syscall.SYS_FSTAT, syscall.SYS_FSTATFS,
+	syscall.SYS_FSTATAT, syscall.SYS_FSTAT, syscall.SYS_FSTATFS, syscall.SYS_EPOLL_PWAIT,
 }