diff --git a/cli/compose/schema/schema.go b/cli/compose/schema/schema.go index b4861556a3..129e116d19 100644 --- a/cli/compose/schema/schema.go +++ b/cli/compose/schema/schema.go @@ -6,9 +6,11 @@ package schema import ( "embed" "fmt" + "math/big" "strings" "time" + "github.com/docker/go-connections/nat" "github.com/pkg/errors" "github.com/xeipuuv/gojsonschema" ) @@ -20,9 +22,18 @@ const ( type portsFormatChecker struct{} -func (checker portsFormatChecker) IsFormat(_ any) bool { - // TODO: implement this - return true +func (checker portsFormatChecker) IsFormat(input any) bool { + var portSpec string + + switch p := input.(type) { + case string: + portSpec = p + case *big.Rat: + portSpec = strings.Split(p.String(), "/")[0] + } + + _, err := nat.ParsePortSpec(portSpec) + return err == nil } type durationFormatChecker struct{} @@ -37,7 +48,6 @@ func (checker durationFormatChecker) IsFormat(input any) bool { } func init() { - gojsonschema.FormatCheckers.Add("expose", portsFormatChecker{}) gojsonschema.FormatCheckers.Add("ports", portsFormatChecker{}) gojsonschema.FormatCheckers.Add("duration", durationFormatChecker{}) } diff --git a/cli/compose/schema/schema_test.go b/cli/compose/schema/schema_test.go index 2f92111305..33a992a362 100644 --- a/cli/compose/schema/schema_test.go +++ b/cli/compose/schema/schema_test.go @@ -29,6 +29,111 @@ func TestValidate(t *testing.T) { assert.ErrorContains(t, Validate(config, "12345"), "unsupported Compose file version: 12345") } +func TestValidatePorts(t *testing.T) { + testcases := []struct { + ports any + hasError bool + }{ + { + ports: []int{8000}, + hasError: false, + }, + { + ports: []string{"8000:8000"}, + hasError: false, + }, + { + ports: []string{"8001-8005"}, + hasError: false, + }, + { + ports: []string{"8001-8005:8001-8005"}, + hasError: false, + }, + { + ports: []string{"8000"}, + hasError: false, + }, + { + ports: []string{"8000-9000:80"}, + hasError: false, + }, + { + ports: []string{"[::1]:8080:8000"}, + hasError: false, + }, + { + ports: []string{"[::1]:8080-8085:8000"}, + hasError: false, + }, + { + ports: []string{"127.0.0.1:8000:8000"}, + hasError: false, + }, + { + ports: []string{"127.0.0.1:8000-8005:8000-8005"}, + hasError: false, + }, + { + ports: []string{"127.0.0.1:8000:8000/udp"}, + hasError: false, + }, + { + ports: []string{"8000:8000/udp"}, + hasError: false, + }, + { + ports: []string{"8000:8000/http"}, + hasError: true, + }, + { + ports: []string{"-1"}, + hasError: true, + }, + { + ports: []string{"65536"}, + hasError: true, + }, + { + ports: []string{"-1:65536/http"}, + hasError: true, + }, + { + ports: []string{"invalid"}, + hasError: true, + }, + { + ports: []string{"12345678:8000:8000/tcp"}, + hasError: true, + }, + { + ports: []string{"8005-8000:8005-8000"}, + hasError: true, + }, + { + ports: []string{"8006-8000:8005-8000"}, + hasError: true, + }, + } + + for _, tc := range testcases { + config := dict{ + "version": "3.0", + "services": dict{ + "foo": dict{ + "image": "busybox", + "ports": tc.ports, + }, + }, + } + if tc.hasError { + assert.ErrorContains(t, Validate(config, "3"), "services.foo.ports.0 Does not match format 'ports'") + } else { + assert.NilError(t, Validate(config, "3")) + } + } +} + func TestValidateUndefinedTopLevelOption(t *testing.T) { config := dict{ "version": "3.0",