commandconn: don't return error if command closed successfully

---
commandconn: fix race on `Close()`

During normal operation, if a `Read()` or `Write()` call results
in an EOF, we call `onEOF()` to handle the terminating command,
and store it's exit value.

However, if a Read/Write call was blocked while `Close()` is called
the in/out pipes are immediately closed which causes an EOF to be
returned. Here, we shouldn't call `onEOF()`, since the reason why
we got an EOF is because we're already terminating the connection.
This also prevents a race between two calls to the commands `Wait()`,
in the `Close()` call and `onEOF()`

---
Add CLI init timeout to SSH connections

---
connhelper: add 30s ssh default dialer timeout

(same as non-ssh dialer)

Signed-off-by: Laura Brehm <laurabrehm@hey.com>
This commit is contained in:
Laura Brehm 2023-04-24 09:57:16 +01:00
parent 20923dfbc7
commit a5ebe2282a
No known key found for this signature in database
GPG Key ID: 526E3FC49260D47A
9 changed files with 454 additions and 118 deletions

View File

@ -8,7 +8,6 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv" "strconv"
"strings"
"sync" "sync"
"time" "time"
@ -327,13 +326,8 @@ func (cli *DockerCli) getInitTimeout() time.Duration {
func (cli *DockerCli) initializeFromClient() { func (cli *DockerCli) initializeFromClient() {
ctx := context.Background() ctx := context.Background()
if !strings.HasPrefix(cli.dockerEndpoint.Host, "ssh://") { ctx, cancel := context.WithTimeout(ctx, cli.getInitTimeout())
// @FIXME context.WithTimeout doesn't work with connhelper / ssh connections defer cancel()
// time="2020-04-10T10:16:26Z" level=warning msg="commandConn.CloseWrite: commandconn: failed to wait: signal: killed"
var cancel func()
ctx, cancel = context.WithTimeout(ctx, cli.getInitTimeout())
defer cancel()
}
ping, err := cli.client.Ping(ctx) ping, err := cli.client.Ping(ctx)
if err != nil { if err != nil {

View File

@ -24,6 +24,7 @@ import (
"runtime" "runtime"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"syscall" "syscall"
"time" "time"
@ -64,81 +65,68 @@ func New(_ context.Context, cmd string, args ...string) (net.Conn, error) {
// commandConn implements net.Conn // commandConn implements net.Conn
type commandConn struct { type commandConn struct {
cmd *exec.Cmd cmdMutex sync.Mutex // for cmd, cmdWaitErr
cmdExited bool cmd *exec.Cmd
cmdWaitErr error cmdWaitErr error
cmdMutex sync.Mutex cmdExited atomic.Bool
stdin io.WriteCloser stdin io.WriteCloser
stdout io.ReadCloser stdout io.ReadCloser
stderrMu sync.Mutex stderrMu sync.Mutex // for stderr
stderr bytes.Buffer stderr bytes.Buffer
stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed stdinClosed atomic.Bool
stdinClosed bool stdoutClosed atomic.Bool
stdoutClosed bool closing atomic.Bool
localAddr net.Addr localAddr net.Addr
remoteAddr net.Addr remoteAddr net.Addr
} }
// killIfStdioClosed kills the cmd if both stdin and stdout are closed. // kill terminates the process. On Windows it kills the process directly,
func (c *commandConn) killIfStdioClosed() error { // whereas on other platforms, a SIGTERM is sent, before forcefully terminating
c.stdioClosedMu.Lock() // the process after 3 seconds.
stdioClosed := c.stdoutClosed && c.stdinClosed func (c *commandConn) kill() {
c.stdioClosedMu.Unlock() if c.cmdExited.Load() {
if !stdioClosed { return
return nil
} }
return c.kill() c.cmdMutex.Lock()
}
// killAndWait tries sending SIGTERM to the process before sending SIGKILL.
func killAndWait(cmd *exec.Cmd) error {
var werr error var werr error
if runtime.GOOS != "windows" { if runtime.GOOS != "windows" {
werrCh := make(chan error) werrCh := make(chan error)
go func() { werrCh <- cmd.Wait() }() go func() { werrCh <- c.cmd.Wait() }()
cmd.Process.Signal(syscall.SIGTERM) _ = c.cmd.Process.Signal(syscall.SIGTERM)
select { select {
case werr = <-werrCh: case werr = <-werrCh:
case <-time.After(3 * time.Second): case <-time.After(3 * time.Second):
cmd.Process.Kill() _ = c.cmd.Process.Kill()
werr = <-werrCh werr = <-werrCh
} }
} else { } else {
cmd.Process.Kill() _ = c.cmd.Process.Kill()
werr = cmd.Wait() werr = c.cmd.Wait()
}
return werr
}
// kill returns nil if the command terminated, regardless to the exit status.
func (c *commandConn) kill() error {
var werr error
c.cmdMutex.Lock()
if c.cmdExited {
werr = c.cmdWaitErr
} else {
werr = killAndWait(c.cmd)
c.cmdWaitErr = werr
c.cmdExited = true
} }
c.cmdWaitErr = werr
c.cmdMutex.Unlock() c.cmdMutex.Unlock()
if werr == nil { c.cmdExited.Store(true)
return nil
}
wExitErr, ok := werr.(*exec.ExitError)
if ok {
if wExitErr.ProcessState.Exited() {
return nil
}
}
return errors.Wrapf(werr, "commandconn: failed to wait")
} }
func (c *commandConn) onEOF(eof error) error { // handleEOF handles io.EOF errors while reading or writing from the underlying
// when we got EOF, the command is going to be terminated // command pipes.
var werr error //
// When we've received an EOF we expect that the command will
// be terminated soon. As such, we call Wait() on the command
// and return EOF or the error depending on whether the command
// exited with an error.
//
// If Wait() does not return within 10s, an error is returned
func (c *commandConn) handleEOF(err error) error {
if err != io.EOF {
return err
}
c.cmdMutex.Lock() c.cmdMutex.Lock()
if c.cmdExited { defer c.cmdMutex.Unlock()
var werr error
if c.cmdExited.Load() {
werr = c.cmdWaitErr werr = c.cmdWaitErr
} else { } else {
werrCh := make(chan error) werrCh := make(chan error)
@ -146,18 +134,17 @@ func (c *commandConn) onEOF(eof error) error {
select { select {
case werr = <-werrCh: case werr = <-werrCh:
c.cmdWaitErr = werr c.cmdWaitErr = werr
c.cmdExited = true c.cmdExited.Store(true)
case <-time.After(10 * time.Second): case <-time.After(10 * time.Second):
c.cmdMutex.Unlock()
c.stderrMu.Lock() c.stderrMu.Lock()
stderr := c.stderr.String() stderr := c.stderr.String()
c.stderrMu.Unlock() c.stderrMu.Unlock()
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr) return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, err, stderr)
} }
} }
c.cmdMutex.Unlock()
if werr == nil { if werr == nil {
return eof return err
} }
c.stderrMu.Lock() c.stderrMu.Lock()
stderr := c.stderr.String() stderr := c.stderr.String()
@ -166,71 +153,88 @@ func (c *commandConn) onEOF(eof error) error {
} }
func ignorableCloseError(err error) bool { func ignorableCloseError(err error) bool {
errS := err.Error() return strings.Contains(err.Error(), os.ErrClosed.Error())
ss := []string{
os.ErrClosed.Error(),
}
for _, s := range ss {
if strings.Contains(errS, s) {
return true
}
}
return false
}
func (c *commandConn) CloseRead() error {
// NOTE: maybe already closed here
if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) {
logrus.Warnf("commandConn.CloseRead: %v", err)
}
c.stdioClosedMu.Lock()
c.stdoutClosed = true
c.stdioClosedMu.Unlock()
if err := c.killIfStdioClosed(); err != nil {
logrus.Warnf("commandConn.CloseRead: %v", err)
}
return nil
} }
func (c *commandConn) Read(p []byte) (int, error) { func (c *commandConn) Read(p []byte) (int, error) {
n, err := c.stdout.Read(p) n, err := c.stdout.Read(p)
if err == io.EOF { // check after the call to Read, since
err = c.onEOF(err) // it is blocking, and while waiting on it
// Close might get called
if c.closing.Load() {
// If we're currently closing the connection
// we don't want to call onEOF, but we do want
// to return an io.EOF
return 0, io.EOF
} }
return n, err
}
func (c *commandConn) CloseWrite() error { return n, c.handleEOF(err)
// NOTE: maybe already closed here
if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) {
logrus.Warnf("commandConn.CloseWrite: %v", err)
}
c.stdioClosedMu.Lock()
c.stdinClosed = true
c.stdioClosedMu.Unlock()
if err := c.killIfStdioClosed(); err != nil {
logrus.Warnf("commandConn.CloseWrite: %v", err)
}
return nil
} }
func (c *commandConn) Write(p []byte) (int, error) { func (c *commandConn) Write(p []byte) (int, error) {
n, err := c.stdin.Write(p) n, err := c.stdin.Write(p)
if err == io.EOF { // check after the call to Write, since
err = c.onEOF(err) // it is blocking, and while waiting on it
// Close might get called
if c.closing.Load() {
// If we're currently closing the connection
// we don't want to call onEOF, but we do want
// to return an io.EOF
return 0, io.EOF
} }
return n, err
return n, c.handleEOF(err)
} }
// CloseRead allows commandConn to implement halfCloser
func (c *commandConn) CloseRead() error {
// NOTE: maybe already closed here
if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) {
return err
}
c.stdoutClosed.Store(true)
if c.stdinClosed.Load() {
c.kill()
}
return nil
}
// CloseWrite allows commandConn to implement halfCloser
func (c *commandConn) CloseWrite() error {
// NOTE: maybe already closed here
if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) {
return err
}
c.stdinClosed.Store(true)
if c.stdoutClosed.Load() {
c.kill()
}
return nil
}
// Close is the net.Conn func that gets called
// by the transport when a dial is cancelled
// due to it's context timing out. Any blocked
// Read or Write calls will be unblocked and
// return errors. It will block until the underlying
// command has terminated.
func (c *commandConn) Close() error { func (c *commandConn) Close() error {
var err error c.closing.Store(true)
if err = c.CloseRead(); err != nil { defer c.closing.Store(false)
if err := c.CloseRead(); err != nil {
logrus.Warnf("commandConn.Close: CloseRead: %v", err) logrus.Warnf("commandConn.Close: CloseRead: %v", err)
return err
} }
if err = c.CloseWrite(); err != nil { if err := c.CloseWrite(); err != nil {
logrus.Warnf("commandConn.Close: CloseWrite: %v", err) logrus.Warnf("commandConn.Close: CloseWrite: %v", err)
return err
} }
return err
return nil
} }
func (c *commandConn) LocalAddr() net.Addr { func (c *commandConn) LocalAddr() net.Addr {

View File

@ -6,7 +6,9 @@ import (
"context" "context"
"io" "io"
"testing" "testing"
"time"
"github.com/docker/docker/pkg/process"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
is "gotest.tools/v3/assert/cmp" is "gotest.tools/v3/assert/cmp"
) )
@ -42,3 +44,169 @@ func TestEOFWithoutError(t *testing.T) {
assert.Check(t, is.Equal(0, n)) assert.Check(t, is.Equal(0, n))
assert.Check(t, is.Equal(io.EOF, err)) assert.Check(t, is.Equal(io.EOF, err))
} }
func TestCloseRunningCommand(t *testing.T) {
cmd := "sh"
args := []string{"-c", "while true; sleep 1; done"}
done := make(chan struct{})
defer close(done)
go func() {
c, err := New(context.TODO(), cmd, args...)
assert.NilError(t, err)
cmdConn := c.(*commandConn)
assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid))
n, err := c.Write([]byte("hello"))
assert.Check(t, is.Equal(len("hello"), n))
assert.NilError(t, err)
assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid))
err = cmdConn.Close()
assert.NilError(t, err)
assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid))
done <- struct{}{}
}()
select {
case <-time.After(5 * time.Second):
t.Error("test did not finish in time")
case <-done:
break
}
}
func TestCloseTwice(t *testing.T) {
cmd := "sh"
args := []string{"-c", "echo hello; sleep 1; exit 0"}
done := make(chan struct{})
go func() {
c, err := New(context.TODO(), cmd, args...)
assert.NilError(t, err)
cmdConn := c.(*commandConn)
assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid))
b := make([]byte, 32)
n, err := c.Read(b)
assert.Check(t, is.Equal(len("hello\n"), n))
assert.NilError(t, err)
err = cmdConn.Close()
assert.NilError(t, err)
assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid))
err = cmdConn.Close()
assert.NilError(t, err)
assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid))
done <- struct{}{}
}()
select {
case <-time.After(10 * time.Second):
t.Error("test did not finish in time")
case <-done:
break
}
}
func TestEOFTimeout(t *testing.T) {
cmd := "sh"
args := []string{"-c", "sleep 20"}
done := make(chan struct{})
go func() {
c, err := New(context.TODO(), cmd, args...)
assert.NilError(t, err)
cmdConn := c.(*commandConn)
assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid))
cmdConn.stdout = mockStdoutEOF{}
b := make([]byte, 32)
n, err := c.Read(b)
assert.Check(t, is.Equal(0, n))
assert.ErrorContains(t, err, "did not exit after EOF")
done <- struct{}{}
}()
// after receiving an EOF, we try to kill the command
// if it doesn't exit after 10s, we throw an error
select {
case <-time.After(12 * time.Second):
t.Error("test did not finish in time")
case <-done:
break
}
}
type mockStdoutEOF struct{}
func (mockStdoutEOF) Read(_ []byte) (int, error) {
return 0, io.EOF
}
func (mockStdoutEOF) Close() error {
return nil
}
func TestCloseWhileWriting(t *testing.T) {
cmd := "sh"
args := []string{"-c", "while true; sleep 1; done"}
c, err := New(context.TODO(), cmd, args...)
assert.NilError(t, err)
cmdConn := c.(*commandConn)
assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid))
writeErrC := make(chan error)
go func() {
for {
n, err := c.Write([]byte("hello"))
if err != nil {
writeErrC <- err
return
}
assert.Equal(t, n, len("hello"))
}
}()
err = c.Close()
assert.NilError(t, err)
assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid))
writeErr := <-writeErrC
assert.ErrorContains(t, writeErr, "EOF")
}
func TestCloseWhileReading(t *testing.T) {
cmd := "sh"
args := []string{"-c", "while true; sleep 1; done"}
c, err := New(context.TODO(), cmd, args...)
assert.NilError(t, err)
cmdConn := c.(*commandConn)
assert.Check(t, process.Alive(cmdConn.cmd.Process.Pid))
readErrC := make(chan error)
go func() {
for {
b := make([]byte, 32)
n, err := c.Read(b)
if err != nil {
readErrC <- err
return
}
assert.Check(t, is.Equal(0, n))
}
}()
err = cmdConn.Close()
assert.NilError(t, err)
assert.Check(t, !process.Alive(cmdConn.cmd.Process.Pid))
readErr := <-readErrC
assert.ErrorContains(t, readErr, "EOF")
}

View File

@ -5,6 +5,7 @@ import (
"context" "context"
"net" "net"
"net/url" "net/url"
"strings"
"github.com/docker/cli/cli/connhelper/commandconn" "github.com/docker/cli/cli/connhelper/commandconn"
"github.com/docker/cli/cli/connhelper/ssh" "github.com/docker/cli/cli/connhelper/ssh"
@ -51,6 +52,7 @@ func getConnectionHelper(daemonURL string, sshFlags []string) (*ConnectionHelper
if sp.Path != "" { if sp.Path != "" {
args = append(args, "--host", "unix://"+sp.Path) args = append(args, "--host", "unix://"+sp.Path)
} }
sshFlags = addSSHTimeout(sshFlags)
args = append(args, "system", "dial-stdio") args = append(args, "system", "dial-stdio")
return commandconn.New(ctx, "ssh", append(sshFlags, sp.Args(args...)...)...) return commandconn.New(ctx, "ssh", append(sshFlags, sp.Args(args...)...)...)
}, },
@ -71,3 +73,10 @@ func GetCommandConnectionHelper(cmd string, flags ...string) (*ConnectionHelper,
Host: "http://docker.example.com", Host: "http://docker.example.com",
}, nil }, nil
} }
func addSSHTimeout(sshFlags []string) []string {
if !strings.Contains(strings.Join(sshFlags, ""), "ConnectTimeout") {
sshFlags = append(sshFlags, "-o ConnectTimeout=30")
}
return sshFlags
}

View File

@ -0,0 +1,31 @@
package connhelper
import (
"testing"
"gotest.tools/v3/assert"
)
func TestSSHFlags(t *testing.T) {
testCases := []struct {
in []string
out []string
}{
{
in: []string{},
out: []string{"-o ConnectTimeout=30"},
},
{
in: []string{"option", "-o anotherOption"},
out: []string{"option", "-o anotherOption", "-o ConnectTimeout=30"},
},
{
in: []string{"-o ConnectTimeout=5", "anotherOption"},
out: []string{"-o ConnectTimeout=5", "anotherOption"},
},
}
for _, tc := range testCases {
assert.DeepEqual(t, addSSHTimeout(tc.in), tc.out)
}
}

3
vendor/github.com/docker/docker/pkg/process/doc.go generated vendored Normal file
View File

@ -0,0 +1,3 @@
// Package process provides a set of basic functions to manage individual
// processes.
package process

View File

@ -0,0 +1,81 @@
//go:build !windows
package process
import (
"bytes"
"fmt"
"os"
"path/filepath"
"runtime"
"strconv"
"golang.org/x/sys/unix"
)
// Alive returns true if process with a given pid is running. It only considers
// positive PIDs; 0 (all processes in the current process group), -1 (all processes
// with a PID larger than 1), and negative (-n, all processes in process group
// "n") values for pid are never considered to be alive.
func Alive(pid int) bool {
if pid < 1 {
return false
}
switch runtime.GOOS {
case "darwin":
// OS X does not have a proc filesystem. Use kill -0 pid to judge if the
// process exists. From KILL(2): https://www.freebsd.org/cgi/man.cgi?query=kill&sektion=2&manpath=OpenDarwin+7.2.1
//
// Sig may be one of the signals specified in sigaction(2) or it may
// be 0, in which case error checking is performed but no signal is
// actually sent. This can be used to check the validity of pid.
err := unix.Kill(pid, 0)
// Either the PID was found (no error) or we get an EPERM, which means
// the PID exists, but we don't have permissions to signal it.
return err == nil || err == unix.EPERM
default:
_, err := os.Stat(filepath.Join("/proc", strconv.Itoa(pid)))
return err == nil
}
}
// Kill force-stops a process. It only considers positive PIDs; 0 (all processes
// in the current process group), -1 (all processes with a PID larger than 1),
// and negative (-n, all processes in process group "n") values for pid are
// ignored. Refer to [KILL(2)] for details.
//
// [KILL(2)]: https://man7.org/linux/man-pages/man2/kill.2.html
func Kill(pid int) error {
if pid < 1 {
return fmt.Errorf("invalid PID (%d): only positive PIDs are allowed", pid)
}
err := unix.Kill(pid, unix.SIGKILL)
if err != nil && err != unix.ESRCH {
return err
}
return nil
}
// Zombie return true if process has a state with "Z". It only considers positive
// PIDs; 0 (all processes in the current process group), -1 (all processes with
// a PID larger than 1), and negative (-n, all processes in process group "n")
// values for pid are ignored. Refer to [PROC(5)] for details.
//
// [PROC(5)]: https://man7.org/linux/man-pages/man5/proc.5.html
func Zombie(pid int) (bool, error) {
if pid < 1 {
return false, nil
}
data, err := os.ReadFile(fmt.Sprintf("/proc/%d/stat", pid))
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, err
}
if cols := bytes.SplitN(data, []byte(" "), 4); len(cols) >= 3 && string(cols[2]) == "Z" {
return true, nil
}
return false, nil
}

View File

@ -0,0 +1,45 @@
package process
import (
"os"
"golang.org/x/sys/windows"
)
// Alive returns true if process with a given pid is running.
func Alive(pid int) bool {
h, err := windows.OpenProcess(windows.PROCESS_QUERY_LIMITED_INFORMATION, false, uint32(pid))
if err != nil {
return false
}
var c uint32
err = windows.GetExitCodeProcess(h, &c)
_ = windows.CloseHandle(h)
if err != nil {
// From the GetExitCodeProcess function (processthreadsapi.h) API docs:
// https://learn.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-getexitcodeprocess
//
// The GetExitCodeProcess function returns a valid error code defined by the
// application only after the thread terminates. Therefore, an application should
// not use STILL_ACTIVE (259) as an error code (STILL_ACTIVE is a macro for
// STATUS_PENDING (minwinbase.h)). If a thread returns STILL_ACTIVE (259) as
// an error code, then applications that test for that value could interpret it
// to mean that the thread is still running, and continue to test for the
// completion of the thread after the thread has terminated, which could put
// the application into an infinite loop.
return c == uint32(windows.STATUS_PENDING)
}
return true
}
// Kill force-stops a process.
func Kill(pid int) error {
p, err := os.FindProcess(pid)
if err == nil {
err = p.Kill()
if err != nil && err != os.ErrProcessDone {
return err
}
}
return nil
}

1
vendor/modules.txt vendored
View File

@ -69,6 +69,7 @@ github.com/docker/docker/pkg/ioutils
github.com/docker/docker/pkg/jsonmessage github.com/docker/docker/pkg/jsonmessage
github.com/docker/docker/pkg/longpath github.com/docker/docker/pkg/longpath
github.com/docker/docker/pkg/pools github.com/docker/docker/pkg/pools
github.com/docker/docker/pkg/process
github.com/docker/docker/pkg/progress github.com/docker/docker/pkg/progress
github.com/docker/docker/pkg/stdcopy github.com/docker/docker/pkg/stdcopy
github.com/docker/docker/pkg/streamformatter github.com/docker/docker/pkg/streamformatter