From a5ebe2282aabaf983c116525af4a7c5eeedf2c6e Mon Sep 17 00:00:00 2001 From: Laura Brehm Date: Mon, 24 Apr 2023 09:57:16 +0100 Subject: [PATCH] commandconn: don't return error if command closed successfully --- commandconn: fix race on `Close()` During normal operation, if a `Read()` or `Write()` call results in an EOF, we call `onEOF()` to handle the terminating command, and store it's exit value. However, if a Read/Write call was blocked while `Close()` is called the in/out pipes are immediately closed which causes an EOF to be returned. Here, we shouldn't call `onEOF()`, since the reason why we got an EOF is because we're already terminating the connection. This also prevents a race between two calls to the commands `Wait()`, in the `Close()` call and `onEOF()` --- Add CLI init timeout to SSH connections --- connhelper: add 30s ssh default dialer timeout (same as non-ssh dialer) Signed-off-by: Laura Brehm --- cli/command/cli.go | 10 +- cli/connhelper/commandconn/commandconn.go | 224 +++++++++--------- .../commandconn/commandconn_unix_test.go | 168 +++++++++++++ cli/connhelper/connhelper.go | 9 + cli/connhelper/connhelper_test.go | 31 +++ .../docker/docker/pkg/process/doc.go | 3 + .../docker/docker/pkg/process/process_unix.go | 81 +++++++ .../docker/pkg/process/process_windows.go | 45 ++++ vendor/modules.txt | 1 + 9 files changed, 454 insertions(+), 118 deletions(-) create mode 100644 cli/connhelper/connhelper_test.go create mode 100644 vendor/github.com/docker/docker/pkg/process/doc.go create mode 100644 vendor/github.com/docker/docker/pkg/process/process_unix.go create mode 100644 vendor/github.com/docker/docker/pkg/process/process_windows.go diff --git a/cli/command/cli.go b/cli/command/cli.go index 4d8b9dc406..e18e18f909 100644 --- a/cli/command/cli.go +++ b/cli/command/cli.go @@ -8,7 +8,6 @@ import ( "path/filepath" "runtime" "strconv" - "strings" "sync" "time" @@ -327,13 +326,8 @@ func (cli *DockerCli) getInitTimeout() time.Duration { func (cli *DockerCli) initializeFromClient() { ctx := context.Background() - if !strings.HasPrefix(cli.dockerEndpoint.Host, "ssh://") { - // @FIXME context.WithTimeout doesn't work with connhelper / ssh connections - // time="2020-04-10T10:16:26Z" level=warning msg="commandConn.CloseWrite: commandconn: failed to wait: signal: killed" - var cancel func() - ctx, cancel = context.WithTimeout(ctx, cli.getInitTimeout()) - defer cancel() - } + ctx, cancel := context.WithTimeout(ctx, cli.getInitTimeout()) + defer cancel() ping, err := cli.client.Ping(ctx) if err != nil { diff --git a/cli/connhelper/commandconn/commandconn.go b/cli/connhelper/commandconn/commandconn.go index 95d864e499..2643ca30db 100644 --- a/cli/connhelper/commandconn/commandconn.go +++ b/cli/connhelper/commandconn/commandconn.go @@ -24,6 +24,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -64,81 +65,68 @@ func New(_ context.Context, cmd string, args ...string) (net.Conn, error) { // commandConn implements net.Conn type commandConn struct { - cmd *exec.Cmd - cmdExited bool - cmdWaitErr error - cmdMutex sync.Mutex - stdin io.WriteCloser - stdout io.ReadCloser - stderrMu sync.Mutex - stderr bytes.Buffer - stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed - stdinClosed bool - stdoutClosed bool - localAddr net.Addr - remoteAddr net.Addr + cmdMutex sync.Mutex // for cmd, cmdWaitErr + cmd *exec.Cmd + cmdWaitErr error + cmdExited atomic.Bool + stdin io.WriteCloser + stdout io.ReadCloser + stderrMu sync.Mutex // for stderr + stderr bytes.Buffer + stdinClosed atomic.Bool + stdoutClosed atomic.Bool + closing atomic.Bool + localAddr net.Addr + remoteAddr net.Addr } -// killIfStdioClosed kills the cmd if both stdin and stdout are closed. -func (c *commandConn) killIfStdioClosed() error { - c.stdioClosedMu.Lock() - stdioClosed := c.stdoutClosed && c.stdinClosed - c.stdioClosedMu.Unlock() - if !stdioClosed { - return nil +// kill terminates the process. On Windows it kills the process directly, +// whereas on other platforms, a SIGTERM is sent, before forcefully terminating +// the process after 3 seconds. +func (c *commandConn) kill() { + if c.cmdExited.Load() { + return } - return c.kill() -} - -// killAndWait tries sending SIGTERM to the process before sending SIGKILL. -func killAndWait(cmd *exec.Cmd) error { + c.cmdMutex.Lock() var werr error if runtime.GOOS != "windows" { werrCh := make(chan error) - go func() { werrCh <- cmd.Wait() }() - cmd.Process.Signal(syscall.SIGTERM) + go func() { werrCh <- c.cmd.Wait() }() + _ = c.cmd.Process.Signal(syscall.SIGTERM) select { case werr = <-werrCh: case <-time.After(3 * time.Second): - cmd.Process.Kill() + _ = c.cmd.Process.Kill() werr = <-werrCh } } else { - cmd.Process.Kill() - werr = cmd.Wait() - } - return werr -} - -// kill returns nil if the command terminated, regardless to the exit status. -func (c *commandConn) kill() error { - var werr error - c.cmdMutex.Lock() - if c.cmdExited { - werr = c.cmdWaitErr - } else { - werr = killAndWait(c.cmd) - c.cmdWaitErr = werr - c.cmdExited = true + _ = c.cmd.Process.Kill() + werr = c.cmd.Wait() } + c.cmdWaitErr = werr c.cmdMutex.Unlock() - if werr == nil { - return nil - } - wExitErr, ok := werr.(*exec.ExitError) - if ok { - if wExitErr.ProcessState.Exited() { - return nil - } - } - return errors.Wrapf(werr, "commandconn: failed to wait") + c.cmdExited.Store(true) } -func (c *commandConn) onEOF(eof error) error { - // when we got EOF, the command is going to be terminated - var werr error +// handleEOF handles io.EOF errors while reading or writing from the underlying +// command pipes. +// +// When we've received an EOF we expect that the command will +// be terminated soon. As such, we call Wait() on the command +// and return EOF or the error depending on whether the command +// exited with an error. +// +// If Wait() does not return within 10s, an error is returned +func (c *commandConn) handleEOF(err error) error { + if err != io.EOF { + return err + } + c.cmdMutex.Lock() - if c.cmdExited { + defer c.cmdMutex.Unlock() + + var werr error + if c.cmdExited.Load() { werr = c.cmdWaitErr } else { werrCh := make(chan error) @@ -146,18 +134,17 @@ func (c *commandConn) onEOF(eof error) error { select { case werr = <-werrCh: c.cmdWaitErr = werr - c.cmdExited = true + c.cmdExited.Store(true) case <-time.After(10 * time.Second): - c.cmdMutex.Unlock() c.stderrMu.Lock() stderr := c.stderr.String() c.stderrMu.Unlock() - return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr) + return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, err, stderr) } } - c.cmdMutex.Unlock() + if werr == nil { - return eof + return err } c.stderrMu.Lock() stderr := c.stderr.String() @@ -166,71 +153,88 @@ func (c *commandConn) onEOF(eof error) error { } func ignorableCloseError(err error) bool { - errS := err.Error() - ss := []string{ - os.ErrClosed.Error(), - } - for _, s := range ss { - if strings.Contains(errS, s) { - return true - } - } - return false -} - -func (c *commandConn) CloseRead() error { - // NOTE: maybe already closed here - if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) { - logrus.Warnf("commandConn.CloseRead: %v", err) - } - c.stdioClosedMu.Lock() - c.stdoutClosed = true - c.stdioClosedMu.Unlock() - if err := c.killIfStdioClosed(); err != nil { - logrus.Warnf("commandConn.CloseRead: %v", err) - } - return nil + return strings.Contains(err.Error(), os.ErrClosed.Error()) } func (c *commandConn) Read(p []byte) (int, error) { n, err := c.stdout.Read(p) - if err == io.EOF { - err = c.onEOF(err) + // check after the call to Read, since + // it is blocking, and while waiting on it + // Close might get called + if c.closing.Load() { + // If we're currently closing the connection + // we don't want to call onEOF, but we do want + // to return an io.EOF + return 0, io.EOF } - return n, err -} -func (c *commandConn) CloseWrite() error { - // NOTE: maybe already closed here - if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) { - logrus.Warnf("commandConn.CloseWrite: %v", err) - } - c.stdioClosedMu.Lock() - c.stdinClosed = true - c.stdioClosedMu.Unlock() - if err := c.killIfStdioClosed(); err != nil { - logrus.Warnf("commandConn.CloseWrite: %v", err) - } - return nil + return n, c.handleEOF(err) } func (c *commandConn) Write(p []byte) (int, error) { n, err := c.stdin.Write(p) - if err == io.EOF { - err = c.onEOF(err) + // check after the call to Write, since + // it is blocking, and while waiting on it + // Close might get called + if c.closing.Load() { + // If we're currently closing the connection + // we don't want to call onEOF, but we do want + // to return an io.EOF + return 0, io.EOF } - return n, err + + return n, c.handleEOF(err) } +// CloseRead allows commandConn to implement halfCloser +func (c *commandConn) CloseRead() error { + // NOTE: maybe already closed here + if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) { + return err + } + c.stdoutClosed.Store(true) + + if c.stdinClosed.Load() { + c.kill() + } + + return nil +} + +// CloseWrite allows commandConn to implement halfCloser +func (c *commandConn) CloseWrite() error { + // NOTE: maybe already closed here + if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) { + return err + } + c.stdinClosed.Store(true) + + if c.stdoutClosed.Load() { + c.kill() + } + return nil +} + +// Close is the net.Conn func that gets called +// by the transport when a dial is cancelled +// due to it's context timing out. Any blocked +// Read or Write calls will be unblocked and +// return errors. It will block until the underlying +// command has terminated. func (c *commandConn) Close() error { - var err error - if err = c.CloseRead(); err != nil { + c.closing.Store(true) + defer c.closing.Store(false) + + if err := c.CloseRead(); err != nil { logrus.Warnf("commandConn.Close: CloseRead: %v", err) + return err } - if err = c.CloseWrite(); err != nil { + if err := c.CloseWrite(); err != nil { logrus.Warnf("commandConn.Close: CloseWrite: %v", err) + return err } - return err + + return nil } func (c *commandConn) LocalAddr() net.Addr { diff --git a/cli/connhelper/commandconn/commandconn_unix_test.go b/cli/connhelper/commandconn/commandconn_unix_test.go index e16833c685..cc96e344ea 100644 --- a/cli/connhelper/commandconn/commandconn_unix_test.go +++ b/cli/connhelper/commandconn/commandconn_unix_test.go @@ -6,7 +6,9 @@ import ( "context" "io" "testing" + "time" + "github.com/docker/docker/pkg/process" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" ) @@ -42,3 +44,169 @@ func TestEOFWithoutError(t *testing.T) { assert.Check(t, is.Equal(0, n)) assert.Check(t, is.Equal(io.EOF, err)) } + +func TestCloseRunningCommand(t *testing.T) { + cmd := "sh" + args := []string{"-c", "while true; sleep 1; done"} + + done := make(chan struct{}) + defer close(done) + + go func() { + c, err := New(context.TODO(), cmd, args...) + assert.NilError(t, err) + cmdConn := c.(*commandConn) + assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid)) + + n, err := c.Write([]byte("hello")) + assert.Check(t, is.Equal(len("hello"), n)) + assert.NilError(t, err) + assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid)) + + err = cmdConn.Close() + assert.NilError(t, err) + assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid)) + done <- struct{}{} + }() + + select { + case <-time.After(5 * time.Second): + t.Error("test did not finish in time") + case <-done: + break + } +} + +func TestCloseTwice(t *testing.T) { + cmd := "sh" + args := []string{"-c", "echo hello; sleep 1; exit 0"} + + done := make(chan struct{}) + go func() { + c, err := New(context.TODO(), cmd, args...) + assert.NilError(t, err) + cmdConn := c.(*commandConn) + assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid)) + + b := make([]byte, 32) + n, err := c.Read(b) + assert.Check(t, is.Equal(len("hello\n"), n)) + assert.NilError(t, err) + + err = cmdConn.Close() + assert.NilError(t, err) + assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid)) + + err = cmdConn.Close() + assert.NilError(t, err) + assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid)) + done <- struct{}{} + }() + + select { + case <-time.After(10 * time.Second): + t.Error("test did not finish in time") + case <-done: + break + } +} + +func TestEOFTimeout(t *testing.T) { + cmd := "sh" + args := []string{"-c", "sleep 20"} + + done := make(chan struct{}) + go func() { + c, err := New(context.TODO(), cmd, args...) + assert.NilError(t, err) + cmdConn := c.(*commandConn) + assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid)) + + cmdConn.stdout = mockStdoutEOF{} + + b := make([]byte, 32) + n, err := c.Read(b) + assert.Check(t, is.Equal(0, n)) + assert.ErrorContains(t, err, "did not exit after EOF") + + done <- struct{}{} + }() + + // after receiving an EOF, we try to kill the command + // if it doesn't exit after 10s, we throw an error + select { + case <-time.After(12 * time.Second): + t.Error("test did not finish in time") + case <-done: + break + } +} + +type mockStdoutEOF struct{} + +func (mockStdoutEOF) Read(_ []byte) (int, error) { + return 0, io.EOF +} + +func (mockStdoutEOF) Close() error { + return nil +} + +func TestCloseWhileWriting(t *testing.T) { + cmd := "sh" + args := []string{"-c", "while true; sleep 1; done"} + + c, err := New(context.TODO(), cmd, args...) + assert.NilError(t, err) + cmdConn := c.(*commandConn) + assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid)) + + writeErrC := make(chan error) + go func() { + for { + n, err := c.Write([]byte("hello")) + if err != nil { + writeErrC <- err + return + } + assert.Equal(t, n, len("hello")) + } + }() + + err = c.Close() + assert.NilError(t, err) + assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid)) + + writeErr := <-writeErrC + assert.ErrorContains(t, writeErr, "EOF") +} + +func TestCloseWhileReading(t *testing.T) { + cmd := "sh" + args := []string{"-c", "while true; sleep 1; done"} + + c, err := New(context.TODO(), cmd, args...) + assert.NilError(t, err) + cmdConn := c.(*commandConn) + assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid)) + + readErrC := make(chan error) + go func() { + for { + b := make([]byte, 32) + n, err := c.Read(b) + if err != nil { + readErrC <- err + return + } + assert.Check(t, is.Equal(0, n)) + } + }() + + err = cmdConn.Close() + assert.NilError(t, err) + assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid)) + + readErr := <-readErrC + assert.ErrorContains(t, readErr, "EOF") +} diff --git a/cli/connhelper/connhelper.go b/cli/connhelper/connhelper.go index 397149c3e2..b98d97c25d 100644 --- a/cli/connhelper/connhelper.go +++ b/cli/connhelper/connhelper.go @@ -5,6 +5,7 @@ import ( "context" "net" "net/url" + "strings" "github.com/docker/cli/cli/connhelper/commandconn" "github.com/docker/cli/cli/connhelper/ssh" @@ -51,6 +52,7 @@ func getConnectionHelper(daemonURL string, sshFlags []string) (*ConnectionHelper if sp.Path != "" { args = append(args, "--host", "unix://"+sp.Path) } + sshFlags = addSSHTimeout(sshFlags) args = append(args, "system", "dial-stdio") return commandconn.New(ctx, "ssh", append(sshFlags, sp.Args(args...)...)...) }, @@ -71,3 +73,10 @@ func GetCommandConnectionHelper(cmd string, flags ...string) (*ConnectionHelper, Host: "http://docker.example.com", }, nil } + +func addSSHTimeout(sshFlags []string) []string { + if !strings.Contains(strings.Join(sshFlags, ""), "ConnectTimeout") { + sshFlags = append(sshFlags, "-o ConnectTimeout=30") + } + return sshFlags +} diff --git a/cli/connhelper/connhelper_test.go b/cli/connhelper/connhelper_test.go new file mode 100644 index 0000000000..14384f5c85 --- /dev/null +++ b/cli/connhelper/connhelper_test.go @@ -0,0 +1,31 @@ +package connhelper + +import ( + "testing" + + "gotest.tools/v3/assert" +) + +func TestSSHFlags(t *testing.T) { + testCases := []struct { + in []string + out []string + }{ + { + in: []string{}, + out: []string{"-o ConnectTimeout=30"}, + }, + { + in: []string{"option", "-o anotherOption"}, + out: []string{"option", "-o anotherOption", "-o ConnectTimeout=30"}, + }, + { + in: []string{"-o ConnectTimeout=5", "anotherOption"}, + out: []string{"-o ConnectTimeout=5", "anotherOption"}, + }, + } + + for _, tc := range testCases { + assert.DeepEqual(t, addSSHTimeout(tc.in), tc.out) + } +} diff --git a/vendor/github.com/docker/docker/pkg/process/doc.go b/vendor/github.com/docker/docker/pkg/process/doc.go new file mode 100644 index 0000000000..dae536d7db --- /dev/null +++ b/vendor/github.com/docker/docker/pkg/process/doc.go @@ -0,0 +1,3 @@ +// Package process provides a set of basic functions to manage individual +// processes. +package process diff --git a/vendor/github.com/docker/docker/pkg/process/process_unix.go b/vendor/github.com/docker/docker/pkg/process/process_unix.go new file mode 100644 index 0000000000..baa1693a24 --- /dev/null +++ b/vendor/github.com/docker/docker/pkg/process/process_unix.go @@ -0,0 +1,81 @@ +//go:build !windows + +package process + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "runtime" + "strconv" + + "golang.org/x/sys/unix" +) + +// Alive returns true if process with a given pid is running. It only considers +// positive PIDs; 0 (all processes in the current process group), -1 (all processes +// with a PID larger than 1), and negative (-n, all processes in process group +// "n") values for pid are never considered to be alive. +func Alive(pid int) bool { + if pid < 1 { + return false + } + switch runtime.GOOS { + case "darwin": + // OS X does not have a proc filesystem. Use kill -0 pid to judge if the + // process exists. From KILL(2): https://www.freebsd.org/cgi/man.cgi?query=kill&sektion=2&manpath=OpenDarwin+7.2.1 + // + // Sig may be one of the signals specified in sigaction(2) or it may + // be 0, in which case error checking is performed but no signal is + // actually sent. This can be used to check the validity of pid. + err := unix.Kill(pid, 0) + + // Either the PID was found (no error) or we get an EPERM, which means + // the PID exists, but we don't have permissions to signal it. + return err == nil || err == unix.EPERM + default: + _, err := os.Stat(filepath.Join("/proc", strconv.Itoa(pid))) + return err == nil + } +} + +// Kill force-stops a process. It only considers positive PIDs; 0 (all processes +// in the current process group), -1 (all processes with a PID larger than 1), +// and negative (-n, all processes in process group "n") values for pid are +// ignored. Refer to [KILL(2)] for details. +// +// [KILL(2)]: https://man7.org/linux/man-pages/man2/kill.2.html +func Kill(pid int) error { + if pid < 1 { + return fmt.Errorf("invalid PID (%d): only positive PIDs are allowed", pid) + } + err := unix.Kill(pid, unix.SIGKILL) + if err != nil && err != unix.ESRCH { + return err + } + return nil +} + +// Zombie return true if process has a state with "Z". It only considers positive +// PIDs; 0 (all processes in the current process group), -1 (all processes with +// a PID larger than 1), and negative (-n, all processes in process group "n") +// values for pid are ignored. Refer to [PROC(5)] for details. +// +// [PROC(5)]: https://man7.org/linux/man-pages/man5/proc.5.html +func Zombie(pid int) (bool, error) { + if pid < 1 { + return false, nil + } + data, err := os.ReadFile(fmt.Sprintf("/proc/%d/stat", pid)) + if err != nil { + if os.IsNotExist(err) { + return false, nil + } + return false, err + } + if cols := bytes.SplitN(data, []byte(" "), 4); len(cols) >= 3 && string(cols[2]) == "Z" { + return true, nil + } + return false, nil +} diff --git a/vendor/github.com/docker/docker/pkg/process/process_windows.go b/vendor/github.com/docker/docker/pkg/process/process_windows.go new file mode 100644 index 0000000000..2dd57e8254 --- /dev/null +++ b/vendor/github.com/docker/docker/pkg/process/process_windows.go @@ -0,0 +1,45 @@ +package process + +import ( + "os" + + "golang.org/x/sys/windows" +) + +// Alive returns true if process with a given pid is running. +func Alive(pid int) bool { + h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, uint32(pid)) + if err != nil { + return false + } + var c uint32 + err = windows.GetExitCodeProcess(h, &c) + _ = windows.CloseHandle(h) + if err != nil { + // From the GetExitCodeProcess function (processthreadsapi.h) API docs: + // https://learn.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-getexitcodeprocess + // + // The GetExitCodeProcess function returns a valid error code defined by the + // application only after the thread terminates. Therefore, an application should + // not use STILL_ACTIVE (259) as an error code (STILL_ACTIVE is a macro for + // STATUS_PENDING (minwinbase.h)). If a thread returns STILL_ACTIVE (259) as + // an error code, then applications that test for that value could interpret it + // to mean that the thread is still running, and continue to test for the + // completion of the thread after the thread has terminated, which could put + // the application into an infinite loop. + return c == uint32(windows.STATUS_PENDING) + } + return true +} + +// Kill force-stops a process. +func Kill(pid int) error { + p, err := os.FindProcess(pid) + if err == nil { + err = p.Kill() + if err != nil && err != os.ErrProcessDone { + return err + } + } + return nil +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 6047a1a7fb..a306c03abc 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -69,6 +69,7 @@ github.com/docker/docker/pkg/ioutils github.com/docker/docker/pkg/jsonmessage github.com/docker/docker/pkg/longpath github.com/docker/docker/pkg/pools +github.com/docker/docker/pkg/process github.com/docker/docker/pkg/progress github.com/docker/docker/pkg/stdcopy github.com/docker/docker/pkg/streamformatter