cp: allow trailing slash in non-existant destination

Signed-off-by: Tibor Vass <tibor@docker.com>
This commit is contained in:
Tibor Vass 2019-10-30 19:07:18 +00:00
parent 37f9a88c69
commit 26dbc3226c
2 changed files with 41 additions and 1 deletions

View File

@ -130,7 +130,7 @@ func AddPlatformFlag(flags *pflag.FlagSet, target *string) {
// ValidateOutputPath validates the output paths of the `export` and `save` commands. // ValidateOutputPath validates the output paths of the `export` and `save` commands.
func ValidateOutputPath(path string) error { func ValidateOutputPath(path string) error {
dir := filepath.Dir(path) dir := filepath.Dir(filepath.Clean(path))
if dir != "" && dir != "." { if dir != "" && dir != "." {
if _, err := os.Stat(dir); os.IsNotExist(err) { if _, err := os.Stat(dir); os.IsNotExist(err) {
return errors.Errorf("invalid output path: directory %q does not exist", dir) return errors.Errorf("invalid output path: directory %q does not exist", dir)

View File

@ -1,8 +1,12 @@
package command package command
import ( import (
"io/ioutil"
"os"
"path/filepath"
"testing" "testing"
"github.com/pkg/errors"
"gotest.tools/assert" "gotest.tools/assert"
) )
@ -31,3 +35,39 @@ func TestStringSliceReplaceAt(t *testing.T) {
assert.Assert(t, !ok) assert.Assert(t, !ok)
assert.DeepEqual(t, []string{"foo"}, out) assert.DeepEqual(t, []string{"foo"}, out)
} }
func TestValidateOutputPath(t *testing.T) {
basedir, err := ioutil.TempDir("", "TestValidateOutputPath")
assert.NilError(t, err)
defer os.RemoveAll(basedir)
dir := filepath.Join(basedir, "dir")
notexist := filepath.Join(basedir, "notexist")
err = os.MkdirAll(dir, 0755)
assert.NilError(t, err)
file := filepath.Join(dir, "file")
err = ioutil.WriteFile(file, []byte("hi"), 0644)
assert.NilError(t, err)
var testcases = []struct {
path string
err error
}{
{basedir, nil},
{file, nil},
{dir, nil},
{dir + string(os.PathSeparator), nil},
{notexist, nil},
{notexist + string(os.PathSeparator), nil},
{filepath.Join(notexist, "file"), errors.New("does not exist")},
}
for _, testcase := range testcases {
t.Run(testcase.path, func(t *testing.T) {
err := ValidateOutputPath(testcase.path)
if testcase.err == nil {
assert.NilError(t, err)
} else {
assert.ErrorContains(t, err, testcase.err.Error())
}
})
}
}