diff --git a/cli/connhelper/commandconn/commandconn.go b/cli/connhelper/commandconn/commandconn.go new file mode 100644 index 0000000000..7e03741fad --- /dev/null +++ b/cli/connhelper/commandconn/commandconn.go @@ -0,0 +1,281 @@ +// Package commandconn provides a net.Conn implementation that can be used for +// proxying (or emulating) stream via a custom command. +// +// For example, to provide an http.Client that can connect to a Docker daemon +// running in a Docker container ("DIND"): +// +// httpClient := &http.Client{ +// Transport: &http.Transport{ +// DialContext: func(ctx context.Context, _network, _addr string) (net.Conn, error) { +// return commandconn.New(ctx, "docker", "exec", "-it", containerID, "docker", "system", "dial-stdio") +// }, +// }, +// } +package commandconn + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "os" + "os/exec" + "runtime" + "strings" + "sync" + "syscall" + "time" + + "github.com/pkg/errors" + "github.com/sirupsen/logrus" +) + +// New returns net.Conn +func New(ctx context.Context, cmd string, args ...string) (net.Conn, error) { + var ( + c commandConn + err error + ) + c.cmd = exec.CommandContext(ctx, cmd, args...) + // we assume that args never contains sensitive information + logrus.Debugf("commandconn: starting %s with %v", cmd, args) + c.cmd.Env = os.Environ() + setPdeathsig(c.cmd) + c.stdin, err = c.cmd.StdinPipe() + if err != nil { + return nil, err + } + c.stdout, err = c.cmd.StdoutPipe() + if err != nil { + return nil, err + } + c.cmd.Stderr = &stderrWriter{ + stderrMu: &c.stderrMu, + stderr: &c.stderr, + debugPrefix: fmt.Sprintf("commandconn (%s):", cmd), + } + c.localAddr = dummyAddr{network: "dummy", s: "dummy-0"} + c.remoteAddr = dummyAddr{network: "dummy", s: "dummy-1"} + return &c, c.cmd.Start() +} + +// commandConn implements net.Conn +type commandConn struct { + cmd *exec.Cmd + cmdExited bool + cmdWaitErr error + cmdMutex sync.Mutex + stdin io.WriteCloser + stdout io.ReadCloser + stderrMu sync.Mutex + stderr bytes.Buffer + stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed + stdinClosed bool + stdoutClosed bool + localAddr net.Addr + remoteAddr net.Addr +} + +// killIfStdioClosed kills the cmd if both stdin and stdout are closed. +func (c *commandConn) killIfStdioClosed() error { + c.stdioClosedMu.Lock() + stdioClosed := c.stdoutClosed && c.stdinClosed + c.stdioClosedMu.Unlock() + if !stdioClosed { + return nil + } + return c.kill() +} + +// killAndWait tries sending SIGTERM to the process before sending SIGKILL. +func killAndWait(cmd *exec.Cmd) error { + var werr error + if runtime.GOOS != "windows" { + werrCh := make(chan error) + go func() { werrCh <- cmd.Wait() }() + cmd.Process.Signal(syscall.SIGTERM) + select { + case werr = <-werrCh: + case <-time.After(3 * time.Second): + cmd.Process.Kill() + werr = <-werrCh + } + } else { + cmd.Process.Kill() + werr = cmd.Wait() + } + return werr +} + +// kill returns nil if the command terminated, regardless to the exit status. +func (c *commandConn) kill() error { + var werr error + c.cmdMutex.Lock() + if c.cmdExited { + werr = c.cmdWaitErr + } else { + werr = killAndWait(c.cmd) + c.cmdWaitErr = werr + c.cmdExited = true + } + c.cmdMutex.Unlock() + if werr == nil { + return nil + } + wExitErr, ok := werr.(*exec.ExitError) + if ok { + if wExitErr.ProcessState.Exited() { + return nil + } + } + return errors.Wrapf(werr, "commandconn: failed to wait") +} + +func (c *commandConn) onEOF(eof error) error { + // when we got EOF, the command is going to be terminated + var werr error + c.cmdMutex.Lock() + if c.cmdExited { + werr = c.cmdWaitErr + } else { + werrCh := make(chan error) + go func() { werrCh <- c.cmd.Wait() }() + select { + case werr = <-werrCh: + c.cmdWaitErr = werr + c.cmdExited = true + case <-time.After(10 * time.Second): + c.cmdMutex.Unlock() + c.stderrMu.Lock() + stderr := c.stderr.String() + c.stderrMu.Unlock() + return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr) + } + } + c.cmdMutex.Unlock() + if werr == nil { + return eof + } + c.stderrMu.Lock() + stderr := c.stderr.String() + c.stderrMu.Unlock() + return errors.Errorf("command %v has exited with %v, please make sure the URL is valid, and Docker 18.09 or later is installed on the remote host: stderr=%s", c.cmd.Args, werr, stderr) +} + +func ignorableCloseError(err error) bool { + errS := err.Error() + ss := []string{ + os.ErrClosed.Error(), + } + for _, s := range ss { + if strings.Contains(errS, s) { + return true + } + } + return false +} + +func (c *commandConn) CloseRead() error { + // NOTE: maybe already closed here + if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) { + logrus.Warnf("commandConn.CloseRead: %v", err) + } + c.stdioClosedMu.Lock() + c.stdoutClosed = true + c.stdioClosedMu.Unlock() + if err := c.killIfStdioClosed(); err != nil { + logrus.Warnf("commandConn.CloseRead: %v", err) + } + return nil +} + +func (c *commandConn) Read(p []byte) (int, error) { + n, err := c.stdout.Read(p) + if err == io.EOF { + err = c.onEOF(err) + } + return n, err +} + +func (c *commandConn) CloseWrite() error { + // NOTE: maybe already closed here + if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) { + logrus.Warnf("commandConn.CloseWrite: %v", err) + } + c.stdioClosedMu.Lock() + c.stdinClosed = true + c.stdioClosedMu.Unlock() + if err := c.killIfStdioClosed(); err != nil { + logrus.Warnf("commandConn.CloseWrite: %v", err) + } + return nil +} + +func (c *commandConn) Write(p []byte) (int, error) { + n, err := c.stdin.Write(p) + if err == io.EOF { + err = c.onEOF(err) + } + return n, err +} + +func (c *commandConn) Close() error { + var err error + if err = c.CloseRead(); err != nil { + logrus.Warnf("commandConn.Close: CloseRead: %v", err) + } + if err = c.CloseWrite(); err != nil { + logrus.Warnf("commandConn.Close: CloseWrite: %v", err) + } + return err +} + +func (c *commandConn) LocalAddr() net.Addr { + return c.localAddr +} +func (c *commandConn) RemoteAddr() net.Addr { + return c.remoteAddr +} +func (c *commandConn) SetDeadline(t time.Time) error { + logrus.Debugf("unimplemented call: SetDeadline(%v)", t) + return nil +} +func (c *commandConn) SetReadDeadline(t time.Time) error { + logrus.Debugf("unimplemented call: SetReadDeadline(%v)", t) + return nil +} +func (c *commandConn) SetWriteDeadline(t time.Time) error { + logrus.Debugf("unimplemented call: SetWriteDeadline(%v)", t) + return nil +} + +type dummyAddr struct { + network string + s string +} + +func (d dummyAddr) Network() string { + return d.network +} + +func (d dummyAddr) String() string { + return d.s +} + +type stderrWriter struct { + stderrMu *sync.Mutex + stderr *bytes.Buffer + debugPrefix string +} + +func (w *stderrWriter) Write(p []byte) (int, error) { + logrus.Debugf("%s%s", w.debugPrefix, string(p)) + w.stderrMu.Lock() + if w.stderr.Len() > 4096 { + w.stderr.Reset() + } + n, err := w.stderr.Write(p) + w.stderrMu.Unlock() + return n, err +} diff --git a/cli/connhelper/connhelper_linux.go b/cli/connhelper/commandconn/commandconn_linux.go similarity index 87% rename from cli/connhelper/connhelper_linux.go rename to cli/connhelper/commandconn/commandconn_linux.go index f138f53675..7d8b122e32 100644 --- a/cli/connhelper/connhelper_linux.go +++ b/cli/connhelper/commandconn/commandconn_linux.go @@ -1,4 +1,4 @@ -package connhelper +package commandconn import ( "os/exec" diff --git a/cli/connhelper/connhelper_nolinux.go b/cli/connhelper/commandconn/commandconn_nolinux.go similarity index 79% rename from cli/connhelper/connhelper_nolinux.go rename to cli/connhelper/commandconn/commandconn_nolinux.go index c8350d9d77..ab07166724 100644 --- a/cli/connhelper/connhelper_nolinux.go +++ b/cli/connhelper/commandconn/commandconn_nolinux.go @@ -1,6 +1,6 @@ // +build !linux -package connhelper +package commandconn import ( "os/exec" diff --git a/cli/connhelper/connhelper_unix_test.go b/cli/connhelper/commandconn/commandconn_unix_test.go similarity index 81% rename from cli/connhelper/connhelper_unix_test.go rename to cli/connhelper/commandconn/commandconn_unix_test.go index c57d655e16..764c647359 100644 --- a/cli/connhelper/connhelper_unix_test.go +++ b/cli/connhelper/commandconn/commandconn_unix_test.go @@ -1,6 +1,6 @@ // +build !windows -package connhelper +package commandconn import ( "context" @@ -12,11 +12,11 @@ import ( ) // For https://github.com/docker/cli/pull/1014#issuecomment-409308139 -func TestCommandConnEOFWithError(t *testing.T) { +func TestEOFWithError(t *testing.T) { ctx := context.TODO() cmd := "sh" args := []string{"-c", "echo hello; echo some error >&2; exit 42"} - c, err := newCommandConn(ctx, cmd, args...) + c, err := New(ctx, cmd, args...) assert.NilError(t, err) b := make([]byte, 32) n, err := c.Read(b) @@ -28,11 +28,11 @@ func TestCommandConnEOFWithError(t *testing.T) { assert.ErrorContains(t, err, "42") } -func TestCommandConnEOFWithoutError(t *testing.T) { +func TestEOFWithoutError(t *testing.T) { ctx := context.TODO() cmd := "sh" args := []string{"-c", "echo hello; echo some debug log >&2; exit 0"} - c, err := newCommandConn(ctx, cmd, args...) + c, err := New(ctx, cmd, args...) assert.NilError(t, err) b := make([]byte, 32) n, err := c.Read(b) diff --git a/cli/connhelper/connhelper.go b/cli/connhelper/connhelper.go index 94cc1d515b..da3640db1a 100644 --- a/cli/connhelper/connhelper.go +++ b/cli/connhelper/connhelper.go @@ -2,23 +2,13 @@ package connhelper import ( - "bytes" "context" - "fmt" - "io" "net" "net/url" - "os" - "os/exec" - "runtime" - "strings" - "sync" - "syscall" - "time" + "github.com/docker/cli/cli/connhelper/commandconn" "github.com/docker/cli/cli/connhelper/ssh" "github.com/pkg/errors" - "github.com/sirupsen/logrus" ) // ConnectionHelper allows to connect to a remote host with custom stream provider binary. @@ -29,7 +19,8 @@ type ConnectionHelper struct { // GetConnectionHelper returns Docker-specific connection helper for the given URL. // GetConnectionHelper returns nil without error when no helper is registered for the scheme. -// URL is like "ssh://me@server01". +// +// ssh://@ URL requires Docker 18.09 or later on the remote host. func GetConnectionHelper(daemonURL string) (*ConnectionHelper, error) { u, err := url.Parse(daemonURL) if err != nil { @@ -37,13 +28,13 @@ func GetConnectionHelper(daemonURL string) (*ConnectionHelper, error) { } switch scheme := u.Scheme; scheme { case "ssh": - sshCmd, sshArgs, err := ssh.New(daemonURL) + sp, err := ssh.ParseURL(daemonURL) if err != nil { - return nil, err + return nil, errors.Wrap(err, "ssh host connection is not valid") } return &ConnectionHelper{ Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) { - return newCommandConn(ctx, sshCmd, sshArgs...) + return commandconn.New(ctx, "ssh", append(sp.Args(), []string{"--", "docker", "system", "dial-stdio"}...)...) }, Host: "http://docker", }, nil @@ -53,260 +44,12 @@ func GetConnectionHelper(daemonURL string) (*ConnectionHelper, error) { return nil, err } -// GetCommandConnectionHelper returns a ConnectionHelp constructed from an arbitrary command. +// GetCommandConnectionHelper returns Docker-specific connection helper constructed from an arbitrary command. func GetCommandConnectionHelper(cmd string, flags ...string) (*ConnectionHelper, error) { return &ConnectionHelper{ Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) { - return newCommandConn(ctx, cmd, flags...) + return commandconn.New(ctx, cmd, flags...) }, Host: "http://docker", }, nil } - -func newCommandConn(ctx context.Context, cmd string, args ...string) (net.Conn, error) { - var ( - c commandConn - err error - ) - c.cmd = exec.CommandContext(ctx, cmd, args...) - // we assume that args never contains sensitive information - logrus.Debugf("connhelper: starting %s with %v", cmd, args) - c.cmd.Env = os.Environ() - setPdeathsig(c.cmd) - c.stdin, err = c.cmd.StdinPipe() - if err != nil { - return nil, err - } - c.stdout, err = c.cmd.StdoutPipe() - if err != nil { - return nil, err - } - c.cmd.Stderr = &stderrWriter{ - stderrMu: &c.stderrMu, - stderr: &c.stderr, - debugPrefix: fmt.Sprintf("connhelper (%s):", cmd), - } - c.localAddr = dummyAddr{network: "dummy", s: "dummy-0"} - c.remoteAddr = dummyAddr{network: "dummy", s: "dummy-1"} - return &c, c.cmd.Start() -} - -// commandConn implements net.Conn -type commandConn struct { - cmd *exec.Cmd - cmdExited bool - cmdWaitErr error - cmdMutex sync.Mutex - stdin io.WriteCloser - stdout io.ReadCloser - stderrMu sync.Mutex - stderr bytes.Buffer - stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed - stdinClosed bool - stdoutClosed bool - localAddr net.Addr - remoteAddr net.Addr -} - -// killIfStdioClosed kills the cmd if both stdin and stdout are closed. -func (c *commandConn) killIfStdioClosed() error { - c.stdioClosedMu.Lock() - stdioClosed := c.stdoutClosed && c.stdinClosed - c.stdioClosedMu.Unlock() - if !stdioClosed { - return nil - } - return c.kill() -} - -// killAndWait tries sending SIGTERM to the process before sending SIGKILL. -func killAndWait(cmd *exec.Cmd) error { - var werr error - if runtime.GOOS != "windows" { - werrCh := make(chan error) - go func() { werrCh <- cmd.Wait() }() - cmd.Process.Signal(syscall.SIGTERM) - select { - case werr = <-werrCh: - case <-time.After(3 * time.Second): - cmd.Process.Kill() - werr = <-werrCh - } - } else { - cmd.Process.Kill() - werr = cmd.Wait() - } - return werr -} - -// kill returns nil if the command terminated, regardless to the exit status. -func (c *commandConn) kill() error { - var werr error - c.cmdMutex.Lock() - if c.cmdExited { - werr = c.cmdWaitErr - } else { - werr = killAndWait(c.cmd) - c.cmdWaitErr = werr - c.cmdExited = true - } - c.cmdMutex.Unlock() - if werr == nil { - return nil - } - wExitErr, ok := werr.(*exec.ExitError) - if ok { - if wExitErr.ProcessState.Exited() { - return nil - } - } - return errors.Wrapf(werr, "connhelper: failed to wait") -} - -func (c *commandConn) onEOF(eof error) error { - // when we got EOF, the command is going to be terminated - var werr error - c.cmdMutex.Lock() - if c.cmdExited { - werr = c.cmdWaitErr - } else { - werrCh := make(chan error) - go func() { werrCh <- c.cmd.Wait() }() - select { - case werr = <-werrCh: - c.cmdWaitErr = werr - c.cmdExited = true - case <-time.After(10 * time.Second): - c.cmdMutex.Unlock() - c.stderrMu.Lock() - stderr := c.stderr.String() - c.stderrMu.Unlock() - return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr) - } - } - c.cmdMutex.Unlock() - if werr == nil { - return eof - } - c.stderrMu.Lock() - stderr := c.stderr.String() - c.stderrMu.Unlock() - return errors.Errorf("command %v has exited with %v, please make sure the URL is valid, and Docker 18.09 or later is installed on the remote host: stderr=%s", c.cmd.Args, werr, stderr) -} - -func ignorableCloseError(err error) bool { - errS := err.Error() - ss := []string{ - os.ErrClosed.Error(), - } - for _, s := range ss { - if strings.Contains(errS, s) { - return true - } - } - return false -} - -func (c *commandConn) CloseRead() error { - // NOTE: maybe already closed here - if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) { - logrus.Warnf("commandConn.CloseRead: %v", err) - } - c.stdioClosedMu.Lock() - c.stdoutClosed = true - c.stdioClosedMu.Unlock() - if err := c.killIfStdioClosed(); err != nil { - logrus.Warnf("commandConn.CloseRead: %v", err) - } - return nil -} - -func (c *commandConn) Read(p []byte) (int, error) { - n, err := c.stdout.Read(p) - if err == io.EOF { - err = c.onEOF(err) - } - return n, err -} - -func (c *commandConn) CloseWrite() error { - // NOTE: maybe already closed here - if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) { - logrus.Warnf("commandConn.CloseWrite: %v", err) - } - c.stdioClosedMu.Lock() - c.stdinClosed = true - c.stdioClosedMu.Unlock() - if err := c.killIfStdioClosed(); err != nil { - logrus.Warnf("commandConn.CloseWrite: %v", err) - } - return nil -} - -func (c *commandConn) Write(p []byte) (int, error) { - n, err := c.stdin.Write(p) - if err == io.EOF { - err = c.onEOF(err) - } - return n, err -} - -func (c *commandConn) Close() error { - var err error - if err = c.CloseRead(); err != nil { - logrus.Warnf("commandConn.Close: CloseRead: %v", err) - } - if err = c.CloseWrite(); err != nil { - logrus.Warnf("commandConn.Close: CloseWrite: %v", err) - } - return err -} - -func (c *commandConn) LocalAddr() net.Addr { - return c.localAddr -} -func (c *commandConn) RemoteAddr() net.Addr { - return c.remoteAddr -} -func (c *commandConn) SetDeadline(t time.Time) error { - logrus.Debugf("unimplemented call: SetDeadline(%v)", t) - return nil -} -func (c *commandConn) SetReadDeadline(t time.Time) error { - logrus.Debugf("unimplemented call: SetReadDeadline(%v)", t) - return nil -} -func (c *commandConn) SetWriteDeadline(t time.Time) error { - logrus.Debugf("unimplemented call: SetWriteDeadline(%v)", t) - return nil -} - -type dummyAddr struct { - network string - s string -} - -func (d dummyAddr) Network() string { - return d.network -} - -func (d dummyAddr) String() string { - return d.s -} - -type stderrWriter struct { - stderrMu *sync.Mutex - stderr *bytes.Buffer - debugPrefix string -} - -func (w *stderrWriter) Write(p []byte) (int, error) { - logrus.Debugf("%s%s", w.debugPrefix, string(p)) - w.stderrMu.Lock() - if w.stderr.Len() > 4096 { - w.stderr.Reset() - } - n, err := w.stderr.Write(p) - w.stderrMu.Unlock() - return n, err -} diff --git a/cli/connhelper/ssh/ssh.go b/cli/connhelper/ssh/ssh.go index f134df1386..06cb983641 100644 --- a/cli/connhelper/ssh/ssh.go +++ b/cli/connhelper/ssh/ssh.go @@ -1,5 +1,4 @@ // Package ssh provides the connection helper for ssh:// URL. -// Requires Docker 18.09 or later on the remote host. package ssh import ( @@ -8,16 +7,8 @@ import ( "github.com/pkg/errors" ) -// New returns cmd and its args -func New(daemonURL string) (string, []string, error) { - sp, err := parseSSHURL(daemonURL) - if err != nil { - return "", nil, errors.Wrap(err, "SSH host connection is not valid") - } - return "ssh", append(sp.Args(), []string{"--", "docker", "system", "dial-stdio"}...), nil -} - -func parseSSHURL(daemonURL string) (*sshSpec, error) { +// ParseURL parses URL +func ParseURL(daemonURL string) (*Spec, error) { u, err := url.Parse(daemonURL) if err != nil { return nil, err @@ -26,19 +17,19 @@ func parseSSHURL(daemonURL string) (*sshSpec, error) { return nil, errors.Errorf("expected scheme ssh, got %q", u.Scheme) } - var sp sshSpec + var sp Spec if u.User != nil { - sp.user = u.User.Username() + sp.User = u.User.Username() if _, ok := u.User.Password(); ok { return nil, errors.New("plain-text password is not supported") } } - sp.host = u.Hostname() - if sp.host == "" { + sp.Host = u.Hostname() + if sp.Host == "" { return nil, errors.Errorf("no host specified") } - sp.port = u.Port() + sp.Port = u.Port() if u.Path != "" { return nil, errors.Errorf("extra path after the host: %q", u.Path) } @@ -51,20 +42,22 @@ func parseSSHURL(daemonURL string) (*sshSpec, error) { return &sp, err } -type sshSpec struct { - user string - host string - port string +// Spec of SSH URL +type Spec struct { + User string + Host string + Port string } -func (sp *sshSpec) Args() []string { +// Args returns args except "ssh" itself and "-- ..." +func (sp *Spec) Args() []string { var args []string - if sp.user != "" { - args = append(args, "-l", sp.user) + if sp.User != "" { + args = append(args, "-l", sp.User) } - if sp.port != "" { - args = append(args, "-p", sp.port) + if sp.Port != "" { + args = append(args, "-p", sp.Port) } - args = append(args, sp.host) + args = append(args, sp.Host) return args } diff --git a/cli/connhelper/ssh/ssh_test.go b/cli/connhelper/ssh/ssh_test.go index 8e1d33d895..60478fc0af 100644 --- a/cli/connhelper/ssh/ssh_test.go +++ b/cli/connhelper/ssh/ssh_test.go @@ -7,7 +7,7 @@ import ( is "gotest.tools/assert/cmp" ) -func TestParseSSHURL(t *testing.T) { +func TestParseURL(t *testing.T) { testCases := []struct { url string expectedArgs []string @@ -53,7 +53,7 @@ func TestParseSSHURL(t *testing.T) { }, } for _, tc := range testCases { - sp, err := parseSSHURL(tc.url) + sp, err := ParseURL(tc.url) if tc.expectedError == "" { assert.NilError(t, err) assert.Check(t, is.DeepEqual(tc.expectedArgs, sp.Args())) diff --git a/e2e/cli-plugins/dial_test.go b/e2e/cli-plugins/dial_test.go index 5c3d14992e..a01feb0345 100644 --- a/e2e/cli-plugins/dial_test.go +++ b/e2e/cli-plugins/dial_test.go @@ -21,6 +21,6 @@ func TestDialStdio(t *testing.T) { cmd := icmd.Command(helloworld, "--config=blah", "--tls", "--log-level", "debug", "helloworld", "--who=foo") res := icmd.RunCmd(cmd, icmd.WithEnv(manager.ReexecEnvvar+"=/bin/true")) res.Assert(t, icmd.Success) - assert.Assert(t, is.Contains(res.Stderr(), `msg="connhelper: starting /bin/true with [--config=blah --tls --log-level debug system dial-stdio]"`)) + assert.Assert(t, is.Contains(res.Stderr(), `msg="commandconn: starting /bin/true with [--config=blah --tls --log-level debug system dial-stdio]"`)) assert.Assert(t, is.Equal(res.Stdout(), "Hello foo!\n")) }