diff --git a/cli/compose/interpolation/interpolation.go b/cli/compose/interpolation/interpolation.go index a1e7719860..d4f5c4a43f 100644 --- a/cli/compose/interpolation/interpolation.go +++ b/cli/compose/interpolation/interpolation.go @@ -14,6 +14,8 @@ type Options struct { LookupValue LookupValue // TypeCastMapping maps key paths to functions to cast to a type TypeCastMapping map[Path]Cast + // Substitution function to use + Substitute func(string, template.Mapping) (string, error) } // LookupValue is a function which maps from variable names to values. @@ -33,6 +35,9 @@ func Interpolate(config map[string]interface{}, opts Options) (map[string]interf if opts.TypeCastMapping == nil { opts.TypeCastMapping = make(map[Path]Cast) } + if opts.Substitute == nil { + opts.Substitute = template.Substitute + } out := map[string]interface{}{} @@ -51,7 +56,7 @@ func recursiveInterpolate(value interface{}, path Path, opts Options) (interface switch value := value.(type) { case string: - newValue, err := template.Substitute(value, template.Mapping(opts.LookupValue)) + newValue, err := opts.Substitute(value, template.Mapping(opts.LookupValue)) if err != nil || newValue == value { return value, newPathError(path, err) } diff --git a/cli/compose/loader/interpolate.go b/cli/compose/loader/interpolate.go index 5c3e1b8b0c..888d29b58c 100644 --- a/cli/compose/loader/interpolate.go +++ b/cli/compose/loader/interpolate.go @@ -64,11 +64,6 @@ func toBoolean(value string) (interface{}, error) { } } -func interpolateConfig(configDict map[string]interface{}, lookupEnv interp.LookupValue) (map[string]interface{}, error) { - return interp.Interpolate( - configDict, - interp.Options{ - LookupValue: lookupEnv, - TypeCastMapping: interpolateTypeCastMapping, - }) +func interpolateConfig(configDict map[string]interface{}, opts interp.Options) (map[string]interface{}, error) { + return interp.Interpolate(configDict, opts) } diff --git a/cli/compose/loader/loader.go b/cli/compose/loader/loader.go index 3c138d59ec..f11619b985 100644 --- a/cli/compose/loader/loader.go +++ b/cli/compose/loader/loader.go @@ -8,6 +8,7 @@ import ( "sort" "strings" + interp "github.com/docker/cli/cli/compose/interpolation" "github.com/docker/cli/cli/compose/schema" "github.com/docker/cli/cli/compose/template" "github.com/docker/cli/cli/compose/types" @@ -22,6 +23,16 @@ import ( yaml "gopkg.in/yaml.v2" ) +// Options supported by Load +type Options struct { + // Skip schema validation + SkipValidation bool + // Skip interpolation + SkipInterpolation bool + // Interpolation options + Interpolate *interp.Options +} + // ParseYAML reads the bytes from a file, parses the bytes into a mapping // structure, and returns it. func ParseYAML(source []byte) (map[string]interface{}, error) { @@ -41,12 +52,25 @@ func ParseYAML(source []byte) (map[string]interface{}, error) { } // Load reads a ConfigDetails and returns a fully loaded configuration -func Load(configDetails types.ConfigDetails) (*types.Config, error) { +func Load(configDetails types.ConfigDetails, options ...func(*Options)) (*types.Config, error) { if len(configDetails.ConfigFiles) < 1 { return nil, errors.Errorf("No files specified") } + opts := &Options{ + Interpolate: &interp.Options{ + Substitute: template.Substitute, + LookupValue: configDetails.LookupEnv, + TypeCastMapping: interpolateTypeCastMapping, + }, + } + + for _, op := range options { + op(opts) + } + configs := []*types.Config{} + var err error for _, file := range configDetails.ConfigFiles { configDict := file.Config @@ -62,14 +86,17 @@ func Load(configDetails types.ConfigDetails) (*types.Config, error) { return nil, err } - var err error - configDict, err = interpolateConfig(configDict, configDetails.LookupEnv) - if err != nil { - return nil, err + if !opts.SkipInterpolation { + configDict, err = interpolateConfig(configDict, *opts.Interpolate) + if err != nil { + return nil, err + } } - if err := schema.Validate(configDict, configDetails.Version); err != nil { - return nil, err + if !opts.SkipValidation { + if err := schema.Validate(configDict, configDetails.Version); err != nil { + return nil, err + } } cfg, err := loadSections(configDict, configDetails) diff --git a/cli/compose/template/template.go b/cli/compose/template/template.go index 086194b8fe..be5d456053 100644 --- a/cli/compose/template/template.go +++ b/cli/compose/template/template.go @@ -16,6 +16,14 @@ var patternString = fmt.Sprintf( var pattern = regexp.MustCompile(patternString) +// DefaultSubstituteFuncs contains the default SubstitueFunc used by the docker cli +var DefaultSubstituteFuncs = []SubstituteFunc{ + softDefault, + hardDefault, + requiredNonEmpty, + required, +} + // InvalidTemplateError is returned when a variable template is not in a valid // format type InvalidTemplateError struct { @@ -32,8 +40,14 @@ func (e InvalidTemplateError) Error() string { // and the absence of a value. type Mapping func(string) (string, bool) -// Substitute variables in the string with their values -func Substitute(template string, mapping Mapping) (string, error) { +// SubstituteFunc is a user-supplied function that apply substitution. +// Returns the value as a string, a bool indicating if the function could apply +// the substitution and an error. +type SubstituteFunc func(string, Mapping) (string, bool, error) + +// SubstituteWith subsitute variables in the string with their values. +// It accepts additional substitute function. +func SubstituteWith(template string, mapping Mapping, pattern *regexp.Regexp, subsFuncs ...SubstituteFunc) (string, error) { var err error result := pattern.ReplaceAllStringFunc(template, func(substring string) string { matches := pattern.FindStringSubmatch(substring) @@ -47,49 +61,22 @@ func Substitute(template string, mapping Mapping) (string, error) { substitution = groups["braced"] } - switch { - - case substitution == "": + if substitution == "" { err = &InvalidTemplateError{Template: template} return "" + } - // Soft default (fall back if unset or empty) - case strings.Contains(substitution, ":-"): - name, defaultValue := partition(substitution, ":-") - value, ok := mapping(name) - if !ok || value == "" { - return defaultValue - } - return value - - // Hard default (fall back if-and-only-if empty) - case strings.Contains(substitution, "-"): - name, defaultValue := partition(substitution, "-") - value, ok := mapping(name) - if !ok { - return defaultValue - } - return value - - case strings.Contains(substitution, ":?"): - name, errorMessage := partition(substitution, ":?") - value, ok := mapping(name) - if !ok || value == "" { - err = &InvalidTemplateError{ - Template: fmt.Sprintf("required variable %s is missing a value: %s", name, errorMessage), - } + for _, f := range subsFuncs { + var ( + value string + applied bool + ) + value, applied, err = f(substitution, mapping) + if err != nil { return "" } - return value - - case strings.Contains(substitution, "?"): - name, errorMessage := partition(substitution, "?") - value, ok := mapping(name) - if !ok { - err = &InvalidTemplateError{ - Template: fmt.Sprintf("required variable %s is missing a value: %s", name, errorMessage), - } - return "" + if !applied { + continue } return value } @@ -101,6 +88,65 @@ func Substitute(template string, mapping Mapping) (string, error) { return result, err } +// Substitute variables in the string with their values +func Substitute(template string, mapping Mapping) (string, error) { + return SubstituteWith(template, mapping, pattern, DefaultSubstituteFuncs...) +} + +// Soft default (fall back if unset or empty) +func softDefault(substitution string, mapping Mapping) (string, bool, error) { + if !strings.Contains(substitution, ":-") { + return "", false, nil + } + name, defaultValue := partition(substitution, ":-") + value, ok := mapping(name) + if !ok || value == "" { + return defaultValue, true, nil + } + return value, true, nil +} + +// Hard default (fall back if-and-only-if empty) +func hardDefault(substitution string, mapping Mapping) (string, bool, error) { + if !strings.Contains(substitution, "-") { + return "", false, nil + } + name, defaultValue := partition(substitution, "-") + value, ok := mapping(name) + if !ok { + return defaultValue, true, nil + } + return value, true, nil +} + +func requiredNonEmpty(substitution string, mapping Mapping) (string, bool, error) { + if !strings.Contains(substitution, ":?") { + return "", false, nil + } + name, errorMessage := partition(substitution, ":?") + value, ok := mapping(name) + if !ok || value == "" { + return "", true, &InvalidTemplateError{ + Template: fmt.Sprintf("required variable %s is missing a value: %s", name, errorMessage), + } + } + return value, true, nil +} + +func required(substitution string, mapping Mapping) (string, bool, error) { + if !strings.Contains(substitution, "?") { + return "", false, nil + } + name, errorMessage := partition(substitution, "?") + value, ok := mapping(name) + if !ok { + return "", true, &InvalidTemplateError{ + Template: fmt.Sprintf("required variable %s is missing a value: %s", name, errorMessage), + } + } + return value, true, nil +} + func matchGroups(matches []string) map[string]string { groups := make(map[string]string) for i, name := range pattern.SubexpNames()[1:] { diff --git a/cli/compose/template/template_test.go b/cli/compose/template/template_test.go index 48d76bb9fc..4d2f030816 100644 --- a/cli/compose/template/template_test.go +++ b/cli/compose/template/template_test.go @@ -1,6 +1,7 @@ package template import ( + "fmt" "reflect" "testing" @@ -148,3 +149,26 @@ func TestDefaultsForMandatoryVariables(t *testing.T) { assert.Check(t, is.Equal(tc.expected, result)) } } + +func TestSubstituteWithCustomFunc(t *testing.T) { + errIsMissing := func(substitution string, mapping Mapping) (string, bool, error) { + value, found := mapping(substitution) + if !found { + return "", true, &InvalidTemplateError{ + Template: fmt.Sprintf("required variable %s is missing a value", substitution), + } + } + return value, true, nil + } + + result, err := SubstituteWith("ok ${FOO}", defaultMapping, pattern, errIsMissing) + assert.NilError(t, err) + assert.Check(t, is.Equal("ok first", result)) + + result, err = SubstituteWith("ok ${BAR}", defaultMapping, pattern, errIsMissing) + assert.NilError(t, err) + assert.Check(t, is.Equal("ok ", result)) + + _, err = SubstituteWith("ok ${NOTHERE}", defaultMapping, pattern, errIsMissing) + assert.Check(t, is.ErrorContains(err, "required variable")) +}