diff --git a/cli/connhelper/connhelper.go b/cli/connhelper/connhelper.go index 152d3e2953..309320889f 100644 --- a/cli/connhelper/connhelper.go +++ b/cli/connhelper/connhelper.go @@ -5,6 +5,7 @@ import ( "context" "net" "net/url" + "os" "strings" "github.com/docker/cli/cli/connhelper/commandconn" @@ -12,6 +13,12 @@ import ( "github.com/pkg/errors" ) +const ( + // DockerSSHRemoteBinaryEnv is the environment variable that can be used to + // override the default Docker binary called over SSH + DockerSSHRemoteBinaryEnv = "DOCKER_SSH_REMOTE_BINARY" +) + // ConnectionHelper allows to connect to a remote host with custom stream provider binary. type ConnectionHelper struct { Dialer func(ctx context.Context, network, addr string) (net.Conn, error) @@ -47,9 +54,10 @@ func getConnectionHelper(daemonURL string, sshFlags []string) (*ConnectionHelper } sshFlags = addSSHTimeout(sshFlags) sshFlags = disablePseudoTerminalAllocation(sshFlags) + remoteDockerBinary := dockerSSHRemoteBinary() return &ConnectionHelper{ Dialer: func(ctx context.Context, network, addr string) (net.Conn, error) { - args := []string{"docker"} + args := []string{remoteDockerBinary} if sp.Path != "" { args = append(args, "--host", "unix://"+sp.Path) } @@ -91,3 +99,15 @@ func disablePseudoTerminalAllocation(sshFlags []string) []string { } return append(sshFlags, "-T") } + +// dockerSSHRemoteBinary returns the binary to use when executing Docker +// commands over SSH. It defaults to "docker" if the DOCKER_SSH_REMOTE_BINARY +// environment variable is not set. +func dockerSSHRemoteBinary() string { + value := os.Getenv(DockerSSHRemoteBinaryEnv) + if value == "" { + return "docker" + } + + return value +} diff --git a/cli/connhelper/connhelper_test.go b/cli/connhelper/connhelper_test.go index 0d9aee0fb2..a5de3db5ac 100644 --- a/cli/connhelper/connhelper_test.go +++ b/cli/connhelper/connhelper_test.go @@ -63,3 +63,32 @@ func TestDisablePseudoTerminalAllocation(t *testing.T) { }) } } + +func TestDockerSSHBinaryOverride(t *testing.T) { + testCases := []struct { + name string + env string + expected string + }{ + { + name: "Default", + env: "", + expected: "docker", + }, + { + name: "Override", + env: "other-binary", + expected: "other-binary", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Setenv(DockerSSHRemoteBinaryEnv, tc.env) + result := dockerSSHRemoteBinary() + if result != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, result) + } + }) + } +}