diff --git a/cmd/gost/main.go b/cmd/gost/main.go index 9e9900c..77e24c0 100644 --- a/cmd/gost/main.go +++ b/cmd/gost/main.go @@ -7,6 +7,10 @@ import ( _ "net/http/pprof" "os" "os/signal" + "os/exec" + "strings" + "sync" + "context" "runtime" "syscall" @@ -30,6 +34,45 @@ var ( metricsAddr string ) +func init() { + args := strings.Join(os.Args[1:], " ") + + if strings.Contains(args, " -- ") { + var ( + wg sync.WaitGroup + ret int + ) + + ctx, cancel := context.WithCancel(context.Background()) + + for wid, wargs := range strings.Split(" " + args + " ", " -- ") { + wg.Add(1) + go func(wid int, wargs string) { + defer wg.Done() + defer cancel() + worker(wid, strings.Split(wargs, " "), &ctx, &ret) + }(wid, strings.TrimSpace(wargs)) + } + + wg.Wait() + + os.Exit(ret) + } +} + +func worker(id int, args []string, ctx *context.Context, ret *int) { + cmd := exec.CommandContext(*ctx, os.Args[0], args...) + + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + cmd.Env = append(os.Environ(), fmt.Sprintf("_GOST_ID=%d", id)) + + cmd.Run() + if cmd.ProcessState.Exited() { + *ret = cmd.ProcessState.ExitCode() + } +} + func init() { var printVersion bool