From c2c7503d498db3fba5054430fe4589a0bc0b7739 Mon Sep 17 00:00:00 2001 From: Aleksander Piotrowski Date: Fri, 21 Jun 2019 22:11:48 +0200 Subject: [PATCH] Convert ports before parsing. Refactor code to allow mixed notation with -p flag. Signed-off-by: Aleksander Piotrowski Signed-off-by: Sebastiaan van Stijn --- cli/command/container/opts.go | 56 ++++++++++++++++-------------- cli/command/container/opts_test.go | 33 +++++++++++++++--- 2 files changed, 58 insertions(+), 31 deletions(-) diff --git a/cli/command/container/opts.go b/cli/command/container/opts.go index 5ae7e8cc01..c03c0ae698 100644 --- a/cli/command/container/opts.go +++ b/cli/command/container/opts.go @@ -379,23 +379,20 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con } publishOpts := copts.publish.GetAll() - var ports map[nat.Port]struct{} - var portBindings map[nat.Port][]nat.PortBinding + var ( + ports map[nat.Port]struct{} + portBindings map[nat.Port][]nat.PortBinding + convertedOpts []string + ) - ports, portBindings, err = nat.ParsePortSpecs(publishOpts) - - // If simple port parsing fails try to parse as long format + convertedOpts, err = convertToStandardNotation(publishOpts) if err != nil { - publishOpts, err = parsePortOpts(publishOpts) - if err != nil { - return nil, err - } + return nil, err + } - ports, portBindings, err = nat.ParsePortSpecs(publishOpts) - - if err != nil { - return nil, err - } + ports, portBindings, err = nat.ParsePortSpecs(convertedOpts) + if err != nil { + return nil, err } // Merge in exposed ports to the map of published ports @@ -403,10 +400,11 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con if strings.Contains(e, ":") { return nil, errors.Errorf("invalid port format for --expose: %s", e) } - //support two formats for expose, original format /[] or /[] + // support two formats for expose, original format /[] + // or /[] proto, port := nat.SplitProtoPort(e) - //parse the start and end port and create a sequence of ports to expose - //if expose a port, the start and end port are the same + // parse the start and end port and create a sequence of ports to expose + // if expose a port, the start and end port are the same start, end, err := nat.ParsePortRange(port) if err != nil { return nil, errors.Errorf("invalid range format for --expose: %s, error: %s", e, err) @@ -796,19 +794,23 @@ func parseNetworkAttachmentOpt(ep opts.NetworkAttachmentOpts) (*networktypes.End return epConfig, nil } -func parsePortOpts(publishOpts []string) ([]string, error) { +func convertToStandardNotation(ports []string) ([]string, error) { optsList := []string{} - for _, publish := range publishOpts { - params := map[string]string{"protocol": "tcp"} - for _, param := range strings.Split(publish, ",") { - opt := strings.Split(param, "=") - if len(opt) < 2 { - return optsList, errors.Errorf("invalid publish opts format (should be name=value but got '%s')", param) - } + for _, publish := range ports { + if strings.Contains(publish, "=") { + params := map[string]string{"protocol": "tcp"} + for _, param := range strings.Split(publish, ",") { + opt := strings.Split(param, "=") + if len(opt) < 2 { + return optsList, errors.Errorf("invalid publish opts format (should be name=value but got '%s')", param) + } - params[opt[0]] = opt[1] + params[opt[0]] = opt[1] + } + optsList = append(optsList, fmt.Sprintf("%s:%s/%s", params["published"], params["target"], params["protocol"])) + } else { + optsList = append(optsList, publish) } - optsList = append(optsList, fmt.Sprintf("%s:%s/%s", params["published"], params["target"], params["protocol"])) } return optsList, nil } diff --git a/cli/command/container/opts_test.go b/cli/command/container/opts_test.go index 46a4ba5192..c2776d1db0 100644 --- a/cli/command/container/opts_test.go +++ b/cli/command/container/opts_test.go @@ -873,8 +873,33 @@ func TestParseSystemPaths(t *testing.T) { } } -func TestParsePortOpts(t *testing.T) { - parsed, err := parsePortOpts([]string{"published=1500,target=200", "target=80,published=90"}) - assert.NilError(t, err) - assert.DeepEqual(t, []string{"1500:200/tcp", "90:80/tcp"}, parsed) +func TestConvertToStandardNotation(t *testing.T) { + valid := map[string][]string{ + "20:10/tcp": {"target=10,published=20"}, + "40:30": {"40:30"}, + "20:20 80:4444": {"20:20", "80:4444"}, + "1500:2500/tcp 1400:1300": {"target=2500,published=1500", "1400:1300"}, + "1500:200/tcp 90:80/tcp": {"published=1500,target=200", "target=80,published=90"}, + } + + invalid := [][]string{ + {"published=1500,target:444"}, + {"published=1500,444"}, + {"published=1500,target,444"}, + } + + for key, ports := range valid { + convertedPorts, err := convertToStandardNotation(ports) + + if err != nil { + assert.NilError(t, err) + } + assert.DeepEqual(t, strings.Split(key, " "), convertedPorts) + } + + for _, ports := range invalid { + if _, err := convertToStandardNotation(ports); err == nil { + t.Fatalf("ConvertToStandardNotation(`%q`) should have failed conversion", ports) + } + } }