diff --git a/cli/command/context/import.go b/cli/command/context/import.go index 27d5432e48..c09f8f89f5 100644 --- a/cli/command/context/import.go +++ b/cli/command/context/import.go @@ -14,7 +14,7 @@ import ( func newImportCommand(dockerCli command.Cli) *cobra.Command { cmd := &cobra.Command{ Use: "import CONTEXT FILE|-", - Short: "Import a context from a tar file", + Short: "Import a context from a tar or zip file", Args: cli.ExactArgs(2), RunE: func(cmd *cobra.Command, args []string) error { return RunImport(dockerCli, args[0], args[1]) @@ -28,6 +28,7 @@ func RunImport(dockerCli command.Cli, name string, source string) error { if err := checkContextNameForCreation(dockerCli.ContextStore(), name); err != nil { return err } + var reader io.Reader if source == "-" { reader = dockerCli.In() @@ -43,6 +44,7 @@ func RunImport(dockerCli command.Cli, name string, source string) error { if err := store.Import(name, dockerCli.ContextStore(), reader); err != nil { return err } + fmt.Fprintln(dockerCli.Out(), name) fmt.Fprintf(dockerCli.Err(), "Successfully imported context %q\n", name) return nil diff --git a/cli/context/store/io_utils.go b/cli/context/store/io_utils.go new file mode 100644 index 0000000000..6f854c8e57 --- /dev/null +++ b/cli/context/store/io_utils.go @@ -0,0 +1,29 @@ +package store + +import ( + "errors" + "io" +) + +// LimitedReader is a fork of io.LimitedReader to override Read. +type LimitedReader struct { + R io.Reader + N int64 // max bytes remaining +} + +// Read is a fork of io.LimitedReader.Read that returns an error when limit exceeded. +func (l *LimitedReader) Read(p []byte) (n int, err error) { + if l.N < 0 { + return 0, errors.New("read exceeds the defined limit") + } + if l.N == 0 { + return 0, io.EOF + } + // have to cap N + 1 otherwise we won't hit limit err + if int64(len(p)) > l.N+1 { + p = p[0 : l.N+1] + } + n, err = l.R.Read(p) + l.N -= int64(n) + return n, err +} diff --git a/cli/context/store/io_utils_test.go b/cli/context/store/io_utils_test.go new file mode 100644 index 0000000000..3840ae5f90 --- /dev/null +++ b/cli/context/store/io_utils_test.go @@ -0,0 +1,24 @@ +package store + +import ( + "io/ioutil" + "strings" + "testing" + + "gotest.tools/assert" +) + +func TestLimitReaderReadAll(t *testing.T) { + r := strings.NewReader("Reader") + + _, err := ioutil.ReadAll(r) + assert.NilError(t, err) + + r = strings.NewReader("Test") + _, err = ioutil.ReadAll(&LimitedReader{R: r, N: 4}) + assert.NilError(t, err) + + r = strings.NewReader("Test") + _, err = ioutil.ReadAll(&LimitedReader{R: r, N: 2}) + assert.Error(t, err, "read exceeds the defined limit") +} diff --git a/cli/context/store/store.go b/cli/context/store/store.go index 9e332e7c91..85be802c66 100644 --- a/cli/context/store/store.go +++ b/cli/context/store/store.go @@ -2,12 +2,16 @@ package store import ( "archive/tar" + "archive/zip" + "bufio" + "bytes" _ "crypto/sha256" // ensure ids can be computed "encoding/json" "errors" "fmt" "io" "io/ioutil" + "net/http" "path" "path/filepath" "strings" @@ -259,12 +263,44 @@ func Export(name string, s Reader) io.ReadCloser { return reader } +const ( + maxAllowedFileSizeToImport int64 = 10 << 20 + zipType string = "application/zip" +) + +func getImportContentType(r *bufio.Reader) (string, error) { + head, err := r.Peek(512) + if err != nil && err != io.EOF { + return "", err + } + + return http.DetectContentType(head), nil +} + // Import imports an exported context into a store func Import(name string, s Writer, reader io.Reader) error { - tr := tar.NewReader(reader) + // Buffered reader will not advance the buffer, needed to determine content type + r := bufio.NewReader(reader) + + importContentType, err := getImportContentType(r) + if err != nil { + return err + } + switch importContentType { + case zipType: + return importZip(name, s, r) + default: + // Assume it's a TAR (TAR does not have a "magic number") + return importTar(name, s, r) + } +} + +func importTar(name string, s Writer, reader io.Reader) error { + tr := tar.NewReader(&LimitedReader{R: reader, N: maxAllowedFileSizeToImport}) tlsData := ContextTLSData{ Endpoints: map[string]EndpointTLSData{}, } + for { hdr, err := tr.Next() if err == io.EOF { @@ -282,37 +318,112 @@ func Import(name string, s Writer, reader io.Reader) error { if err != nil { return err } - var meta Metadata - if err := json.Unmarshal(data, &meta); err != nil { + meta, err := parseMetadata(data, name) + if err != nil { return err } - meta.Name = name if err := s.CreateOrUpdate(meta); err != nil { return err } } else if strings.HasPrefix(hdr.Name, "tls/") { - relative := strings.TrimPrefix(hdr.Name, "tls/") - parts := strings.SplitN(relative, "/", 2) - if len(parts) != 2 { - return errors.New("archive format is invalid") - } - endpointName := parts[0] - fileName := parts[1] data, err := ioutil.ReadAll(tr) if err != nil { return err } - if _, ok := tlsData.Endpoints[endpointName]; !ok { - tlsData.Endpoints[endpointName] = EndpointTLSData{ - Files: map[string][]byte{}, - } + if err := importEndpointTLS(&tlsData, hdr.Name, data); err != nil { + return err } - tlsData.Endpoints[endpointName].Files[fileName] = data } } + return s.ResetTLSMaterial(name, &tlsData) } +func importZip(name string, s Writer, reader io.Reader) error { + body, err := ioutil.ReadAll(&LimitedReader{R: reader, N: maxAllowedFileSizeToImport}) + if err != nil { + return err + } + zr, err := zip.NewReader(bytes.NewReader(body), int64(len(body))) + if err != nil { + return err + } + tlsData := ContextTLSData{ + Endpoints: map[string]EndpointTLSData{}, + } + + for _, zf := range zr.File { + fi := zf.FileInfo() + if fi.IsDir() { + // skip this entry, only taking files into account + continue + } + if zf.Name == metaFile { + f, err := zf.Open() + if err != nil { + return err + } + + data, err := ioutil.ReadAll(&LimitedReader{R: f, N: maxAllowedFileSizeToImport}) + defer f.Close() + if err != nil { + return err + } + meta, err := parseMetadata(data, name) + if err != nil { + return err + } + if err := s.CreateOrUpdate(meta); err != nil { + return err + } + } else if strings.HasPrefix(zf.Name, "tls/") { + f, err := zf.Open() + if err != nil { + return err + } + data, err := ioutil.ReadAll(f) + defer f.Close() + if err != nil { + return err + } + err = importEndpointTLS(&tlsData, zf.Name, data) + if err != nil { + return err + } + } + } + + return s.ResetTLSMaterial(name, &tlsData) +} + +func parseMetadata(data []byte, name string) (Metadata, error) { + var meta Metadata + if err := json.Unmarshal(data, &meta); err != nil { + return meta, err + } + meta.Name = name + return meta, nil +} + +func importEndpointTLS(tlsData *ContextTLSData, path string, data []byte) error { + parts := strings.SplitN(strings.TrimPrefix(path, "tls/"), "/", 2) + if len(parts) != 2 { + // TLS endpoints require archived file directory with 2 layers + // i.e. tls/{endpointName}/{fileName} + return errors.New("archive format is invalid") + } + + epName := parts[0] + fileName := parts[1] + if _, ok := tlsData.Endpoints[epName]; !ok { + tlsData.Endpoints[epName] = EndpointTLSData{ + Files: map[string][]byte{}, + } + } + tlsData.Endpoints[epName].Files[fileName] = data + return nil +} + type setContextName interface { setContext(name string) } diff --git a/cli/context/store/store_test.go b/cli/context/store/store_test.go index 6ad11e2398..b1d0fec7fa 100644 --- a/cli/context/store/store_test.go +++ b/cli/context/store/store_test.go @@ -1,9 +1,15 @@ package store import ( + "archive/zip" + "bufio" + "bytes" "crypto/rand" + "encoding/json" + "io" "io/ioutil" "os" + "path" "testing" "gotest.tools/assert" @@ -125,3 +131,66 @@ func TestErrHasCorrectContext(t *testing.T) { assert.ErrorContains(t, err, "no-exists") assert.Check(t, IsErrContextDoesNotExist(err)) } + +func TestDetectImportContentType(t *testing.T) { + testDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(testDir) + + buf := new(bytes.Buffer) + r := bufio.NewReader(buf) + ct, err := getImportContentType(r) + assert.NilError(t, err) + assert.Assert(t, zipType != ct) +} + +func TestImportZip(t *testing.T) { + testDir, err := ioutil.TempDir("", t.Name()) + assert.NilError(t, err) + defer os.RemoveAll(testDir) + + zf := path.Join(testDir, "test.zip") + + f, err := os.Create(zf) + defer f.Close() + assert.NilError(t, err) + w := zip.NewWriter(f) + + meta, err := json.Marshal(Metadata{ + Endpoints: map[string]interface{}{ + "ep1": endpoint{Foo: "bar"}, + }, + Metadata: context{Bar: "baz"}, + Name: "source", + }) + assert.NilError(t, err) + var files = []struct { + Name, Body string + }{ + {"meta.json", string(meta)}, + {path.Join("tls", "docker", "ca.pem"), string([]byte("ca.pem"))}, + } + + for _, file := range files { + f, err := w.Create(file.Name) + assert.NilError(t, err) + _, err = f.Write([]byte(file.Body)) + assert.NilError(t, err) + } + + err = w.Close() + assert.NilError(t, err) + + source, err := os.Open(zf) + assert.NilError(t, err) + ct, err := getImportContentType(bufio.NewReader(source)) + assert.NilError(t, err) + assert.Equal(t, zipType, ct) + + source, _ = os.Open(zf) + defer source.Close() + var r io.Reader = source + s := New(testDir, testCfg) + err = Import("zipTest", s, r) + assert.NilError(t, err) +}