mirror of https://github.com/docker/cli.git
connhelper: export functions for other projects
Exposed functions are planned to be used by `buildctl`: https://github.com/moby/buildkit/issues/769 Signed-off-by: Akihiro Suda <suda.akihiro@lab.ntt.co.jp>
This commit is contained in:
parent
ee94f72e2c
commit
dbe7afbd04
|
@ -0,0 +1,281 @@
|
||||||
|
// Package commandconn provides a net.Conn implementation that can be used for
|
||||||
|
// proxying (or emulating) stream via a custom command.
|
||||||
|
//
|
||||||
|
// For example, to provide an http.Client that can connect to a Docker daemon
|
||||||
|
// running in a Docker container ("DIND"):
|
||||||
|
//
|
||||||
|
// httpClient := &http.Client{
|
||||||
|
// Transport: &http.Transport{
|
||||||
|
// DialContext: func(ctx context.Context, _network, _addr string) (net.Conn, error) {
|
||||||
|
// return commandconn.New(ctx, "docker", "exec", "-it", containerID, "docker", "system", "dial-stdio")
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// }
|
||||||
|
package commandconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// New returns net.Conn
|
||||||
|
func New(ctx context.Context, cmd string, args ...string) (net.Conn, error) {
|
||||||
|
var (
|
||||||
|
c commandConn
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
c.cmd = exec.CommandContext(ctx, cmd, args...)
|
||||||
|
// we assume that args never contains sensitive information
|
||||||
|
logrus.Debugf("commandconn: starting %s with %v", cmd, args)
|
||||||
|
c.cmd.Env = os.Environ()
|
||||||
|
setPdeathsig(c.cmd)
|
||||||
|
c.stdin, err = c.cmd.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.stdout, err = c.cmd.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
c.cmd.Stderr = &stderrWriter{
|
||||||
|
stderrMu: &c.stderrMu,
|
||||||
|
stderr: &c.stderr,
|
||||||
|
debugPrefix: fmt.Sprintf("commandconn (%s):", cmd),
|
||||||
|
}
|
||||||
|
c.localAddr = dummyAddr{network: "dummy", s: "dummy-0"}
|
||||||
|
c.remoteAddr = dummyAddr{network: "dummy", s: "dummy-1"}
|
||||||
|
return &c, c.cmd.Start()
|
||||||
|
}
|
||||||
|
|
||||||
|
// commandConn implements net.Conn
|
||||||
|
type commandConn struct {
|
||||||
|
cmd *exec.Cmd
|
||||||
|
cmdExited bool
|
||||||
|
cmdWaitErr error
|
||||||
|
cmdMutex sync.Mutex
|
||||||
|
stdin io.WriteCloser
|
||||||
|
stdout io.ReadCloser
|
||||||
|
stderrMu sync.Mutex
|
||||||
|
stderr bytes.Buffer
|
||||||
|
stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed
|
||||||
|
stdinClosed bool
|
||||||
|
stdoutClosed bool
|
||||||
|
localAddr net.Addr
|
||||||
|
remoteAddr net.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// killIfStdioClosed kills the cmd if both stdin and stdout are closed.
|
||||||
|
func (c *commandConn) killIfStdioClosed() error {
|
||||||
|
c.stdioClosedMu.Lock()
|
||||||
|
stdioClosed := c.stdoutClosed && c.stdinClosed
|
||||||
|
c.stdioClosedMu.Unlock()
|
||||||
|
if !stdioClosed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return c.kill()
|
||||||
|
}
|
||||||
|
|
||||||
|
// killAndWait tries sending SIGTERM to the process before sending SIGKILL.
|
||||||
|
func killAndWait(cmd *exec.Cmd) error {
|
||||||
|
var werr error
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
werrCh := make(chan error)
|
||||||
|
go func() { werrCh <- cmd.Wait() }()
|
||||||
|
cmd.Process.Signal(syscall.SIGTERM)
|
||||||
|
select {
|
||||||
|
case werr = <-werrCh:
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
cmd.Process.Kill()
|
||||||
|
werr = <-werrCh
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cmd.Process.Kill()
|
||||||
|
werr = cmd.Wait()
|
||||||
|
}
|
||||||
|
return werr
|
||||||
|
}
|
||||||
|
|
||||||
|
// kill returns nil if the command terminated, regardless to the exit status.
|
||||||
|
func (c *commandConn) kill() error {
|
||||||
|
var werr error
|
||||||
|
c.cmdMutex.Lock()
|
||||||
|
if c.cmdExited {
|
||||||
|
werr = c.cmdWaitErr
|
||||||
|
} else {
|
||||||
|
werr = killAndWait(c.cmd)
|
||||||
|
c.cmdWaitErr = werr
|
||||||
|
c.cmdExited = true
|
||||||
|
}
|
||||||
|
c.cmdMutex.Unlock()
|
||||||
|
if werr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
wExitErr, ok := werr.(*exec.ExitError)
|
||||||
|
if ok {
|
||||||
|
if wExitErr.ProcessState.Exited() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errors.Wrapf(werr, "commandconn: failed to wait")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commandConn) onEOF(eof error) error {
|
||||||
|
// when we got EOF, the command is going to be terminated
|
||||||
|
var werr error
|
||||||
|
c.cmdMutex.Lock()
|
||||||
|
if c.cmdExited {
|
||||||
|
werr = c.cmdWaitErr
|
||||||
|
} else {
|
||||||
|
werrCh := make(chan error)
|
||||||
|
go func() { werrCh <- c.cmd.Wait() }()
|
||||||
|
select {
|
||||||
|
case werr = <-werrCh:
|
||||||
|
c.cmdWaitErr = werr
|
||||||
|
c.cmdExited = true
|
||||||
|
case <-time.After(10 * time.Second):
|
||||||
|
c.cmdMutex.Unlock()
|
||||||
|
c.stderrMu.Lock()
|
||||||
|
stderr := c.stderr.String()
|
||||||
|
c.stderrMu.Unlock()
|
||||||
|
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.cmdMutex.Unlock()
|
||||||
|
if werr == nil {
|
||||||
|
return eof
|
||||||
|
}
|
||||||
|
c.stderrMu.Lock()
|
||||||
|
stderr := c.stderr.String()
|
||||||
|
c.stderrMu.Unlock()
|
||||||
|
return errors.Errorf("command %v has exited with %v, please make sure the URL is valid, and Docker 18.09 or later is installed on the remote host: stderr=%s", c.cmd.Args, werr, stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ignorableCloseError(err error) bool {
|
||||||
|
errS := err.Error()
|
||||||
|
ss := []string{
|
||||||
|
os.ErrClosed.Error(),
|
||||||
|
}
|
||||||
|
for _, s := range ss {
|
||||||
|
if strings.Contains(errS, s) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commandConn) CloseRead() error {
|
||||||
|
// NOTE: maybe already closed here
|
||||||
|
if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) {
|
||||||
|
logrus.Warnf("commandConn.CloseRead: %v", err)
|
||||||
|
}
|
||||||
|
c.stdioClosedMu.Lock()
|
||||||
|
c.stdoutClosed = true
|
||||||
|
c.stdioClosedMu.Unlock()
|
||||||
|
if err := c.killIfStdioClosed(); err != nil {
|
||||||
|
logrus.Warnf("commandConn.CloseRead: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commandConn) Read(p []byte) (int, error) {
|
||||||
|
n, err := c.stdout.Read(p)
|
||||||
|
if err == io.EOF {
|
||||||
|
err = c.onEOF(err)
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commandConn) CloseWrite() error {
|
||||||
|
// NOTE: maybe already closed here
|
||||||
|
if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) {
|
||||||
|
logrus.Warnf("commandConn.CloseWrite: %v", err)
|
||||||
|
}
|
||||||
|
c.stdioClosedMu.Lock()
|
||||||
|
c.stdinClosed = true
|
||||||
|
c.stdioClosedMu.Unlock()
|
||||||
|
if err := c.killIfStdioClosed(); err != nil {
|
||||||
|
logrus.Warnf("commandConn.CloseWrite: %v", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commandConn) Write(p []byte) (int, error) {
|
||||||
|
n, err := c.stdin.Write(p)
|
||||||
|
if err == io.EOF {
|
||||||
|
err = c.onEOF(err)
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commandConn) Close() error {
|
||||||
|
var err error
|
||||||
|
if err = c.CloseRead(); err != nil {
|
||||||
|
logrus.Warnf("commandConn.Close: CloseRead: %v", err)
|
||||||
|
}
|
||||||
|
if err = c.CloseWrite(); err != nil {
|
||||||
|
logrus.Warnf("commandConn.Close: CloseWrite: %v", err)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *commandConn) LocalAddr() net.Addr {
|
||||||
|
return c.localAddr
|
||||||
|
}
|
||||||
|
func (c *commandConn) RemoteAddr() net.Addr {
|
||||||
|
return c.remoteAddr
|
||||||
|
}
|
||||||
|
func (c *commandConn) SetDeadline(t time.Time) error {
|
||||||
|
logrus.Debugf("unimplemented call: SetDeadline(%v)", t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c *commandConn) SetReadDeadline(t time.Time) error {
|
||||||
|
logrus.Debugf("unimplemented call: SetReadDeadline(%v)", t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (c *commandConn) SetWriteDeadline(t time.Time) error {
|
||||||
|
logrus.Debugf("unimplemented call: SetWriteDeadline(%v)", t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type dummyAddr struct {
|
||||||
|
network string
|
||||||
|
s string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d dummyAddr) Network() string {
|
||||||
|
return d.network
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d dummyAddr) String() string {
|
||||||
|
return d.s
|
||||||
|
}
|
||||||
|
|
||||||
|
type stderrWriter struct {
|
||||||
|
stderrMu *sync.Mutex
|
||||||
|
stderr *bytes.Buffer
|
||||||
|
debugPrefix string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *stderrWriter) Write(p []byte) (int, error) {
|
||||||
|
logrus.Debugf("%s%s", w.debugPrefix, string(p))
|
||||||
|
w.stderrMu.Lock()
|
||||||
|
if w.stderr.Len() > 4096 {
|
||||||
|
w.stderr.Reset()
|
||||||
|
}
|
||||||
|
n, err := w.stderr.Write(p)
|
||||||
|
w.stderrMu.Unlock()
|
||||||
|
return n, err
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package connhelper
|
package commandconn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os/exec"
|
"os/exec"
|
|
@ -1,6 +1,6 @@
|
||||||
// +build !linux
|
// +build !linux
|
||||||
|
|
||||||
package connhelper
|
package commandconn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"os/exec"
|
"os/exec"
|
|
@ -1,6 +1,6 @@
|
||||||
// +build !windows
|
// +build !windows
|
||||||
|
|
||||||
package connhelper
|
package commandconn
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -12,11 +12,11 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// For https://github.com/docker/cli/pull/1014#issuecomment-409308139
|
// For https://github.com/docker/cli/pull/1014#issuecomment-409308139
|
||||||
func TestCommandConnEOFWithError(t *testing.T) {
|
func TestEOFWithError(t *testing.T) {
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
cmd := "sh"
|
cmd := "sh"
|
||||||
args := []string{"-c", "echo hello; echo some error >&2; exit 42"}
|
args := []string{"-c", "echo hello; echo some error >&2; exit 42"}
|
||||||
c, err := newCommandConn(ctx, cmd, args...)
|
c, err := New(ctx, cmd, args...)
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
b := make([]byte, 32)
|
b := make([]byte, 32)
|
||||||
n, err := c.Read(b)
|
n, err := c.Read(b)
|
||||||
|
@ -28,11 +28,11 @@ func TestCommandConnEOFWithError(t *testing.T) {
|
||||||
assert.ErrorContains(t, err, "42")
|
assert.ErrorContains(t, err, "42")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCommandConnEOFWithoutError(t *testing.T) {
|
func TestEOFWithoutError(t *testing.T) {
|
||||||
ctx := context.TODO()
|
ctx := context.TODO()
|
||||||
cmd := "sh"
|
cmd := "sh"
|
||||||
args := []string{"-c", "echo hello; echo some debug log >&2; exit 0"}
|
args := []string{"-c", "echo hello; echo some debug log >&2; exit 0"}
|
||||||
c, err := newCommandConn(ctx, cmd, args...)
|
c, err := New(ctx, cmd, args...)
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
b := make([]byte, 32)
|
b := make([]byte, 32)
|
||||||
n, err := c.Read(b)
|
n, err := c.Read(b)
|
|
@ -2,23 +2,13 @@
|
||||||
package connhelper
|
package connhelper
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"os/exec"
|
|
||||||
"runtime"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"syscall"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
|
"github.com/docker/cli/cli/connhelper/commandconn"
|
||||||
"github.com/docker/cli/cli/connhelper/ssh"
|
"github.com/docker/cli/cli/connhelper/ssh"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionHelper allows to connect to a remote host with custom stream provider binary.
|
// ConnectionHelper allows to connect to a remote host with custom stream provider binary.
|
||||||
|
@ -29,7 +19,8 @@ type ConnectionHelper struct {
|
||||||
|
|
||||||
// GetConnectionHelper returns Docker-specific connection helper for the given URL.
|
// GetConnectionHelper returns Docker-specific connection helper for the given URL.
|
||||||
// GetConnectionHelper returns nil without error when no helper is registered for the scheme.
|
// GetConnectionHelper returns nil without error when no helper is registered for the scheme.
|
||||||
// URL is like "ssh://me@server01".
|
//
|
||||||
|
// ssh://<user>@<host> URL requires Docker 18.09 or later on the remote host.
|
||||||
func GetConnectionHelper(daemonURL string) (*ConnectionHelper, error) {
|
func GetConnectionHelper(daemonURL string) (*ConnectionHelper, error) {
|
||||||
u, err := url.Parse(daemonURL)
|
u, err := url.Parse(daemonURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -37,13 +28,13 @@ func GetConnectionHelper(daemonURL string) (*ConnectionHelper, error) {
|
||||||
}
|
}
|
||||||
switch scheme := u.Scheme; scheme {
|
switch scheme := u.Scheme; scheme {
|
||||||
case "ssh":
|
case "ssh":
|
||||||
sshCmd, sshArgs, err := ssh.New(daemonURL)
|
sp, err := ssh.ParseURL(daemonURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, errors.Wrap(err, "ssh host connection is not valid")
|
||||||
}
|
}
|
||||||
return &ConnectionHelper{
|
return &ConnectionHelper{
|
||||||
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
return newCommandConn(ctx, sshCmd, sshArgs...)
|
return commandconn.New(ctx, "ssh", append(sp.Args(), []string{"--", "docker", "system", "dial-stdio"}...)...)
|
||||||
},
|
},
|
||||||
Host: "http://docker",
|
Host: "http://docker",
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -53,260 +44,12 @@ func GetConnectionHelper(daemonURL string) (*ConnectionHelper, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetCommandConnectionHelper returns a ConnectionHelp constructed from an arbitrary command.
|
// GetCommandConnectionHelper returns Docker-specific connection helper constructed from an arbitrary command.
|
||||||
func GetCommandConnectionHelper(cmd string, flags ...string) (*ConnectionHelper, error) {
|
func GetCommandConnectionHelper(cmd string, flags ...string) (*ConnectionHelper, error) {
|
||||||
return &ConnectionHelper{
|
return &ConnectionHelper{
|
||||||
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
return newCommandConn(ctx, cmd, flags...)
|
return commandconn.New(ctx, cmd, flags...)
|
||||||
},
|
},
|
||||||
Host: "http://docker",
|
Host: "http://docker",
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newCommandConn(ctx context.Context, cmd string, args ...string) (net.Conn, error) {
|
|
||||||
var (
|
|
||||||
c commandConn
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
c.cmd = exec.CommandContext(ctx, cmd, args...)
|
|
||||||
// we assume that args never contains sensitive information
|
|
||||||
logrus.Debugf("connhelper: starting %s with %v", cmd, args)
|
|
||||||
c.cmd.Env = os.Environ()
|
|
||||||
setPdeathsig(c.cmd)
|
|
||||||
c.stdin, err = c.cmd.StdinPipe()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.stdout, err = c.cmd.StdoutPipe()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
c.cmd.Stderr = &stderrWriter{
|
|
||||||
stderrMu: &c.stderrMu,
|
|
||||||
stderr: &c.stderr,
|
|
||||||
debugPrefix: fmt.Sprintf("connhelper (%s):", cmd),
|
|
||||||
}
|
|
||||||
c.localAddr = dummyAddr{network: "dummy", s: "dummy-0"}
|
|
||||||
c.remoteAddr = dummyAddr{network: "dummy", s: "dummy-1"}
|
|
||||||
return &c, c.cmd.Start()
|
|
||||||
}
|
|
||||||
|
|
||||||
// commandConn implements net.Conn
|
|
||||||
type commandConn struct {
|
|
||||||
cmd *exec.Cmd
|
|
||||||
cmdExited bool
|
|
||||||
cmdWaitErr error
|
|
||||||
cmdMutex sync.Mutex
|
|
||||||
stdin io.WriteCloser
|
|
||||||
stdout io.ReadCloser
|
|
||||||
stderrMu sync.Mutex
|
|
||||||
stderr bytes.Buffer
|
|
||||||
stdioClosedMu sync.Mutex // for stdinClosed and stdoutClosed
|
|
||||||
stdinClosed bool
|
|
||||||
stdoutClosed bool
|
|
||||||
localAddr net.Addr
|
|
||||||
remoteAddr net.Addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// killIfStdioClosed kills the cmd if both stdin and stdout are closed.
|
|
||||||
func (c *commandConn) killIfStdioClosed() error {
|
|
||||||
c.stdioClosedMu.Lock()
|
|
||||||
stdioClosed := c.stdoutClosed && c.stdinClosed
|
|
||||||
c.stdioClosedMu.Unlock()
|
|
||||||
if !stdioClosed {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return c.kill()
|
|
||||||
}
|
|
||||||
|
|
||||||
// killAndWait tries sending SIGTERM to the process before sending SIGKILL.
|
|
||||||
func killAndWait(cmd *exec.Cmd) error {
|
|
||||||
var werr error
|
|
||||||
if runtime.GOOS != "windows" {
|
|
||||||
werrCh := make(chan error)
|
|
||||||
go func() { werrCh <- cmd.Wait() }()
|
|
||||||
cmd.Process.Signal(syscall.SIGTERM)
|
|
||||||
select {
|
|
||||||
case werr = <-werrCh:
|
|
||||||
case <-time.After(3 * time.Second):
|
|
||||||
cmd.Process.Kill()
|
|
||||||
werr = <-werrCh
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cmd.Process.Kill()
|
|
||||||
werr = cmd.Wait()
|
|
||||||
}
|
|
||||||
return werr
|
|
||||||
}
|
|
||||||
|
|
||||||
// kill returns nil if the command terminated, regardless to the exit status.
|
|
||||||
func (c *commandConn) kill() error {
|
|
||||||
var werr error
|
|
||||||
c.cmdMutex.Lock()
|
|
||||||
if c.cmdExited {
|
|
||||||
werr = c.cmdWaitErr
|
|
||||||
} else {
|
|
||||||
werr = killAndWait(c.cmd)
|
|
||||||
c.cmdWaitErr = werr
|
|
||||||
c.cmdExited = true
|
|
||||||
}
|
|
||||||
c.cmdMutex.Unlock()
|
|
||||||
if werr == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
wExitErr, ok := werr.(*exec.ExitError)
|
|
||||||
if ok {
|
|
||||||
if wExitErr.ProcessState.Exited() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return errors.Wrapf(werr, "connhelper: failed to wait")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *commandConn) onEOF(eof error) error {
|
|
||||||
// when we got EOF, the command is going to be terminated
|
|
||||||
var werr error
|
|
||||||
c.cmdMutex.Lock()
|
|
||||||
if c.cmdExited {
|
|
||||||
werr = c.cmdWaitErr
|
|
||||||
} else {
|
|
||||||
werrCh := make(chan error)
|
|
||||||
go func() { werrCh <- c.cmd.Wait() }()
|
|
||||||
select {
|
|
||||||
case werr = <-werrCh:
|
|
||||||
c.cmdWaitErr = werr
|
|
||||||
c.cmdExited = true
|
|
||||||
case <-time.After(10 * time.Second):
|
|
||||||
c.cmdMutex.Unlock()
|
|
||||||
c.stderrMu.Lock()
|
|
||||||
stderr := c.stderr.String()
|
|
||||||
c.stderrMu.Unlock()
|
|
||||||
return errors.Errorf("command %v did not exit after %v: stderr=%q", c.cmd.Args, eof, stderr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.cmdMutex.Unlock()
|
|
||||||
if werr == nil {
|
|
||||||
return eof
|
|
||||||
}
|
|
||||||
c.stderrMu.Lock()
|
|
||||||
stderr := c.stderr.String()
|
|
||||||
c.stderrMu.Unlock()
|
|
||||||
return errors.Errorf("command %v has exited with %v, please make sure the URL is valid, and Docker 18.09 or later is installed on the remote host: stderr=%s", c.cmd.Args, werr, stderr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func ignorableCloseError(err error) bool {
|
|
||||||
errS := err.Error()
|
|
||||||
ss := []string{
|
|
||||||
os.ErrClosed.Error(),
|
|
||||||
}
|
|
||||||
for _, s := range ss {
|
|
||||||
if strings.Contains(errS, s) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *commandConn) CloseRead() error {
|
|
||||||
// NOTE: maybe already closed here
|
|
||||||
if err := c.stdout.Close(); err != nil && !ignorableCloseError(err) {
|
|
||||||
logrus.Warnf("commandConn.CloseRead: %v", err)
|
|
||||||
}
|
|
||||||
c.stdioClosedMu.Lock()
|
|
||||||
c.stdoutClosed = true
|
|
||||||
c.stdioClosedMu.Unlock()
|
|
||||||
if err := c.killIfStdioClosed(); err != nil {
|
|
||||||
logrus.Warnf("commandConn.CloseRead: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *commandConn) Read(p []byte) (int, error) {
|
|
||||||
n, err := c.stdout.Read(p)
|
|
||||||
if err == io.EOF {
|
|
||||||
err = c.onEOF(err)
|
|
||||||
}
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *commandConn) CloseWrite() error {
|
|
||||||
// NOTE: maybe already closed here
|
|
||||||
if err := c.stdin.Close(); err != nil && !ignorableCloseError(err) {
|
|
||||||
logrus.Warnf("commandConn.CloseWrite: %v", err)
|
|
||||||
}
|
|
||||||
c.stdioClosedMu.Lock()
|
|
||||||
c.stdinClosed = true
|
|
||||||
c.stdioClosedMu.Unlock()
|
|
||||||
if err := c.killIfStdioClosed(); err != nil {
|
|
||||||
logrus.Warnf("commandConn.CloseWrite: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *commandConn) Write(p []byte) (int, error) {
|
|
||||||
n, err := c.stdin.Write(p)
|
|
||||||
if err == io.EOF {
|
|
||||||
err = c.onEOF(err)
|
|
||||||
}
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *commandConn) Close() error {
|
|
||||||
var err error
|
|
||||||
if err = c.CloseRead(); err != nil {
|
|
||||||
logrus.Warnf("commandConn.Close: CloseRead: %v", err)
|
|
||||||
}
|
|
||||||
if err = c.CloseWrite(); err != nil {
|
|
||||||
logrus.Warnf("commandConn.Close: CloseWrite: %v", err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *commandConn) LocalAddr() net.Addr {
|
|
||||||
return c.localAddr
|
|
||||||
}
|
|
||||||
func (c *commandConn) RemoteAddr() net.Addr {
|
|
||||||
return c.remoteAddr
|
|
||||||
}
|
|
||||||
func (c *commandConn) SetDeadline(t time.Time) error {
|
|
||||||
logrus.Debugf("unimplemented call: SetDeadline(%v)", t)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (c *commandConn) SetReadDeadline(t time.Time) error {
|
|
||||||
logrus.Debugf("unimplemented call: SetReadDeadline(%v)", t)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (c *commandConn) SetWriteDeadline(t time.Time) error {
|
|
||||||
logrus.Debugf("unimplemented call: SetWriteDeadline(%v)", t)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type dummyAddr struct {
|
|
||||||
network string
|
|
||||||
s string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d dummyAddr) Network() string {
|
|
||||||
return d.network
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d dummyAddr) String() string {
|
|
||||||
return d.s
|
|
||||||
}
|
|
||||||
|
|
||||||
type stderrWriter struct {
|
|
||||||
stderrMu *sync.Mutex
|
|
||||||
stderr *bytes.Buffer
|
|
||||||
debugPrefix string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *stderrWriter) Write(p []byte) (int, error) {
|
|
||||||
logrus.Debugf("%s%s", w.debugPrefix, string(p))
|
|
||||||
w.stderrMu.Lock()
|
|
||||||
if w.stderr.Len() > 4096 {
|
|
||||||
w.stderr.Reset()
|
|
||||||
}
|
|
||||||
n, err := w.stderr.Write(p)
|
|
||||||
w.stderrMu.Unlock()
|
|
||||||
return n, err
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
// Package ssh provides the connection helper for ssh:// URL.
|
// Package ssh provides the connection helper for ssh:// URL.
|
||||||
// Requires Docker 18.09 or later on the remote host.
|
|
||||||
package ssh
|
package ssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -8,16 +7,8 @@ import (
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
// New returns cmd and its args
|
// ParseURL parses URL
|
||||||
func New(daemonURL string) (string, []string, error) {
|
func ParseURL(daemonURL string) (*Spec, error) {
|
||||||
sp, err := parseSSHURL(daemonURL)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, errors.Wrap(err, "SSH host connection is not valid")
|
|
||||||
}
|
|
||||||
return "ssh", append(sp.Args(), []string{"--", "docker", "system", "dial-stdio"}...), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseSSHURL(daemonURL string) (*sshSpec, error) {
|
|
||||||
u, err := url.Parse(daemonURL)
|
u, err := url.Parse(daemonURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -26,19 +17,19 @@ func parseSSHURL(daemonURL string) (*sshSpec, error) {
|
||||||
return nil, errors.Errorf("expected scheme ssh, got %q", u.Scheme)
|
return nil, errors.Errorf("expected scheme ssh, got %q", u.Scheme)
|
||||||
}
|
}
|
||||||
|
|
||||||
var sp sshSpec
|
var sp Spec
|
||||||
|
|
||||||
if u.User != nil {
|
if u.User != nil {
|
||||||
sp.user = u.User.Username()
|
sp.User = u.User.Username()
|
||||||
if _, ok := u.User.Password(); ok {
|
if _, ok := u.User.Password(); ok {
|
||||||
return nil, errors.New("plain-text password is not supported")
|
return nil, errors.New("plain-text password is not supported")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sp.host = u.Hostname()
|
sp.Host = u.Hostname()
|
||||||
if sp.host == "" {
|
if sp.Host == "" {
|
||||||
return nil, errors.Errorf("no host specified")
|
return nil, errors.Errorf("no host specified")
|
||||||
}
|
}
|
||||||
sp.port = u.Port()
|
sp.Port = u.Port()
|
||||||
if u.Path != "" {
|
if u.Path != "" {
|
||||||
return nil, errors.Errorf("extra path after the host: %q", u.Path)
|
return nil, errors.Errorf("extra path after the host: %q", u.Path)
|
||||||
}
|
}
|
||||||
|
@ -51,20 +42,22 @@ func parseSSHURL(daemonURL string) (*sshSpec, error) {
|
||||||
return &sp, err
|
return &sp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
type sshSpec struct {
|
// Spec of SSH URL
|
||||||
user string
|
type Spec struct {
|
||||||
host string
|
User string
|
||||||
port string
|
Host string
|
||||||
|
Port string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sp *sshSpec) Args() []string {
|
// Args returns args except "ssh" itself and "-- ..."
|
||||||
|
func (sp *Spec) Args() []string {
|
||||||
var args []string
|
var args []string
|
||||||
if sp.user != "" {
|
if sp.User != "" {
|
||||||
args = append(args, "-l", sp.user)
|
args = append(args, "-l", sp.User)
|
||||||
}
|
}
|
||||||
if sp.port != "" {
|
if sp.Port != "" {
|
||||||
args = append(args, "-p", sp.port)
|
args = append(args, "-p", sp.Port)
|
||||||
}
|
}
|
||||||
args = append(args, sp.host)
|
args = append(args, sp.Host)
|
||||||
return args
|
return args
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,7 +7,7 @@ import (
|
||||||
is "gotest.tools/assert/cmp"
|
is "gotest.tools/assert/cmp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestParseSSHURL(t *testing.T) {
|
func TestParseURL(t *testing.T) {
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
url string
|
url string
|
||||||
expectedArgs []string
|
expectedArgs []string
|
||||||
|
@ -53,7 +53,7 @@ func TestParseSSHURL(t *testing.T) {
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
sp, err := parseSSHURL(tc.url)
|
sp, err := ParseURL(tc.url)
|
||||||
if tc.expectedError == "" {
|
if tc.expectedError == "" {
|
||||||
assert.NilError(t, err)
|
assert.NilError(t, err)
|
||||||
assert.Check(t, is.DeepEqual(tc.expectedArgs, sp.Args()))
|
assert.Check(t, is.DeepEqual(tc.expectedArgs, sp.Args()))
|
||||||
|
|
|
@ -21,6 +21,6 @@ func TestDialStdio(t *testing.T) {
|
||||||
cmd := icmd.Command(helloworld, "--config=blah", "--tls", "--log-level", "debug", "helloworld", "--who=foo")
|
cmd := icmd.Command(helloworld, "--config=blah", "--tls", "--log-level", "debug", "helloworld", "--who=foo")
|
||||||
res := icmd.RunCmd(cmd, icmd.WithEnv(manager.ReexecEnvvar+"=/bin/true"))
|
res := icmd.RunCmd(cmd, icmd.WithEnv(manager.ReexecEnvvar+"=/bin/true"))
|
||||||
res.Assert(t, icmd.Success)
|
res.Assert(t, icmd.Success)
|
||||||
assert.Assert(t, is.Contains(res.Stderr(), `msg="connhelper: starting /bin/true with [--config=blah --tls --log-level debug system dial-stdio]"`))
|
assert.Assert(t, is.Contains(res.Stderr(), `msg="commandconn: starting /bin/true with [--config=blah --tls --log-level debug system dial-stdio]"`))
|
||||||
assert.Assert(t, is.Equal(res.Stdout(), "Hello foo!\n"))
|
assert.Assert(t, is.Equal(res.Stdout(), "Hello foo!\n"))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue