diff --git a/cli/compose/interpolation/interpolation.go b/cli/compose/interpolation/interpolation.go index 4cb65d706d..f05c5602cc 100644 --- a/cli/compose/interpolation/interpolation.go +++ b/cli/compose/interpolation/interpolation.go @@ -3,6 +3,8 @@ package interpolation import ( "os" + "strings" + "github.com/docker/cli/cli/compose/template" "github.com/pkg/errors" ) @@ -13,6 +15,8 @@ type Options struct { SectionName string // LookupValue from a key LookupValue LookupValue + // TypeCastMapping maps key paths to functions to cast to a type + TypeCastMapping map[Path]Cast } // LookupValue is a function which maps from variable names to values. @@ -21,6 +25,9 @@ type Options struct { // and the absence of a value. type LookupValue func(key string) (string, bool) +// Cast a value to a new type, or return an error if the value can't be cast +type Cast func(value string) (interface{}, error) + // Interpolate replaces variables in a string with the values from a mapping func Interpolate(config map[string]interface{}, opts Options) (map[string]interface{}, error) { out := map[string]interface{}{} @@ -28,6 +35,9 @@ func Interpolate(config map[string]interface{}, opts Options) (map[string]interf if opts.LookupValue == nil { opts.LookupValue = os.LookupEnv } + if opts.TypeCastMapping == nil { + opts.TypeCastMapping = make(map[Path]Cast) + } for key, item := range config { if item == nil { @@ -38,7 +48,7 @@ func Interpolate(config map[string]interface{}, opts Options) (map[string]interf if !ok { return nil, errors.Errorf("Invalid type for %s : %T instead of %T", key, item, out) } - interpolatedItem, err := interpolateSectionItem(key, mapItem, opts) + interpolatedItem, err := interpolateSectionItem(NewPath(key), mapItem, opts) if err != nil { return nil, err } @@ -49,23 +59,23 @@ func Interpolate(config map[string]interface{}, opts Options) (map[string]interf } func interpolateSectionItem( - sectionkey string, + path Path, item map[string]interface{}, opts Options, ) (map[string]interface{}, error) { out := map[string]interface{}{} for key, value := range item { - interpolatedValue, err := recursiveInterpolate(value, opts) + interpolatedValue, err := recursiveInterpolate(value, path.Next(key), opts) switch err := err.(type) { case nil: case *template.InvalidTemplateError: return nil, errors.Errorf( "Invalid interpolation format for %#v option in %s %#v: %#v. You may need to escape any $ with another $.", - key, opts.SectionName, sectionkey, err.Template, + key, opts.SectionName, path.root(), err.Template, ) default: - return nil, errors.Wrapf(err, "error while interpolating %s in %s %s", key, opts.SectionName, sectionkey) + return nil, errors.Wrapf(err, "error while interpolating %s in %s %s", key, opts.SectionName, path.root()) } out[key] = interpolatedValue } @@ -73,16 +83,24 @@ func interpolateSectionItem( return out, nil } -func recursiveInterpolate(value interface{}, opts Options) (interface{}, error) { +func recursiveInterpolate(value interface{}, path Path, opts Options) (interface{}, error) { switch value := value.(type) { case string: - return template.Substitute(value, template.Mapping(opts.LookupValue)) + newValue, err := template.Substitute(value, template.Mapping(opts.LookupValue)) + if err != nil || newValue == value { + return value, err + } + caster, ok := opts.getCasterForPath(path) + if !ok { + return newValue, nil + } + return caster(newValue) case map[string]interface{}: out := map[string]interface{}{} for key, elem := range value { - interpolatedElem, err := recursiveInterpolate(elem, opts) + interpolatedElem, err := recursiveInterpolate(elem, path.Next(key), opts) if err != nil { return nil, err } @@ -93,7 +111,7 @@ func recursiveInterpolate(value interface{}, opts Options) (interface{}, error) case []interface{}: out := make([]interface{}, len(value)) for i, elem := range value { - interpolatedElem, err := recursiveInterpolate(elem, opts) + interpolatedElem, err := recursiveInterpolate(elem, path, opts) if err != nil { return nil, err } @@ -106,3 +124,62 @@ func recursiveInterpolate(value interface{}, opts Options) (interface{}, error) } } + +const pathSeparator = "." + +// PathMatchAll is a token used as part of a Path to match any key at that level +// in the nested structure +const PathMatchAll = "*" + +// Path is a dotted path of keys to a value in a nested mapping structure. A * +// section in a path will match any key in the mapping structure. +type Path string + +// NewPath returns a new Path +func NewPath(items ...string) Path { + return Path(strings.Join(items, pathSeparator)) +} + +// Next returns a new path by append part to the current path +func (p Path) Next(part string) Path { + return Path(string(p) + pathSeparator + part) +} + +func (p Path) root() string { + parts := p.parts() + if len(parts) == 0 { + return "" + } + return parts[0] +} + +func (p Path) parts() []string { + return strings.Split(string(p), pathSeparator) +} + +func (p Path) matches(pattern Path) bool { + patternParts := pattern.parts() + parts := p.parts() + + if len(patternParts) != len(parts) { + return false + } + for index, part := range parts { + switch patternParts[index] { + case PathMatchAll, part: + continue + default: + return false + } + } + return true +} + +func (o Options) getCasterForPath(path Path) (Cast, bool) { + for pattern, caster := range o.TypeCastMapping { + if path.matches(pattern) { + return caster, true + } + } + return nil, false +} diff --git a/cli/compose/interpolation/interpolation_test.go b/cli/compose/interpolation/interpolation_test.go index 08de5c83e0..3248841026 100644 --- a/cli/compose/interpolation/interpolation_test.go +++ b/cli/compose/interpolation/interpolation_test.go @@ -3,13 +3,16 @@ package interpolation import ( "testing" + "strconv" + "github.com/gotestyourself/gotestyourself/env" "github.com/stretchr/testify/assert" ) var defaults = map[string]string{ - "USER": "jenny", - "FOO": "bar", + "USER": "jenny", + "FOO": "bar", + "count": "5", } func defaultMapping(name string) (string, bool) { @@ -80,3 +83,70 @@ func TestInterpolateWithDefaults(t *testing.T) { assert.NoError(t, err) assert.Equal(t, expected, result) } + +func TestInterpolateWithCast(t *testing.T) { + config := map[string]interface{}{ + "foo": map[string]interface{}{ + "replicas": "$count", + }, + } + toInt := func(value string) (interface{}, error) { + return strconv.Atoi(value) + } + result, err := Interpolate(config, Options{ + LookupValue: defaultMapping, + TypeCastMapping: map[Path]Cast{NewPath("*", "replicas"): toInt}, + }) + assert.NoError(t, err) + expected := map[string]interface{}{ + "foo": map[string]interface{}{ + "replicas": 5, + }, + } + assert.Equal(t, expected, result) +} + +func TestPathMatches(t *testing.T) { + var testcases = []struct { + doc string + path Path + pattern Path + expected bool + }{ + { + doc: "pattern too short", + path: NewPath("one", "two", "three"), + pattern: NewPath("one", "two"), + }, + { + doc: "pattern too long", + path: NewPath("one", "two"), + pattern: NewPath("one", "two", "three"), + }, + { + doc: "pattern mismatch", + path: NewPath("one", "three", "two"), + pattern: NewPath("one", "two", "three"), + }, + { + doc: "pattern mismatch with match-all part", + path: NewPath("one", "three", "two"), + pattern: NewPath(PathMatchAll, "two", "three"), + }, + { + doc: "pattern match with match-all part", + path: NewPath("one", "two", "three"), + pattern: NewPath("one", "*", "three"), + expected: true, + }, + { + doc: "pattern match", + path: NewPath("one", "two", "three"), + pattern: NewPath("one", "two", "three"), + expected: true, + }, + } + for _, testcase := range testcases { + assert.Equal(t, testcase.expected, testcase.path.matches(testcase.pattern)) + } +} diff --git a/cli/compose/loader/interpolate.go b/cli/compose/loader/interpolate.go new file mode 100644 index 0000000000..151d2034c3 --- /dev/null +++ b/cli/compose/loader/interpolate.go @@ -0,0 +1,95 @@ +package loader + +import ( + "strconv" + "strings" + + interp "github.com/docker/cli/cli/compose/interpolation" + "github.com/pkg/errors" +) + +var interpolateTypeCastMapping = map[string]map[interp.Path]interp.Cast{ + "services": { + iPath("configs", "mode"): toInt, + iPath("secrets", "mode"): toInt, + iPath("healthcheck", "retries"): toInt, + iPath("healthcheck", "disable"): toBoolean, + iPath("deploy", "replicas"): toInt, + iPath("deploy", "update_config", "parallelism:"): toInt, + iPath("deploy", "update_config", "max_failure_ratio"): toFloat, + iPath("deploy", "restart_policy", "max_attempts"): toInt, + iPath("ports", "target"): toInt, + iPath("ports", "published"): toInt, + iPath("ulimits", interp.PathMatchAll): toInt, + iPath("ulimits", interp.PathMatchAll, "hard"): toInt, + iPath("ulimits", interp.PathMatchAll, "soft"): toInt, + iPath("privileged"): toBoolean, + iPath("read_only"): toBoolean, + iPath("stdin_open"): toBoolean, + iPath("tty"): toBoolean, + iPath("volumes", "read_only"): toBoolean, + iPath("volumes", "volume", "nocopy"): toBoolean, + }, + "networks": { + iPath("external"): toBoolean, + iPath("internal"): toBoolean, + iPath("attachable"): toBoolean, + }, + "volumes": { + iPath("external"): toBoolean, + }, + "secrets": { + iPath("external"): toBoolean, + }, + "configs": { + iPath("external"): toBoolean, + }, +} + +func iPath(parts ...string) interp.Path { + return interp.NewPath(append([]string{interp.PathMatchAll}, parts...)...) +} + +func toInt(value string) (interface{}, error) { + return strconv.Atoi(value) +} + +func toFloat(value string) (interface{}, error) { + return strconv.ParseFloat(value, 64) +} + +// should match http://yaml.org/type/bool.html +func toBoolean(value string) (interface{}, error) { + switch strings.ToLower(value) { + case "y", "yes", "true", "on": + return true, nil + case "n", "no", "false", "off": + return false, nil + default: + return nil, errors.Errorf("invalid boolean: %s", value) + } +} + +func interpolateConfig(configDict map[string]interface{}, lookupEnv interp.LookupValue) (map[string]map[string]interface{}, error) { + config := make(map[string]map[string]interface{}) + + for _, key := range []string{"services", "networks", "volumes", "secrets", "configs"} { + section, ok := configDict[key] + if !ok { + config[key] = make(map[string]interface{}) + continue + } + var err error + config[key], err = interp.Interpolate( + section.(map[string]interface{}), + interp.Options{ + SectionName: key, + LookupValue: lookupEnv, + TypeCastMapping: interpolateTypeCastMapping[key], + }) + if err != nil { + return nil, err + } + } + return config, nil +} diff --git a/cli/compose/loader/loader.go b/cli/compose/loader/loader.go index 972395dd72..fe1c634347 100644 --- a/cli/compose/loader/loader.go +++ b/cli/compose/loader/loader.go @@ -8,7 +8,6 @@ import ( "sort" "strings" - "github.com/docker/cli/cli/compose/interpolation" "github.com/docker/cli/cli/compose/schema" "github.com/docker/cli/cli/compose/template" "github.com/docker/cli/cli/compose/types" @@ -51,14 +50,13 @@ func Load(configDetails types.ConfigDetails) (*types.Config, error) { configDict := getConfigDict(configDetails) - if services, ok := configDict["services"]; ok { - if servicesDict, ok := services.(map[string]interface{}); ok { - forbidden := getProperties(servicesDict, types.ForbiddenProperties) + if err := validateForbidden(configDict); err != nil { + return nil, err + } - if len(forbidden) > 0 { - return nil, &ForbiddenPropertiesError{Properties: forbidden} - } - } + config, err := interpolateConfig(configDict, configDetails.LookupEnv) + if err != nil { + return nil, err } if err := schema.Validate(configDict, schema.Version(configDict)); err != nil { @@ -66,12 +64,6 @@ func Load(configDetails types.ConfigDetails) (*types.Config, error) { } cfg := types.Config{} - - config, err := interpolateConfig(configDict, configDetails.LookupEnv) - if err != nil { - return nil, err - } - cfg.Services, err = LoadServices(config["services"], configDetails.WorkingDir, configDetails.LookupEnv) if err != nil { return nil, err @@ -96,22 +88,16 @@ func Load(configDetails types.ConfigDetails) (*types.Config, error) { return &cfg, err } -func interpolateConfig(configDict map[string]interface{}, lookupEnv template.Mapping) (map[string]map[string]interface{}, error) { - config := make(map[string]map[string]interface{}) - - for _, key := range []string{"services", "networks", "volumes", "secrets", "configs"} { - section, ok := configDict[key] - if !ok { - config[key] = make(map[string]interface{}) - continue - } - var err error - config[key], err = interpolation.Interpolate(section.(map[string]interface{}), key, lookupEnv) - if err != nil { - return nil, err - } +func validateForbidden(configDict map[string]interface{}) error { + servicesDict, ok := configDict["services"].(map[string]interface{}) + if !ok { + return nil } - return config, nil + forbidden := getProperties(servicesDict, types.ForbiddenProperties) + if len(forbidden) > 0 { + return &ForbiddenPropertiesError{Properties: forbidden} + } + return nil } // GetUnsupportedProperties returns the list of any unsupported properties that are diff --git a/cli/compose/loader/loader_test.go b/cli/compose/loader/loader_test.go index 863714dff1..0cb43ec816 100644 --- a/cli/compose/loader/loader_test.go +++ b/cli/compose/loader/loader_test.go @@ -465,7 +465,7 @@ services: assert.Contains(t, err.Error(), "services.dict-env.environment must be a mapping") } -func TestEnvironmentInterpolation(t *testing.T) { +func TestLoadWithEnvironmentInterpolation(t *testing.T) { home := "/home/foo" config, err := loadYAMLWithEnv(` version: "3" @@ -502,19 +502,88 @@ volumes: assert.Equal(t, home, config.Volumes["test"].Driver) } +func TestLoadWithInterpolationCastFull(t *testing.T) { + dict, err := ParseYAML([]byte(` +version: "3.4" +services: + web: + configs: + - source: appconfig + mode: $theint + secrets: + - source: super + mode: $theint + healthcheck: + retries: ${theint} + disable: $thebool + deploy: + replicas: $theint + update_config: + parallelism: $theint + max_failure_ratio: $thefloat + restart_policy: + max_attempts: $theint + ports: + - $theint + - "34567" + - target: $theint + published: $theint + ulimits: + - $theint + - hard: $theint + soft: $theint + privileged: $thebool + read_only: $thebool + stdin_open: ${thebool} + tty: $thebool + volumes: + - source: data + read_only: $thebool + volume: + nocopy: $thebool + +configs: + appconfig: + external: $thebool +secrets: + super: + external: $thebool +volumes: + data: + external: $thebool +networks: + front: + external: $thebool + internal: $thebool + attachable: $thebool + +`)) + require.NoError(t, err) + env := map[string]string{ + "theint": "555", + "thefloat": "3.14", + "thebool": "true", + } + + config, err := Load(buildConfigDetails(dict, env)) + require.NoError(t, err) + expected := &types.Config{} + assert.Equal(t, expected, config) +} + func TestUnsupportedProperties(t *testing.T) { dict, err := ParseYAML([]byte(` version: "3" services: web: image: web - build: + build: context: ./web links: - bar db: image: db - build: + build: context: ./db `)) require.NoError(t, err)