mirror of https://github.com/docker/cli.git
Merge pull request #546 from dnephin/fix-version-on-failure
Set APIVersion on the client, even when Ping fails
This commit is contained in:
commit
a41caadef0
|
@ -17,6 +17,7 @@ import (
|
|||
"github.com/docker/docker/client"
|
||||
"github.com/docker/go-connections/sockets"
|
||||
"github.com/docker/go-connections/tlsconfig"
|
||||
"github.com/docker/notary"
|
||||
"github.com/docker/notary/passphrase"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/spf13/cobra"
|
||||
|
@ -111,44 +112,53 @@ func (cli *DockerCli) Initialize(opts *cliflags.ClientOptions) error {
|
|||
var err error
|
||||
cli.client, err = NewAPIClientFromFlags(opts.Common, cli.configFile)
|
||||
if tlsconfig.IsErrEncryptedKey(err) {
|
||||
var (
|
||||
passwd string
|
||||
giveup bool
|
||||
)
|
||||
passRetriever := passphrase.PromptRetrieverWithInOut(cli.In(), cli.Out(), nil)
|
||||
|
||||
for attempts := 0; tlsconfig.IsErrEncryptedKey(err); attempts++ {
|
||||
// some code and comments borrowed from notary/trustmanager/keystore.go
|
||||
passwd, giveup, err = passRetriever("private", "encrypted TLS private", false, attempts)
|
||||
// Check if the passphrase retriever got an error or if it is telling us to give up
|
||||
if giveup || err != nil {
|
||||
return errors.Wrap(err, "private key is encrypted, but could not get passphrase")
|
||||
newClient := func(password string) (client.APIClient, error) {
|
||||
opts.Common.TLSOptions.Passphrase = password
|
||||
return NewAPIClientFromFlags(opts.Common, cli.configFile)
|
||||
}
|
||||
|
||||
opts.Common.TLSOptions.Passphrase = passwd
|
||||
cli.client, err = NewAPIClientFromFlags(opts.Common, cli.configFile)
|
||||
cli.client, err = getClientWithPassword(passRetriever, newClient)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cli.initializeFromClient()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cli *DockerCli) initializeFromClient() {
|
||||
cli.defaultVersion = cli.client.ClientVersion()
|
||||
|
||||
if ping, err := cli.client.Ping(context.Background()); err == nil {
|
||||
ping, err := cli.client.Ping(context.Background())
|
||||
if err != nil {
|
||||
// Default to true if we fail to connect to daemon
|
||||
cli.server = ServerInfo{HasExperimental: true}
|
||||
|
||||
if ping.APIVersion != "" {
|
||||
cli.client.NegotiateAPIVersionPing(ping)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
cli.server = ServerInfo{
|
||||
HasExperimental: ping.Experimental,
|
||||
OSType: ping.OSType,
|
||||
}
|
||||
|
||||
cli.client.NegotiateAPIVersionPing(ping)
|
||||
} else {
|
||||
// Default to true if we fail to connect to daemon
|
||||
cli.server = ServerInfo{HasExperimental: true}
|
||||
}
|
||||
|
||||
func getClientWithPassword(passRetriever notary.PassRetriever, newClient func(password string) (client.APIClient, error)) (client.APIClient, error) {
|
||||
for attempts := 0; ; attempts++ {
|
||||
passwd, giveup, err := passRetriever("private", "encrypted TLS private", false, attempts)
|
||||
if giveup || err != nil {
|
||||
return nil, errors.Wrap(err, "private key is encrypted, but could not get passphrase")
|
||||
}
|
||||
|
||||
return nil
|
||||
apiclient, err := newClient(passwd)
|
||||
if !tlsconfig.IsErrEncryptedKey(err) {
|
||||
return apiclient, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ServerInfo stores details about the supported features and platform of the
|
||||
|
|
|
@ -4,12 +4,18 @@ import (
|
|||
"os"
|
||||
"testing"
|
||||
|
||||
"crypto/x509"
|
||||
|
||||
"github.com/docker/cli/cli/config/configfile"
|
||||
"github.com/docker/cli/cli/flags"
|
||||
"github.com/docker/cli/internal/test/testutil"
|
||||
"github.com/docker/docker/api"
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
func TestNewAPIClientFromFlags(t *testing.T) {
|
||||
|
@ -43,7 +49,7 @@ func TestNewAPIClientFromFlagsWithAPIVersionFromEnv(t *testing.T) {
|
|||
assert.Equal(t, customVersion, apiclient.ClientVersion())
|
||||
}
|
||||
|
||||
// TODO: move to gotestyourself
|
||||
// TODO: use gotestyourself/env.Patch
|
||||
func patchEnvVariable(t *testing.T, key, value string) func() {
|
||||
oldValue, ok := os.LookupEnv(key)
|
||||
require.NoError(t, os.Setenv(key, value))
|
||||
|
@ -55,3 +61,138 @@ func patchEnvVariable(t *testing.T, key, value string) func() {
|
|||
require.NoError(t, os.Setenv(key, oldValue))
|
||||
}
|
||||
}
|
||||
|
||||
type fakeClient struct {
|
||||
client.Client
|
||||
pingFunc func() (types.Ping, error)
|
||||
version string
|
||||
negotiated bool
|
||||
}
|
||||
|
||||
func (c *fakeClient) Ping(_ context.Context) (types.Ping, error) {
|
||||
return c.pingFunc()
|
||||
}
|
||||
|
||||
func (c *fakeClient) ClientVersion() string {
|
||||
return c.version
|
||||
}
|
||||
|
||||
func (c *fakeClient) NegotiateAPIVersionPing(types.Ping) {
|
||||
c.negotiated = true
|
||||
}
|
||||
|
||||
func TestInitializeFromClient(t *testing.T) {
|
||||
defaultVersion := "v1.55"
|
||||
|
||||
var testcases = []struct {
|
||||
doc string
|
||||
pingFunc func() (types.Ping, error)
|
||||
expectedServer ServerInfo
|
||||
negotiated bool
|
||||
}{
|
||||
{
|
||||
doc: "successful ping",
|
||||
pingFunc: func() (types.Ping, error) {
|
||||
return types.Ping{Experimental: true, OSType: "linux", APIVersion: "v1.30"}, nil
|
||||
},
|
||||
expectedServer: ServerInfo{HasExperimental: true, OSType: "linux"},
|
||||
negotiated: true,
|
||||
},
|
||||
{
|
||||
doc: "failed ping, no API version",
|
||||
pingFunc: func() (types.Ping, error) {
|
||||
return types.Ping{}, errors.New("failed")
|
||||
},
|
||||
expectedServer: ServerInfo{HasExperimental: true},
|
||||
},
|
||||
{
|
||||
doc: "failed ping, with API version",
|
||||
pingFunc: func() (types.Ping, error) {
|
||||
return types.Ping{APIVersion: "v1.33"}, errors.New("failed")
|
||||
},
|
||||
expectedServer: ServerInfo{HasExperimental: true},
|
||||
negotiated: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, testcase := range testcases {
|
||||
t.Run(testcase.doc, func(t *testing.T) {
|
||||
apiclient := &fakeClient{
|
||||
pingFunc: testcase.pingFunc,
|
||||
version: defaultVersion,
|
||||
}
|
||||
|
||||
cli := &DockerCli{client: apiclient}
|
||||
cli.initializeFromClient()
|
||||
assert.Equal(t, defaultVersion, cli.defaultVersion)
|
||||
assert.Equal(t, testcase.expectedServer, cli.server)
|
||||
assert.Equal(t, testcase.negotiated, apiclient.negotiated)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetClientWithPassword(t *testing.T) {
|
||||
expected := "password"
|
||||
|
||||
var testcases = []struct {
|
||||
doc string
|
||||
password string
|
||||
retrieverErr error
|
||||
retrieverGiveup bool
|
||||
newClientErr error
|
||||
expectedErr string
|
||||
}{
|
||||
{
|
||||
doc: "successful connect",
|
||||
password: expected,
|
||||
},
|
||||
{
|
||||
doc: "password retriever exhausted",
|
||||
retrieverGiveup: true,
|
||||
retrieverErr: errors.New("failed"),
|
||||
expectedErr: "private key is encrypted, but could not get passphrase",
|
||||
},
|
||||
{
|
||||
doc: "password retriever error",
|
||||
retrieverErr: errors.New("failed"),
|
||||
expectedErr: "failed",
|
||||
},
|
||||
{
|
||||
doc: "newClient error",
|
||||
newClientErr: errors.New("failed to connect"),
|
||||
expectedErr: "failed to connect",
|
||||
},
|
||||
}
|
||||
|
||||
for _, testcase := range testcases {
|
||||
t.Run(testcase.doc, func(t *testing.T) {
|
||||
passRetriever := func(_, _ string, _ bool, attempts int) (passphrase string, giveup bool, err error) {
|
||||
// Always return an invalid pass first to test iteration
|
||||
switch attempts {
|
||||
case 0:
|
||||
return "something else", false, nil
|
||||
default:
|
||||
return testcase.password, testcase.retrieverGiveup, testcase.retrieverErr
|
||||
}
|
||||
}
|
||||
|
||||
newClient := func(currentPassword string) (client.APIClient, error) {
|
||||
if testcase.newClientErr != nil {
|
||||
return nil, testcase.newClientErr
|
||||
}
|
||||
if currentPassword == expected {
|
||||
return &client.Client{}, nil
|
||||
}
|
||||
return &client.Client{}, x509.IncorrectPasswordError
|
||||
}
|
||||
|
||||
_, err := getClientWithPassword(passRetriever, newClient)
|
||||
if testcase.expectedErr != "" {
|
||||
testutil.ErrorContains(t, err, testcase.expectedErr)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue