Bläddra i källkod

feat: support stdout and stdin

Yeuoly 1 år sedan
förälder
incheckning
db1fd270a2

+ 22 - 1
cmd/test/sandbox/main.go

@@ -1,17 +1,38 @@
 package main
 
 import (
+	"fmt"
 	"time"
 
 	"github.com/langgenius/dify-sandbox/internal/core/runner/python"
+	"github.com/langgenius/dify-sandbox/internal/utils/log"
 )
 
 const python_script = `def foo(a, b):
 	return a + b
 print(foo(1, 2))
+
+import json
+import os
+print(json.dumps({"a": 1, "b": 2}))
 `
 
 func main() {
 	runner := python.PythonRunner{}
-	runner.Run(python_script, time.Minute, nil)
+	stdout, stderr, done, err := runner.Run(python_script, time.Minute, nil)
+	if err != nil {
+		log.Panic("failed to run python script: %v", err)
+	}
+
+	for {
+		select {
+		case <-done:
+			fmt.Println("done")
+			return
+		case out := <-stdout:
+			fmt.Print(string(out))
+		case err := <-stderr:
+			fmt.Print(string(err))
+		}
+	}
 }

+ 7 - 1
internal/core/runner/python/prescript.py

@@ -2,10 +2,16 @@ if __name__ == "__main__":
     import ctypes
     import os
     import sys
+    import json
+    import typing
+    import time
+
+    print(os.listdir("/tmp"))
+
     if len(sys.argv) != 4:
         sys.exit(-1)
 
-    lib = ctypes.CDLL("./tmp/sandbox-python/python.so")
+    lib = ctypes.CDLL("/tmp/sandbox-python/python.so")
     module = sys.argv[1]
     code = open(module).read()
 

+ 115 - 17
internal/core/runner/python/python.go

@@ -25,16 +25,16 @@ var python_sandbox_fs []byte
 //go:embed python.so
 var python_lib []byte
 
-func (p *PythonRunner) Run(code string, timeout time.Duration, stdin chan []byte) (<-chan []byte, <-chan []byte, error) {
+func (p *PythonRunner) Run(code string, timeout time.Duration, stdin []byte) (chan []byte, chan []byte, chan bool, error) {
 	// check if libpython.so exists
 	if _, err := os.Stat("/tmp/sandbox-python/python.so"); os.IsNotExist(err) {
 		err := os.MkdirAll("/tmp/sandbox-python", 0755)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, nil, err
 		}
 		err = os.WriteFile("/tmp/sandbox-python/python.so", python_lib, 0755)
 		if err != nil {
-			return nil, nil, err
+			return nil, nil, nil, err
 		}
 	}
 
@@ -44,32 +44,130 @@ func (p *PythonRunner) Run(code string, timeout time.Duration, stdin chan []byte
 	temp_code_path := fmt.Sprintf("/tmp/code/%s.py", temp_code_name)
 	err := os.MkdirAll("/tmp/code", 0755)
 	if err != nil {
-		return nil, nil, err
+		return nil, nil, nil, err
 	}
-	defer os.Remove(temp_code_path)
 	err = os.WriteFile(temp_code_path, []byte(code), 0755)
 	if err != nil {
-		return nil, nil, err
+		return nil, nil, nil, err
 	}
 
+	stdout := make(chan []byte, 1)
+	stderr := make(chan []byte, 1)
+	done_chan := make(chan bool, 1)
+
 	err = p.WithTempDir([]string{
 		temp_code_path,
 		"/tmp/sandbox-python/python.so",
-	}, func() error {
-		syscall.Exec("/usr/bin/python3", []string{
-			"/usr/bin/python3",
-			"-c",
-			string(python_sandbox_fs),
-			temp_code_path,
-			strconv.Itoa(static.SANDBOX_USER_UID),
-			strconv.Itoa(static.SANDBOX_GROUP_ID),
-		}, nil)
+	}, func(root_path string) error {
+		var pipe_fds [2]int
+		// create stdout pipe
+		err = syscall.Pipe2(pipe_fds[0:], syscall.O_CLOEXEC)
+		if err != nil {
+			return err
+		}
+		stdout_reader, stdout_writer := pipe_fds[0], pipe_fds[1]
+		// create stderr pipe
+		err = syscall.Pipe2(pipe_fds[0:], syscall.O_CLOEXEC)
+		if err != nil {
+			return err
+		}
+		stderr_reader, stderr_writer := pipe_fds[0], pipe_fds[1]
+
+		// create a new process
+		pid, _, errno := syscall.RawSyscall(syscall.SYS_FORK, 0, 0, 0)
+		if errno != 0 {
+			return fmt.Errorf("failed to fork: %v", errno)
+		}
+
+		if pid == 0 {
+			// child process
+			syscall.Close(stdout_reader)
+			syscall.Close(stderr_reader)
+
+			// dup the stdout and stderr
+			syscall.Dup2(stdout_writer, int(os.Stdout.Fd()))
+			syscall.Dup2(stderr_writer, int(os.Stderr.Fd()))
+			err := syscall.Exec(
+				"/usr/bin/python3",
+				[]string{
+					"/usr/bin/python3",
+					"-c",
+					string(python_sandbox_fs),
+					temp_code_path,
+					strconv.Itoa(static.SANDBOX_USER_UID),
+					strconv.Itoa(static.SANDBOX_GROUP_ID),
+				},
+				nil,
+			)
+
+			if err != nil {
+				stderr <- []byte(fmt.Sprintf("failed to exec: %v", err))
+				return nil
+			}
+		}
+
+		// read the output
+		go func() {
+			buf := make([]byte, 1024)
+			for {
+				n, err := syscall.Read(stdout_reader, buf)
+				if err != nil {
+					break
+				}
+				stdout <- buf[:n]
+			}
+		}()
+
+		// read the error
+		go func() {
+			buf := make([]byte, 1024)
+			for {
+				n, err := syscall.Read(stderr_reader, buf)
+				if err != nil {
+					break
+				}
+				stderr <- buf[:n]
+			}
+		}()
+
+		// wait for the process to finish
+		done := make(chan error, 1)
+		go func() {
+			var status syscall.WaitStatus
+			_, err := syscall.Wait4(int(pid), &status, 0, nil)
+			if err != nil {
+				done <- err
+				return
+			}
+			done <- nil
+		}()
+
+		go func() {
+			for {
+				select {
+				case <-time.After(timeout):
+					// kill the process
+					syscall.Kill(int(pid), syscall.SIGKILL)
+					stderr <- []byte("timeout\n")
+				case err := <-done:
+					if err != nil {
+						stderr <- []byte(fmt.Sprintf("error: %v\n", err))
+					}
+					os.Remove(temp_code_path)
+					os.RemoveAll(root_path)
+					os.Remove(root_path)
+					done_chan <- true
+					return
+				}
+			}
+		}()
+
 		return nil
 	})
 
 	if err != nil {
-		fmt.Println(err)
+		return nil, nil, nil, err
 	}
 
-	return nil, nil, nil
+	return stdout, stderr, done_chan, nil
 }

+ 2 - 6
internal/core/runner/seccomp.go

@@ -12,7 +12,7 @@ import (
 type SeccompRunner struct {
 }
 
-func (s *SeccompRunner) WithTempDir(paths []string, closures func() error) error {
+func (s *SeccompRunner) WithTempDir(paths []string, closures func(path string) error) error {
 	uuid, err := uuid.NewRandom()
 	if err != nil {
 		return err
@@ -24,10 +24,6 @@ func (s *SeccompRunner) WithTempDir(paths []string, closures func() error) error
 	if err != nil {
 		return err
 	}
-	defer func() {
-		os.RemoveAll(tmp_dir)
-		os.Remove(tmp_dir)
-	}()
 
 	// copy files to tmp dir
 	for _, file_path := range paths {
@@ -62,7 +58,7 @@ func (s *SeccompRunner) WithTempDir(paths []string, closures func() error) error
 		return err
 	}
 
-	err = closures()
+	err = closures(tmp_dir)
 	if err != nil {
 		return err
 	}