diff --git a/cli/config/config.go b/cli/config/config.go index 703fa30f48..98147e270a 100644 --- a/cli/config/config.go +++ b/cli/config/config.go @@ -26,15 +26,29 @@ const ( var ( initConfigDir sync.Once configDir string + homeDir string ) +// resetHomeDir is used in testing to resets the "homeDir" package variable to +// force re-lookup of the home directory between tests. +func resetHomeDir() { + homeDir = "" +} + +func getHomeDir() string { + if homeDir == "" { + homeDir = homedir.Get() + } + return homeDir +} + func setConfigDir() { if configDir != "" { return } configDir = os.Getenv("DOCKER_CONFIG") if configDir == "" { - configDir = filepath.Join(homedir.Get(), configFileDir) + configDir = filepath.Join(getHomeDir(), configFileDir) } } @@ -109,11 +123,7 @@ func Load(configDir string) (*configfile.ConfigFile, error) { } // Can't find latest config file so check for the old one - home, err := os.UserHomeDir() - if err != nil { - return configFile, errors.Wrap(err, oldConfigfile) - } - filename = filepath.Join(home, oldConfigfile) + filename = filepath.Join(getHomeDir(), oldConfigfile) if file, err := os.Open(filename); err == nil { defer file.Close() if err := configFile.LegacyLoadFromReader(file); err != nil { diff --git a/cli/config/config_test.go b/cli/config/config_test.go index a861d4da1e..c7787bb5d1 100644 --- a/cli/config/config_test.go +++ b/cli/config/config_test.go @@ -115,6 +115,7 @@ password`: "Invalid Auth config file", email`: "Invalid auth configuration file", } + resetHomeDir() tmpHome, err := ioutil.TempDir("", "config-test") assert.NilError(t, err) defer os.RemoveAll(tmpHome) @@ -131,6 +132,7 @@ email`: "Invalid auth configuration file", } func TestOldValidAuth(t *testing.T) { + resetHomeDir() tmpHome, err := ioutil.TempDir("", "config-test") assert.NilError(t, err) defer os.RemoveAll(tmpHome) @@ -165,6 +167,7 @@ func TestOldValidAuth(t *testing.T) { } func TestOldJSONInvalid(t *testing.T) { + resetHomeDir() tmpHome, err := ioutil.TempDir("", "config-test") assert.NilError(t, err) defer os.RemoveAll(tmpHome) @@ -184,6 +187,7 @@ func TestOldJSONInvalid(t *testing.T) { } func TestOldJSON(t *testing.T) { + resetHomeDir() tmpHome, err := ioutil.TempDir("", "config-test") assert.NilError(t, err) defer os.RemoveAll(tmpHome)