diff --git a/cli/command/cli.go b/cli/command/cli.go index 4ac881f2dc..f26b808ac2 100644 --- a/cli/command/cli.go +++ b/cli/command/cli.go @@ -17,6 +17,7 @@ import ( "github.com/docker/docker/client" "github.com/docker/go-connections/sockets" "github.com/docker/go-connections/tlsconfig" + "github.com/docker/notary" "github.com/docker/notary/passphrase" "github.com/pkg/errors" "github.com/spf13/cobra" @@ -111,44 +112,53 @@ func (cli *DockerCli) Initialize(opts *cliflags.ClientOptions) error { var err error cli.client, err = NewAPIClientFromFlags(opts.Common, cli.configFile) if tlsconfig.IsErrEncryptedKey(err) { - var ( - passwd string - giveup bool - ) passRetriever := passphrase.PromptRetrieverWithInOut(cli.In(), cli.Out(), nil) - - for attempts := 0; tlsconfig.IsErrEncryptedKey(err); attempts++ { - // some code and comments borrowed from notary/trustmanager/keystore.go - passwd, giveup, err = passRetriever("private", "encrypted TLS private", false, attempts) - // Check if the passphrase retriever got an error or if it is telling us to give up - if giveup || err != nil { - return errors.Wrap(err, "private key is encrypted, but could not get passphrase") - } - - opts.Common.TLSOptions.Passphrase = passwd - cli.client, err = NewAPIClientFromFlags(opts.Common, cli.configFile) + newClient := func(password string) (client.APIClient, error) { + opts.Common.TLSOptions.Passphrase = password + return NewAPIClientFromFlags(opts.Common, cli.configFile) } + cli.client, err = getClientWithPassword(passRetriever, newClient) } - if err != nil { return err } + cli.initializeFromClient() + return nil +} +func (cli *DockerCli) initializeFromClient() { cli.defaultVersion = cli.client.ClientVersion() - if ping, err := cli.client.Ping(context.Background()); err == nil { - cli.server = ServerInfo{ - HasExperimental: ping.Experimental, - OSType: ping.OSType, - } - - cli.client.NegotiateAPIVersionPing(ping) - } else { + ping, err := cli.client.Ping(context.Background()) + if err != nil { // Default to true if we fail to connect to daemon cli.server = ServerInfo{HasExperimental: true} + + if ping.APIVersion != "" { + cli.client.NegotiateAPIVersionPing(ping) + } + return } - return nil + cli.server = ServerInfo{ + HasExperimental: ping.Experimental, + OSType: ping.OSType, + } + cli.client.NegotiateAPIVersionPing(ping) +} + +func getClientWithPassword(passRetriever notary.PassRetriever, newClient func(password string) (client.APIClient, error)) (client.APIClient, error) { + for attempts := 0; ; attempts++ { + passwd, giveup, err := passRetriever("private", "encrypted TLS private", false, attempts) + if giveup || err != nil { + return nil, errors.Wrap(err, "private key is encrypted, but could not get passphrase") + } + + apiclient, err := newClient(passwd) + if !tlsconfig.IsErrEncryptedKey(err) { + return apiclient, err + } + } } // ServerInfo stores details about the supported features and platform of the diff --git a/cli/command/cli_test.go b/cli/command/cli_test.go index ca73a05e38..16ba7c4ff9 100644 --- a/cli/command/cli_test.go +++ b/cli/command/cli_test.go @@ -4,12 +4,18 @@ import ( "os" "testing" + "crypto/x509" + "github.com/docker/cli/cli/config/configfile" "github.com/docker/cli/cli/flags" + "github.com/docker/cli/internal/test/testutil" "github.com/docker/docker/api" + "github.com/docker/docker/api/types" "github.com/docker/docker/client" + "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/net/context" ) func TestNewAPIClientFromFlags(t *testing.T) { @@ -43,7 +49,7 @@ func TestNewAPIClientFromFlagsWithAPIVersionFromEnv(t *testing.T) { assert.Equal(t, customVersion, apiclient.ClientVersion()) } -// TODO: move to gotestyourself +// TODO: use gotestyourself/env.Patch func patchEnvVariable(t *testing.T, key, value string) func() { oldValue, ok := os.LookupEnv(key) require.NoError(t, os.Setenv(key, value)) @@ -55,3 +61,138 @@ func patchEnvVariable(t *testing.T, key, value string) func() { require.NoError(t, os.Setenv(key, oldValue)) } } + +type fakeClient struct { + client.Client + pingFunc func() (types.Ping, error) + version string + negotiated bool +} + +func (c *fakeClient) Ping(_ context.Context) (types.Ping, error) { + return c.pingFunc() +} + +func (c *fakeClient) ClientVersion() string { + return c.version +} + +func (c *fakeClient) NegotiateAPIVersionPing(types.Ping) { + c.negotiated = true +} + +func TestInitializeFromClient(t *testing.T) { + defaultVersion := "v1.55" + + var testcases = []struct { + doc string + pingFunc func() (types.Ping, error) + expectedServer ServerInfo + negotiated bool + }{ + { + doc: "successful ping", + pingFunc: func() (types.Ping, error) { + return types.Ping{Experimental: true, OSType: "linux", APIVersion: "v1.30"}, nil + }, + expectedServer: ServerInfo{HasExperimental: true, OSType: "linux"}, + negotiated: true, + }, + { + doc: "failed ping, no API version", + pingFunc: func() (types.Ping, error) { + return types.Ping{}, errors.New("failed") + }, + expectedServer: ServerInfo{HasExperimental: true}, + }, + { + doc: "failed ping, with API version", + pingFunc: func() (types.Ping, error) { + return types.Ping{APIVersion: "v1.33"}, errors.New("failed") + }, + expectedServer: ServerInfo{HasExperimental: true}, + negotiated: true, + }, + } + + for _, testcase := range testcases { + t.Run(testcase.doc, func(t *testing.T) { + apiclient := &fakeClient{ + pingFunc: testcase.pingFunc, + version: defaultVersion, + } + + cli := &DockerCli{client: apiclient} + cli.initializeFromClient() + assert.Equal(t, defaultVersion, cli.defaultVersion) + assert.Equal(t, testcase.expectedServer, cli.server) + assert.Equal(t, testcase.negotiated, apiclient.negotiated) + }) + } +} + +func TestGetClientWithPassword(t *testing.T) { + expected := "password" + + var testcases = []struct { + doc string + password string + retrieverErr error + retrieverGiveup bool + newClientErr error + expectedErr string + }{ + { + doc: "successful connect", + password: expected, + }, + { + doc: "password retriever exhausted", + retrieverGiveup: true, + retrieverErr: errors.New("failed"), + expectedErr: "private key is encrypted, but could not get passphrase", + }, + { + doc: "password retriever error", + retrieverErr: errors.New("failed"), + expectedErr: "failed", + }, + { + doc: "newClient error", + newClientErr: errors.New("failed to connect"), + expectedErr: "failed to connect", + }, + } + + for _, testcase := range testcases { + t.Run(testcase.doc, func(t *testing.T) { + passRetriever := func(_, _ string, _ bool, attempts int) (passphrase string, giveup bool, err error) { + // Always return an invalid pass first to test iteration + switch attempts { + case 0: + return "something else", false, nil + default: + return testcase.password, testcase.retrieverGiveup, testcase.retrieverErr + } + } + + newClient := func(currentPassword string) (client.APIClient, error) { + if testcase.newClientErr != nil { + return nil, testcase.newClientErr + } + if currentPassword == expected { + return &client.Client{}, nil + } + return &client.Client{}, x509.IncorrectPasswordError + } + + _, err := getClientWithPassword(passRetriever, newClient) + if testcase.expectedErr != "" { + testutil.ErrorContains(t, err, testcase.expectedErr) + return + } + + assert.NoError(t, err) + }) + } +}