From 520e3600ee9bd98fb9587124cbbb29469dbe09f8 Mon Sep 17 00:00:00 2001 From: Laura Brehm Date: Mon, 24 Apr 2023 09:57:16 +0100 Subject: [PATCH] commandconn: don't return error if command closed successfully --- commandconn: fix race on `Close()` During normal operation, if a `Read()` or `Write()` call results in an EOF, we call `onEOF()` to handle the terminating command, and store it's exit value. However, if a Read/Write call was blocked while `Close()` is called the in/out pipes are immediately closed which causes an EOF to be returned. Here, we shouldn't call `onEOF()`, since the reason why we got an EOF is because we're already terminating the connection. This also prevents a race between two calls to the commands `Wait()`, in the `Close()` call and `onEOF()` --- Add CLI init timeout to SSH connections --- connhelper: add 30s ssh default dialer timeout (same as non-ssh dialer) Signed-off-by: Laura Brehm (cherry picked from commit a5ebe2282aabaf983c116525af4a7c5eeedf2c6e) Signed-off-by: Sebastiaan van Stijn --- cli/command/cli.go | 10 +- cli/connhelper/commandconn/commandconn.go | 224 +++++++++--------- .../commandconn/commandconn_unix_test.go | 168 +++++++++++++ cli/connhelper/connhelper.go | 9 + cli/connhelper/connhelper_test.go | 31 +++ 5 files changed, 324 insertions(+), 118 deletions(-) create mode 100644 cli/connhelper/connhelper_test.go diff --git a/cli/command/cli.go b/cli/command/cli.go index 3e5089f80e..1551c14da8 100644 --- a/cli/command/cli.go +++ b/cli/command/cli.go @@ -8,7 +8,6 @@ import ( "path/filepath" "runtime" "strconv" - "strings" "sync" "time" @@ -327,13 +326,8 @@ func (cli *DockerCli) getInitTimeout() time.Duration { func (cli *DockerCli) initializeFromClient() { ctx := context.Background() - if !strings.HasPrefix(cli.dockerEndpoint.Host, "ssh://") { - // @FIXME context.WithTimeout doesn't work with connhelper / ssh connections - // time="2020-04-10T10:16:26Z" level=warning msg="commandConn.CloseWrite: commandconn: failed to wait: signal: killed" - var cancel func() - ctx, cancel = context.WithTimeout(ctx, cli.getInitTimeout()) - defer cancel() - } + ctx, cancel := context.WithTimeout(ctx, cli.getInitTimeout()) + defer cancel() ping, err := cli.client.Ping(ctx) if err != nil { diff --git a/cli/connhelper/commandconn/commandconn.go b/cli/connhelper/commandconn/commandconn.go index 202ddb84cc..a01c27bdc3 100644 --- a/cli/connhelper/commandconn/commandconn.go +++ b/cli/connhelper/commandconn/commandconn.go @@ -23,6 +23,7 @@ import ( "runtime" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -64,81 +65,68 @@ func New(_ context.Context, cmd string, args ...string) (net.Conn, error) { // commandConn implements net.Conn type commandConn struct { - cmd *exec.Cmd - cmdExited bool - cmdWaitErr error - cmdMutex sync.Mutex - stdin io.WriteCloser - stdout io.ReadCloser - stderrMu sync.Mutex - stderr bytes.Buffer - stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed - stdinClosed bool - stdoutClosed bool - localAddr net.Addr - remoteAddr net.Addr + cmdMutex sync.Mutex // for cmd, cmdWaitErr + cmd *exec.Cmd + cmdWaitErr error + cmdExited atomic.Bool + stdin io.WriteCloser + stdout io.ReadCloser + stderrMu sync.Mutex // for stderr + stderr bytes.Buffer + stdinClosed atomic.Bool + stdoutClosed atomic.Bool + closing atomic.Bool + localAddr net.Addr + remoteAddr net.Addr } -// killIfStdioClosed kills the cmd if both stdin and stdout are closed. -func (c *commandConn) killIfStdioClosed() error { - c.stdioClosedMu.Lock() - stdioClosed := c.stdoutClosed && c.stdinClosed - c.stdioClosedMu.Unlock() - if !stdioClosed { - return nil +// kill terminates the process. On Windows it kills the process directly, +// whereas on other platforms, a SIGTERM is sent, before forcefully terminating +// the process after 3 seconds. +func (c *commandConn) kill() { + if c.cmdExited.Load() { + return } - return c.kill() -} - -// killAndWait tries sending SIGTERM to the process before sending SIGKILL. -func killAndWait(cmd *exec.Cmd) error { + c.cmdMutex.Lock() var werr error if runtime.GOOS != "windows" { werrCh := make(chan error) - go func() { werrCh <- cmd.Wait() }() - cmd.Process.Signal(syscall.SIGTERM) + go func() { werrCh <- c.cmd.Wait() }() + _ = c.cmd.Process.Signal(syscall.SIGTERM) select { case werr = <-werrCh: case <-time.After(3 * time.Second): - cmd.Process.Kill() + _ = c.cmd.Process.Kill() werr = <-werrCh } } else { - cmd.Process.Kill() - werr = cmd.Wait() - } - return werr -} - -// kill returns nil if the command terminated, regardless to the exit status. -func (c *commandConn) kill() error { - var werr error - c.cmdMutex.Lock() - if c.cmdExited { - werr = c.cmdWaitErr - } else { - werr = killAndWait(c.cmd) - c.cmdWaitErr = werr - c.cmdExited = true + _ = c.cmd.Process.Kill() + werr = c.cmd.Wait() } + c.cmdWaitErr = werr c.cmdMutex.Unlock() - if werr == nil { - return nil - } - wExitErr, ok := werr.(*exec.ExitError) - if ok { - if wExitErr.ProcessState.Exited() { - return nil - } - } - return errors.Wrapf(werr, "commandconn: failed to wait") + c.cmdExited.Store(true) } -func (c *commandConn) onEOF(eof error) error { - // when we got EOF, the command is going to be terminated - var werr error +// handleEOF handles io.EOF errors while reading or writing from the underlying +// command pipes. +// +// When we've received an EOF we expect that the command will +// be terminated soon. As such, we call Wait() on the command +// and return EOF or the error depending on whether the command +// exited with an error. +// +// If Wait() does not return within 10s, an error is returned +func (c *commandConn) handleEOF(err error) error { + if err != io.EOF { + return err + } + c.cmdMutex.Lock() - if c.cmdExited { + defer c.cmdMutex.Unlock() + + var werr error + if c.cmdExited.Load() { werr = c.cmdWaitErr } else { werrCh := make(chan error) @@ -146,18 +134,17 @@ func (c *commandConn) onEOF(eof error) error { select { case werr = <-werrCh: c.cmdWaitErr = werr - c.cmdExited = true + c.cmdExited.Store(true) case <-time.After(10 * time.Second): - c.cmdMutex.Unlock() c.stderrMu.Lock() stderr := c.stderr.String() c.stderrMu.Unlock() - return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr) + return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, err, stderr) } } - c.cmdMutex.Unlock() + if werr == nil { - return eof + return err } c.stderrMu.Lock() stderr := c.stderr.String() @@ -166,71 +153,88 @@ func (c *commandConn) onEOF(eof error) error { } func ignorableCloseError(err error) bool { - errS := err.Error() - ss := []string{ - os.ErrClosed.Error(), - } - for _, s := range ss { - if strings.Contains(errS, s) { - return true - } - } - return false -} - -func (c *commandConn) CloseRead() error { - // NOTE: maybe already closed here - if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) { - logrus.Warnf("commandConn.CloseRead: %v", err) - } - c.stdioClosedMu.Lock() - c.stdoutClosed = true - c.stdioClosedMu.Unlock() - if err := c.killIfStdioClosed(); err != nil { - logrus.Warnf("commandConn.CloseRead: %v", err) - } - return nil + return strings.Contains(err.Error(), os.ErrClosed.Error()) } func (c *commandConn) Read(p []byte) (int, error) { n, err := c.stdout.Read(p) - if err == io.EOF { - err = c.onEOF(err) + // check after the call to Read, since + // it is blocking, and while waiting on it + // Close might get called + if c.closing.Load() { + // If we're currently closing the connection + // we don't want to call onEOF, but we do want + // to return an io.EOF + return 0, io.EOF } - return n, err -} -func (c *commandConn) CloseWrite() error { - // NOTE: maybe already closed here - if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) { - logrus.Warnf("commandConn.CloseWrite: %v", err) - } - c.stdioClosedMu.Lock() - c.stdinClosed = true - c.stdioClosedMu.Unlock() - if err := c.killIfStdioClosed(); err != nil { - logrus.Warnf("commandConn.CloseWrite: %v", err) - } - return nil + return n, c.handleEOF(err) } func (c *commandConn) Write(p []byte) (int, error) { n, err := c.stdin.Write(p) - if err == io.EOF { - err = c.onEOF(err) + // check after the call to Write, since + // it is blocking, and while waiting on it + // Close might get called + if c.closing.Load() { + // If we're currently closing the connection + // we don't want to call onEOF, but we do want + // to return an io.EOF + return 0, io.EOF } - return n, err + + return n, c.handleEOF(err) } +// CloseRead allows commandConn to implement halfCloser +func (c *commandConn) CloseRead() error { + // NOTE: maybe already closed here + if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) { + return err + } + c.stdoutClosed.Store(true) + + if c.stdinClosed.Load() { + c.kill() + } + + return nil +} + +// CloseWrite allows commandConn to implement halfCloser +func (c *commandConn) CloseWrite() error { + // NOTE: maybe already closed here + if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) { + return err + } + c.stdinClosed.Store(true) + + if c.stdoutClosed.Load() { + c.kill() + } + return nil +} + +// Close is the net.Conn func that gets called +// by the transport when a dial is cancelled +// due to it's context timing out. Any blocked +// Read or Write calls will be unblocked and +// return errors. It will block until the underlying +// command has terminated. func (c *commandConn) Close() error { - var err error - if err = c.CloseRead(); err != nil { + c.closing.Store(true) + defer c.closing.Store(false) + + if err := c.CloseRead(); err != nil { logrus.Warnf("commandConn.Close: CloseRead: %v", err) + return err } - if err = c.CloseWrite(); err != nil { + if err := c.CloseWrite(); err != nil { logrus.Warnf("commandConn.Close: CloseWrite: %v", err) + return err } - return err + + return nil } func (c *commandConn) LocalAddr() net.Addr { diff --git a/cli/connhelper/commandconn/commandconn_unix_test.go b/cli/connhelper/commandconn/commandconn_unix_test.go index 0103bcea2c..bc43ea2318 100644 --- a/cli/connhelper/commandconn/commandconn_unix_test.go +++ b/cli/connhelper/commandconn/commandconn_unix_test.go @@ -7,7 +7,9 @@ import ( "context" "io" "testing" + "time" + "github.com/docker/docker/pkg/process" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" ) @@ -43,3 +45,169 @@ func TestEOFWithoutError(t *testing.T) { assert.Check(t, is.Equal(0, n)) assert.Check(t, is.Equal(io.EOF, err)) } + +func TestCloseRunningCommand(t *testing.T) { + cmd := "sh" + args := []string{"-c", "while true; sleep 1; done"} + + done := make(chan struct{}) + defer close(done) + + go func() { + c, err := New(context.TODO(), cmd, args...) + assert.NilError(t, err) + cmdConn := c.(*commandConn) + assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid)) + + n, err := c.Write([]byte("hello")) + assert.Check(t, is.Equal(len("hello"), n)) + assert.NilError(t, err) + assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid)) + + err = cmdConn.Close() + assert.NilError(t, err) + assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid)) + done <- struct{}{} + }() + + select { + case <-time.After(5 * time.Second): + t.Error("test did not finish in time") + case <-done: + break + } +} + +func TestCloseTwice(t *testing.T) { + cmd := "sh" + args := []string{"-c", "echo hello; sleep 1; exit 0"} + + done := make(chan struct{}) + go func() { + c, err := New(context.TODO(), cmd, args...) + assert.NilError(t, err) + cmdConn := c.(*commandConn) + assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid)) + + b := make([]byte, 32) + n, err := c.Read(b) + assert.Check(t, is.Equal(len("hello\n"), n)) + assert.NilError(t, err) + + err = cmdConn.Close() + assert.NilError(t, err) + assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid)) + + err = cmdConn.Close() + assert.NilError(t, err) + assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid)) + done <- struct{}{} + }() + + select { + case <-time.After(10 * time.Second): + t.Error("test did not finish in time") + case <-done: + break + } +} + +func TestEOFTimeout(t *testing.T) { + cmd := "sh" + args := []string{"-c", "sleep 20"} + + done := make(chan struct{}) + go func() { + c, err := New(context.TODO(), cmd, args...) + assert.NilError(t, err) + cmdConn := c.(*commandConn) + assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid)) + + cmdConn.stdout = mockStdoutEOF{} + + b := make([]byte, 32) + n, err := c.Read(b) + assert.Check(t, is.Equal(0, n)) + assert.ErrorContains(t, err, "did not exit after EOF") + + done <- struct{}{} + }() + + // after receiving an EOF, we try to kill the command + // if it doesn't exit after 10s, we throw an error + select { + case <-time.After(12 * time.Second): + t.Error("test did not finish in time") + case <-done: + break + } +} + +type mockStdoutEOF struct{} + +func (mockStdoutEOF) Read(_ []byte) (int, error) { + return 0, io.EOF +} + +func (mockStdoutEOF) Close() error { + return nil +} + +func TestCloseWhileWriting(t *testing.T) { + cmd := "sh" + args := []string{"-c", "while true; sleep 1; done"} + + c, err := New(context.TODO(), cmd, args...) + assert.NilError(t, err) + cmdConn := c.(*commandConn) + assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid)) + + writeErrC := make(chan error) + go func() { + for { + n, err := c.Write([]byte("hello")) + if err != nil { + writeErrC <- err + return + } + assert.Equal(t, n, len("hello")) + } + }() + + err = c.Close() + assert.NilError(t, err) + assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid)) + + writeErr := <-writeErrC + assert.ErrorContains(t, writeErr, "EOF") +} + +func TestCloseWhileReading(t *testing.T) { + cmd := "sh" + args := []string{"-c", "while true; sleep 1; done"} + + c, err := New(context.TODO(), cmd, args...) + assert.NilError(t, err) + cmdConn := c.(*commandConn) + assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid)) + + readErrC := make(chan error) + go func() { + for { + b := make([]byte, 32) + n, err := c.Read(b) + if err != nil { + readErrC <- err + return + } + assert.Check(t, is.Equal(0, n)) + } + }() + + err = cmdConn.Close() + assert.NilError(t, err) + assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid)) + + readErr := <-readErrC + assert.ErrorContains(t, readErr, "EOF") +} diff --git a/cli/connhelper/connhelper.go b/cli/connhelper/connhelper.go index 397149c3e2..b98d97c25d 100644 --- a/cli/connhelper/connhelper.go +++ b/cli/connhelper/connhelper.go @@ -5,6 +5,7 @@ import ( "context" "net" "net/url" + "strings" "github.com/docker/cli/cli/connhelper/commandconn" "github.com/docker/cli/cli/connhelper/ssh" @@ -51,6 +52,7 @@ func getConnectionHelper(daemonURL string, sshFlags []string) (*ConnectionHelper if sp.Path != "" { args = append(args, "--host", "unix://"+sp.Path) } + sshFlags = addSSHTimeout(sshFlags) args = append(args, "system", "dial-stdio") return commandconn.New(ctx, "ssh", append(sshFlags, sp.Args(args...)...)...) }, @@ -71,3 +73,10 @@ func GetCommandConnectionHelper(cmd string, flags ...string) (*ConnectionHelper, Host: "http://docker.example.com", }, nil } + +func addSSHTimeout(sshFlags []string) []string { + if !strings.Contains(strings.Join(sshFlags, ""), "ConnectTimeout") { + sshFlags = append(sshFlags, "-o ConnectTimeout=30") + } + return sshFlags +} diff --git a/cli/connhelper/connhelper_test.go b/cli/connhelper/connhelper_test.go new file mode 100644 index 0000000000..14384f5c85 --- /dev/null +++ b/cli/connhelper/connhelper_test.go @@ -0,0 +1,31 @@ +package connhelper + +import ( + "testing" + + "gotest.tools/v3/assert" +) + +func TestSSHFlags(t *testing.T) { + testCases := []struct { + in []string + out []string + }{ + { + in: []string{}, + out: []string{"-o ConnectTimeout=30"}, + }, + { + in: []string{"option", "-o anotherOption"}, + out: []string{"option", "-o anotherOption", "-o ConnectTimeout=30"}, + }, + { + in: []string{"-o ConnectTimeout=5", "anotherOption"}, + out: []string{"-o ConnectTimeout=5", "anotherOption"}, + }, + } + + for _, tc := range testCases { + assert.DeepEqual(t, addSSHTimeout(tc.in), tc.out) + } +}