From 2fa69baf8d96864f6c4f856babd4e50ec2e36a78 Mon Sep 17 00:00:00 2001 From: Tibor Vass Date: Wed, 30 Oct 2019 19:07:18 +0000 Subject: [PATCH] cp: allow trailing slash in non-existant destination Signed-off-by: Tibor Vass (cherry picked from commit 26dbc3226c7044063bc1ddc0a0fd73026b6644e6) Signed-off-by: Tibor Vass --- cli/command/utils.go | 2 +- cli/command/utils_test.go | 40 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/cli/command/utils.go b/cli/command/utils.go index 21e702eb3f..713c5a3263 100644 --- a/cli/command/utils.go +++ b/cli/command/utils.go @@ -130,7 +130,7 @@ func AddPlatformFlag(flags *pflag.FlagSet, target *string) { // ValidateOutputPath validates the output paths of the `export` and `save` commands. func ValidateOutputPath(path string) error { - dir := filepath.Dir(path) + dir := filepath.Dir(filepath.Clean(path)) if dir != "" && dir != "." { if _, err := os.Stat(dir); os.IsNotExist(err) { return errors.Errorf("invalid output path: directory %q does not exist", dir) diff --git a/cli/command/utils_test.go b/cli/command/utils_test.go index 0452a7c4cf..a4257200e1 100644 --- a/cli/command/utils_test.go +++ b/cli/command/utils_test.go @@ -1,8 +1,12 @@ package command import ( + "io/ioutil" + "os" + "path/filepath" "testing" + "github.com/pkg/errors" "gotest.tools/assert" ) @@ -31,3 +35,39 @@ func TestStringSliceReplaceAt(t *testing.T) { assert.Assert(t, !ok) assert.DeepEqual(t, []string{"foo"}, out) } + +func TestValidateOutputPath(t *testing.T) { + basedir, err := ioutil.TempDir("", "TestValidateOutputPath") + assert.NilError(t, err) + defer os.RemoveAll(basedir) + dir := filepath.Join(basedir, "dir") + notexist := filepath.Join(basedir, "notexist") + err = os.MkdirAll(dir, 0755) + assert.NilError(t, err) + file := filepath.Join(dir, "file") + err = ioutil.WriteFile(file, []byte("hi"), 0644) + assert.NilError(t, err) + var testcases = []struct { + path string + err error + }{ + {basedir, nil}, + {file, nil}, + {dir, nil}, + {dir + string(os.PathSeparator), nil}, + {notexist, nil}, + {notexist + string(os.PathSeparator), nil}, + {filepath.Join(notexist, "file"), errors.New("does not exist")}, + } + + for _, testcase := range testcases { + t.Run(testcase.path, func(t *testing.T) { + err := ValidateOutputPath(testcase.path) + if testcase.err == nil { + assert.NilError(t, err) + } else { + assert.ErrorContains(t, err, testcase.err.Error()) + } + }) + } +}