diff --git a/cli/compose/interpolation/interpolation.go b/cli/compose/interpolation/interpolation.go index 38673c5c00..a1e7719860 100644 --- a/cli/compose/interpolation/interpolation.go +++ b/cli/compose/interpolation/interpolation.go @@ -1,75 +1,71 @@ package interpolation import ( + "os" + "strings" + "github.com/docker/cli/cli/compose/template" "github.com/pkg/errors" ) -// Interpolate replaces variables in a string with the values from a mapping -func Interpolate(config map[string]interface{}, section string, mapping template.Mapping) (map[string]interface{}, error) { - out := map[string]interface{}{} - - for name, item := range config { - if item == nil { - out[name] = nil - continue - } - mapItem, ok := item.(map[string]interface{}) - if !ok { - return nil, errors.Errorf("Invalid type for %s : %T instead of %T", name, item, out) - } - interpolatedItem, err := interpolateSectionItem(name, mapItem, section, mapping) - if err != nil { - return nil, err - } - out[name] = interpolatedItem - } - - return out, nil +// Options supported by Interpolate +type Options struct { + // LookupValue from a key + LookupValue LookupValue + // TypeCastMapping maps key paths to functions to cast to a type + TypeCastMapping map[Path]Cast } -func interpolateSectionItem( - name string, - item map[string]interface{}, - section string, - mapping template.Mapping, -) (map[string]interface{}, error) { +// LookupValue is a function which maps from variable names to values. +// Returns the value as a string and a bool indicating whether +// the value is present, to distinguish between an empty string +// and the absence of a value. +type LookupValue func(key string) (string, bool) + +// Cast a value to a new type, or return an error if the value can't be cast +type Cast func(value string) (interface{}, error) + +// Interpolate replaces variables in a string with the values from a mapping +func Interpolate(config map[string]interface{}, opts Options) (map[string]interface{}, error) { + if opts.LookupValue == nil { + opts.LookupValue = os.LookupEnv + } + if opts.TypeCastMapping == nil { + opts.TypeCastMapping = make(map[Path]Cast) + } out := map[string]interface{}{} - for key, value := range item { - interpolatedValue, err := recursiveInterpolate(value, mapping) - switch err := err.(type) { - case nil: - case *template.InvalidTemplateError: - return nil, errors.Errorf( - "Invalid interpolation format for %#v option in %s %#v: %#v. You may need to escape any $ with another $.", - key, section, name, err.Template, - ) - default: - return nil, errors.Wrapf(err, "error while interpolating %s in %s %s", key, section, name) + for key, value := range config { + interpolatedValue, err := recursiveInterpolate(value, NewPath(key), opts) + if err != nil { + return out, err } out[key] = interpolatedValue } return out, nil - } -func recursiveInterpolate( - value interface{}, - mapping template.Mapping, -) (interface{}, error) { - +func recursiveInterpolate(value interface{}, path Path, opts Options) (interface{}, error) { switch value := value.(type) { case string: - return template.Substitute(value, mapping) + newValue, err := template.Substitute(value, template.Mapping(opts.LookupValue)) + if err != nil || newValue == value { + return value, newPathError(path, err) + } + caster, ok := opts.getCasterForPath(path) + if !ok { + return newValue, nil + } + casted, err := caster(newValue) + return casted, newPathError(path, errors.Wrap(err, "failed to cast to expected type")) case map[string]interface{}: out := map[string]interface{}{} for key, elem := range value { - interpolatedElem, err := recursiveInterpolate(elem, mapping) + interpolatedElem, err := recursiveInterpolate(elem, path.Next(key), opts) if err != nil { return nil, err } @@ -80,7 +76,7 @@ func recursiveInterpolate( case []interface{}: out := make([]interface{}, len(value)) for i, elem := range value { - interpolatedElem, err := recursiveInterpolate(elem, mapping) + interpolatedElem, err := recursiveInterpolate(elem, path.Next(PathMatchList), opts) if err != nil { return nil, err } @@ -92,5 +88,71 @@ func recursiveInterpolate( return value, nil } - +} + +func newPathError(path Path, err error) error { + switch err := err.(type) { + case nil: + return nil + case *template.InvalidTemplateError: + return errors.Errorf( + "invalid interpolation format for %s: %#v. You may need to escape any $ with another $.", + path, err.Template) + default: + return errors.Wrapf(err, "error while interpolating %s", path) + } +} + +const pathSeparator = "." + +// PathMatchAll is a token used as part of a Path to match any key at that level +// in the nested structure +const PathMatchAll = "*" + +// PathMatchList is a token used as part of a Path to match items in a list +const PathMatchList = "[]" + +// Path is a dotted path of keys to a value in a nested mapping structure. A * +// section in a path will match any key in the mapping structure. +type Path string + +// NewPath returns a new Path +func NewPath(items ...string) Path { + return Path(strings.Join(items, pathSeparator)) +} + +// Next returns a new path by append part to the current path +func (p Path) Next(part string) Path { + return Path(string(p) + pathSeparator + part) +} + +func (p Path) parts() []string { + return strings.Split(string(p), pathSeparator) +} + +func (p Path) matches(pattern Path) bool { + patternParts := pattern.parts() + parts := p.parts() + + if len(patternParts) != len(parts) { + return false + } + for index, part := range parts { + switch patternParts[index] { + case PathMatchAll, part: + continue + default: + return false + } + } + return true +} + +func (o Options) getCasterForPath(path Path) (Cast, bool) { + for pattern, caster := range o.TypeCastMapping { + if path.matches(pattern) { + return caster, true + } + } + return nil, false } diff --git a/cli/compose/interpolation/interpolation_test.go b/cli/compose/interpolation/interpolation_test.go index 9b055f4703..8f5d50db65 100644 --- a/cli/compose/interpolation/interpolation_test.go +++ b/cli/compose/interpolation/interpolation_test.go @@ -3,12 +3,16 @@ package interpolation import ( "testing" + "strconv" + + "github.com/gotestyourself/gotestyourself/env" "github.com/stretchr/testify/assert" ) var defaults = map[string]string{ - "USER": "jenny", - "FOO": "bar", + "USER": "jenny", + "FOO": "bar", + "count": "5", } func defaultMapping(name string) (string, bool) { @@ -41,7 +45,7 @@ func TestInterpolate(t *testing.T) { }, }, } - result, err := Interpolate(services, "service", defaultMapping) + result, err := Interpolate(services, Options{LookupValue: defaultMapping}) assert.NoError(t, err) assert.Equal(t, expected, result) } @@ -52,6 +56,91 @@ func TestInvalidInterpolation(t *testing.T) { "image": "${", }, } - _, err := Interpolate(services, "service", defaultMapping) - assert.EqualError(t, err, `Invalid interpolation format for "image" option in service "servicea": "${". You may need to escape any $ with another $.`) + _, err := Interpolate(services, Options{LookupValue: defaultMapping}) + assert.EqualError(t, err, `invalid interpolation format for servicea.image: "${". You may need to escape any $ with another $.`) +} + +func TestInterpolateWithDefaults(t *testing.T) { + defer env.Patch(t, "FOO", "BARZ")() + + config := map[string]interface{}{ + "networks": map[string]interface{}{ + "foo": "thing_${FOO}", + }, + } + expected := map[string]interface{}{ + "networks": map[string]interface{}{ + "foo": "thing_BARZ", + }, + } + result, err := Interpolate(config, Options{}) + assert.NoError(t, err) + assert.Equal(t, expected, result) +} + +func TestInterpolateWithCast(t *testing.T) { + config := map[string]interface{}{ + "foo": map[string]interface{}{ + "replicas": "$count", + }, + } + toInt := func(value string) (interface{}, error) { + return strconv.Atoi(value) + } + result, err := Interpolate(config, Options{ + LookupValue: defaultMapping, + TypeCastMapping: map[Path]Cast{NewPath(PathMatchAll, "replicas"): toInt}, + }) + assert.NoError(t, err) + expected := map[string]interface{}{ + "foo": map[string]interface{}{ + "replicas": 5, + }, + } + assert.Equal(t, expected, result) +} + +func TestPathMatches(t *testing.T) { + var testcases = []struct { + doc string + path Path + pattern Path + expected bool + }{ + { + doc: "pattern too short", + path: NewPath("one", "two", "three"), + pattern: NewPath("one", "two"), + }, + { + doc: "pattern too long", + path: NewPath("one", "two"), + pattern: NewPath("one", "two", "three"), + }, + { + doc: "pattern mismatch", + path: NewPath("one", "three", "two"), + pattern: NewPath("one", "two", "three"), + }, + { + doc: "pattern mismatch with match-all part", + path: NewPath("one", "three", "two"), + pattern: NewPath(PathMatchAll, "two", "three"), + }, + { + doc: "pattern match with match-all part", + path: NewPath("one", "two", "three"), + pattern: NewPath("one", "*", "three"), + expected: true, + }, + { + doc: "pattern match", + path: NewPath("one", "two", "three"), + pattern: NewPath("one", "two", "three"), + expected: true, + }, + } + for _, testcase := range testcases { + assert.Equal(t, testcase.expected, testcase.path.matches(testcase.pattern)) + } } diff --git a/cli/compose/loader/interpolate.go b/cli/compose/loader/interpolate.go new file mode 100644 index 0000000000..5c3e1b8b0c --- /dev/null +++ b/cli/compose/loader/interpolate.go @@ -0,0 +1,74 @@ +package loader + +import ( + "strconv" + "strings" + + interp "github.com/docker/cli/cli/compose/interpolation" + "github.com/pkg/errors" +) + +var interpolateTypeCastMapping = map[interp.Path]interp.Cast{ + servicePath("configs", interp.PathMatchList, "mode"): toInt, + servicePath("secrets", interp.PathMatchList, "mode"): toInt, + servicePath("healthcheck", "retries"): toInt, + servicePath("healthcheck", "disable"): toBoolean, + servicePath("deploy", "replicas"): toInt, + servicePath("deploy", "update_config", "parallelism"): toInt, + servicePath("deploy", "update_config", "max_failure_ratio"): toFloat, + servicePath("deploy", "restart_policy", "max_attempts"): toInt, + servicePath("ports", interp.PathMatchList, "target"): toInt, + servicePath("ports", interp.PathMatchList, "published"): toInt, + servicePath("ulimits", interp.PathMatchAll): toInt, + servicePath("ulimits", interp.PathMatchAll, "hard"): toInt, + servicePath("ulimits", interp.PathMatchAll, "soft"): toInt, + servicePath("privileged"): toBoolean, + servicePath("read_only"): toBoolean, + servicePath("stdin_open"): toBoolean, + servicePath("tty"): toBoolean, + servicePath("volumes", interp.PathMatchList, "read_only"): toBoolean, + servicePath("volumes", interp.PathMatchList, "volume", "nocopy"): toBoolean, + iPath("networks", interp.PathMatchAll, "external"): toBoolean, + iPath("networks", interp.PathMatchAll, "internal"): toBoolean, + iPath("networks", interp.PathMatchAll, "attachable"): toBoolean, + iPath("volumes", interp.PathMatchAll, "external"): toBoolean, + iPath("secrets", interp.PathMatchAll, "external"): toBoolean, + iPath("configs", interp.PathMatchAll, "external"): toBoolean, +} + +func iPath(parts ...string) interp.Path { + return interp.NewPath(parts...) +} + +func servicePath(parts ...string) interp.Path { + return iPath(append([]string{"services", interp.PathMatchAll}, parts...)...) +} + +func toInt(value string) (interface{}, error) { + return strconv.Atoi(value) +} + +func toFloat(value string) (interface{}, error) { + return strconv.ParseFloat(value, 64) +} + +// should match http://yaml.org/type/bool.html +func toBoolean(value string) (interface{}, error) { + switch strings.ToLower(value) { + case "y", "yes", "true", "on": + return true, nil + case "n", "no", "false", "off": + return false, nil + default: + return nil, errors.Errorf("invalid boolean: %s", value) + } +} + +func interpolateConfig(configDict map[string]interface{}, lookupEnv interp.LookupValue) (map[string]interface{}, error) { + return interp.Interpolate( + configDict, + interp.Options{ + LookupValue: lookupEnv, + TypeCastMapping: interpolateTypeCastMapping, + }) +} diff --git a/cli/compose/loader/loader.go b/cli/compose/loader/loader.go index 972395dd72..5ab95b21f6 100644 --- a/cli/compose/loader/loader.go +++ b/cli/compose/loader/loader.go @@ -8,7 +8,6 @@ import ( "sort" "strings" - "github.com/docker/cli/cli/compose/interpolation" "github.com/docker/cli/cli/compose/schema" "github.com/docker/cli/cli/compose/template" "github.com/docker/cli/cli/compose/types" @@ -51,67 +50,92 @@ func Load(configDetails types.ConfigDetails) (*types.Config, error) { configDict := getConfigDict(configDetails) - if services, ok := configDict["services"]; ok { - if servicesDict, ok := services.(map[string]interface{}); ok { - forbidden := getProperties(servicesDict, types.ForbiddenProperties) + if err := validateForbidden(configDict); err != nil { + return nil, err + } - if len(forbidden) > 0 { - return nil, &ForbiddenPropertiesError{Properties: forbidden} - } - } + var err error + configDict, err = interpolateConfig(configDict, configDetails.LookupEnv) + if err != nil { + return nil, err } if err := schema.Validate(configDict, schema.Version(configDict)); err != nil { return nil, err } - - cfg := types.Config{} - - config, err := interpolateConfig(configDict, configDetails.LookupEnv) - if err != nil { - return nil, err - } - - cfg.Services, err = LoadServices(config["services"], configDetails.WorkingDir, configDetails.LookupEnv) - if err != nil { - return nil, err - } - - cfg.Networks, err = LoadNetworks(config["networks"]) - if err != nil { - return nil, err - } - - cfg.Volumes, err = LoadVolumes(config["volumes"]) - if err != nil { - return nil, err - } - - cfg.Secrets, err = LoadSecrets(config["secrets"], configDetails.WorkingDir) - if err != nil { - return nil, err - } - - cfg.Configs, err = LoadConfigObjs(config["configs"], configDetails.WorkingDir) - return &cfg, err + return loadSections(configDict, configDetails) } -func interpolateConfig(configDict map[string]interface{}, lookupEnv template.Mapping) (map[string]map[string]interface{}, error) { - config := make(map[string]map[string]interface{}) +func validateForbidden(configDict map[string]interface{}) error { + servicesDict, ok := configDict["services"].(map[string]interface{}) + if !ok { + return nil + } + forbidden := getProperties(servicesDict, types.ForbiddenProperties) + if len(forbidden) > 0 { + return &ForbiddenPropertiesError{Properties: forbidden} + } + return nil +} - for _, key := range []string{"services", "networks", "volumes", "secrets", "configs"} { - section, ok := configDict[key] - if !ok { - config[key] = make(map[string]interface{}) - continue - } - var err error - config[key], err = interpolation.Interpolate(section.(map[string]interface{}), key, lookupEnv) - if err != nil { +func loadSections(config map[string]interface{}, configDetails types.ConfigDetails) (*types.Config, error) { + var err error + cfg := types.Config{} + + var loaders = []struct { + key string + fnc func(config map[string]interface{}) error + }{ + { + key: "services", + fnc: func(config map[string]interface{}) error { + cfg.Services, err = LoadServices(config, configDetails.WorkingDir, configDetails.LookupEnv) + return err + }, + }, + { + key: "networks", + fnc: func(config map[string]interface{}) error { + cfg.Networks, err = LoadNetworks(config) + return err + }, + }, + { + key: "volumes", + fnc: func(config map[string]interface{}) error { + cfg.Volumes, err = LoadVolumes(config) + return err + }, + }, + { + key: "secrets", + fnc: func(config map[string]interface{}) error { + cfg.Secrets, err = LoadSecrets(config, configDetails.WorkingDir) + return err + }, + }, + { + key: "configs", + fnc: func(config map[string]interface{}) error { + cfg.Configs, err = LoadConfigObjs(config, configDetails.WorkingDir) + return err + }, + }, + } + for _, loader := range loaders { + if err := loader.fnc(getSection(config, loader.key)); err != nil { return nil, err } } - return config, nil + return &cfg, nil +} + +func getSection(config map[string]interface{}, key string) map[string]interface{} { + section, ok := config[key] + if !ok { + return make(map[string]interface{}) + } + return section.(map[string]interface{}) } // GetUnsupportedProperties returns the list of any unsupported properties that are diff --git a/cli/compose/loader/loader_test.go b/cli/compose/loader/loader_test.go index 863714dff1..f7e7ff3026 100644 --- a/cli/compose/loader/loader_test.go +++ b/cli/compose/loader/loader_test.go @@ -465,7 +465,7 @@ services: assert.Contains(t, err.Error(), "services.dict-env.environment must be a mapping") } -func TestEnvironmentInterpolation(t *testing.T) { +func TestLoadWithEnvironmentInterpolation(t *testing.T) { home := "/home/foo" config, err := loadYAMLWithEnv(` version: "3" @@ -502,19 +502,161 @@ volumes: assert.Equal(t, home, config.Volumes["test"].Driver) } +func TestLoadWithInterpolationCastFull(t *testing.T) { + dict, err := ParseYAML([]byte(` +version: "3.4" +services: + web: + configs: + - source: appconfig + mode: $theint + secrets: + - source: super + mode: $theint + healthcheck: + retries: ${theint} + disable: $thebool + deploy: + replicas: $theint + update_config: + parallelism: $theint + max_failure_ratio: $thefloat + restart_policy: + max_attempts: $theint + ports: + - $theint + - "34567" + - target: $theint + published: $theint + ulimits: + nproc: $theint + nofile: + hard: $theint + soft: $theint + privileged: $thebool + read_only: $thebool + stdin_open: ${thebool} + tty: $thebool + volumes: + - source: data + type: volume + read_only: $thebool + volume: + nocopy: $thebool + +configs: + appconfig: + external: $thebool +secrets: + super: + external: $thebool +volumes: + data: + external: $thebool +networks: + front: + external: $thebool + internal: $thebool + attachable: $thebool + +`)) + require.NoError(t, err) + env := map[string]string{ + "theint": "555", + "thefloat": "3.14", + "thebool": "true", + } + + config, err := Load(buildConfigDetails(dict, env)) + require.NoError(t, err) + expected := &types.Config{ + Services: []types.ServiceConfig{ + { + Name: "web", + Configs: []types.ServiceConfigObjConfig{ + { + Source: "appconfig", + Mode: uint32Ptr(555), + }, + }, + Secrets: []types.ServiceSecretConfig{ + { + Source: "super", + Mode: uint32Ptr(555), + }, + }, + HealthCheck: &types.HealthCheckConfig{ + Retries: uint64Ptr(555), + Disable: true, + }, + Deploy: types.DeployConfig{ + Replicas: uint64Ptr(555), + UpdateConfig: &types.UpdateConfig{ + Parallelism: uint64Ptr(555), + MaxFailureRatio: 3.14, + }, + RestartPolicy: &types.RestartPolicy{ + MaxAttempts: uint64Ptr(555), + }, + }, + Ports: []types.ServicePortConfig{ + {Target: 555, Mode: "ingress", Protocol: "tcp"}, + {Target: 34567, Mode: "ingress", Protocol: "tcp"}, + {Target: 555, Published: 555}, + }, + Ulimits: map[string]*types.UlimitsConfig{ + "nproc": {Single: 555}, + "nofile": {Hard: 555, Soft: 555}, + }, + Privileged: true, + ReadOnly: true, + StdinOpen: true, + Tty: true, + Volumes: []types.ServiceVolumeConfig{ + { + Source: "data", + Type: "volume", + ReadOnly: true, + Volume: &types.ServiceVolumeVolume{NoCopy: true}, + }, + }, + Environment: types.MappingWithEquals{}, + }, + }, + Configs: map[string]types.ConfigObjConfig{ + "appconfig": {External: types.External{External: true, Name: "appconfig"}}, + }, + Secrets: map[string]types.SecretConfig{ + "super": {External: types.External{External: true, Name: "super"}}, + }, + Volumes: map[string]types.VolumeConfig{ + "data": {External: types.External{External: true, Name: "data"}}, + }, + Networks: map[string]types.NetworkConfig{ + "front": { + External: types.External{External: true, Name: "front"}, + Internal: true, + Attachable: true, + }, + }, + } + + assert.Equal(t, expected, config) +} + func TestUnsupportedProperties(t *testing.T) { dict, err := ParseYAML([]byte(` version: "3" services: web: image: web - build: + build: context: ./web links: - bar db: image: db - build: + build: context: ./db `)) require.NoError(t, err) @@ -679,6 +821,10 @@ func uint64Ptr(value uint64) *uint64 { return &value } +func uint32Ptr(value uint32) *uint32 { + return &value +} + func TestFullExample(t *testing.T) { bytes, err := ioutil.ReadFile("full-example.yml") require.NoError(t, err) diff --git a/vendor/github.com/gotestyourself/gotestyourself/env/env.go b/vendor/github.com/gotestyourself/gotestyourself/env/env.go new file mode 100644 index 0000000000..61a45a438d --- /dev/null +++ b/vendor/github.com/gotestyourself/gotestyourself/env/env.go @@ -0,0 +1,58 @@ +/*Package env provides functions to test code that read environment variables + */ +package env + +import ( + "os" + "strings" + + "github.com/stretchr/testify/require" +) + +// Patch changes the value of an environment variable, and returns a +// function which will reset the the value of that variable back to the +// previous state. +func Patch(t require.TestingT, key, value string) func() { + oldValue, ok := os.LookupEnv(key) + require.NoError(t, os.Setenv(key, value)) + return func() { + if !ok { + require.NoError(t, os.Unsetenv(key)) + return + } + require.NoError(t, os.Setenv(key, oldValue)) + } +} + +// PatchAll sets the environment to env, and returns a function which will +// reset the environment back to the previous state. +func PatchAll(t require.TestingT, env map[string]string) func() { + oldEnv := os.Environ() + os.Clearenv() + + for key, value := range env { + require.NoError(t, os.Setenv(key, value)) + } + return func() { + os.Clearenv() + for key, oldVal := range ToMap(oldEnv) { + require.NoError(t, os.Setenv(key, oldVal)) + } + } +} + +// ToMap takes a list of strings in the format returned by os.Environ() and +// returns a mapping of keys to values. +func ToMap(env []string) map[string]string { + result := map[string]string{} + for _, raw := range env { + parts := strings.SplitN(raw, "=", 2) + switch len(parts) { + case 1: + result[raw] = "" + case 2: + result[parts[0]] = parts[1] + } + } + return result +}