diff --git a/cli/compose/template/template.go b/cli/compose/template/template.go index 81d579078f..1762ab11a9 100644 --- a/cli/compose/template/template.go +++ b/cli/compose/template/template.go @@ -14,7 +14,7 @@ var patternString = fmt.Sprintf( delimiter, delimiter, substitution, substitution, ) -var pattern = regexp.MustCompile(patternString) +var defaultPattern = regexp.MustCompile(patternString) // DefaultSubstituteFuncs contains the default SubstitueFunc used by the docker cli var DefaultSubstituteFuncs = []SubstituteFunc{ @@ -51,7 +51,7 @@ func SubstituteWith(template string, mapping Mapping, pattern *regexp.Regexp, su var err error result := pattern.ReplaceAllStringFunc(template, func(substring string) string { matches := pattern.FindStringSubmatch(substring) - groups := matchGroups(matches) + groups := matchGroups(matches, pattern) if escaped := groups["escaped"]; escaped != "" { return escaped } @@ -90,26 +90,31 @@ func SubstituteWith(template string, mapping Mapping, pattern *regexp.Regexp, su // Substitute variables in the string with their values func Substitute(template string, mapping Mapping) (string, error) { - return SubstituteWith(template, mapping, pattern, DefaultSubstituteFuncs...) + return SubstituteWith(template, mapping, defaultPattern, DefaultSubstituteFuncs...) } // ExtractVariables returns a map of all the variables defined in the specified // composefile (dict representation) and their default value if any. -func ExtractVariables(configDict map[string]interface{}) map[string]string { - return recurseExtract(configDict) +func ExtractVariables(configDict map[string]interface{}, pattern *regexp.Regexp) map[string]string { + if pattern == nil { + pattern = defaultPattern + } + return recurseExtract(configDict, pattern) } -func recurseExtract(value interface{}) map[string]string { +func recurseExtract(value interface{}, pattern *regexp.Regexp) map[string]string { m := map[string]string{} switch value := value.(type) { case string: - if v, is := extractVariable(value); is { - m[v.name] = v.value + if values, is := extractVariable(value, pattern); is { + for _, v := range values { + m[v.name] = v.value + } } case map[string]interface{}: for _, elem := range value { - submap := recurseExtract(elem) + submap := recurseExtract(elem, pattern) for key, value := range submap { m[key] = value } @@ -117,8 +122,10 @@ func recurseExtract(value interface{}) map[string]string { case []interface{}: for _, elem := range value { - if v, is := extractVariable(elem); is { - m[v.name] = v.value + if values, is := extractVariable(elem, pattern); is { + for _, v := range values { + m[v.name] = v.value + } } } } @@ -131,36 +138,40 @@ type extractedValue struct { value string } -func extractVariable(value interface{}) (extractedValue, bool) { +func extractVariable(value interface{}, pattern *regexp.Regexp) ([]extractedValue, bool) { sValue, ok := value.(string) if !ok { - return extractedValue{}, false + return []extractedValue{}, false } - matches := pattern.FindStringSubmatch(sValue) + matches := pattern.FindAllStringSubmatch(sValue, -1) if len(matches) == 0 { - return extractedValue{}, false + return []extractedValue{}, false } - groups := matchGroups(matches) - if escaped := groups["escaped"]; escaped != "" { - return extractedValue{}, false + values := []extractedValue{} + for _, match := range matches { + groups := matchGroups(match, pattern) + if escaped := groups["escaped"]; escaped != "" { + continue + } + val := groups["named"] + if val == "" { + val = groups["braced"] + } + name := val + var defaultValue string + switch { + case strings.Contains(val, ":?"): + name, _ = partition(val, ":?") + case strings.Contains(val, "?"): + name, _ = partition(val, "?") + case strings.Contains(val, ":-"): + name, defaultValue = partition(val, ":-") + case strings.Contains(val, "-"): + name, defaultValue = partition(val, "-") + } + values = append(values, extractedValue{name: name, value: defaultValue}) } - val := groups["named"] - if val == "" { - val = groups["braced"] - } - name := val - var defaultValue string - switch { - case strings.Contains(val, ":?"): - name, _ = partition(val, ":?") - case strings.Contains(val, "?"): - name, _ = partition(val, "?") - case strings.Contains(val, ":-"): - name, defaultValue = partition(val, ":-") - case strings.Contains(val, "-"): - name, defaultValue = partition(val, "-") - } - return extractedValue{name: name, value: defaultValue}, true + return values, len(values) > 0 } // Soft default (fall back if unset or empty) @@ -207,7 +218,7 @@ func withRequired(substitution string, mapping Mapping, sep string, valid func(s return value, true, nil } -func matchGroups(matches []string) map[string]string { +func matchGroups(matches []string, pattern *regexp.Regexp) map[string]string { groups := make(map[string]string) for i, name := range pattern.SubexpNames()[1:] { groups[name] = matches[i+1] diff --git a/cli/compose/template/template_test.go b/cli/compose/template/template_test.go index 48c588c702..abbc810c3c 100644 --- a/cli/compose/template/template_test.go +++ b/cli/compose/template/template_test.go @@ -161,15 +161,15 @@ func TestSubstituteWithCustomFunc(t *testing.T) { return value, true, nil } - result, err := SubstituteWith("ok ${FOO}", defaultMapping, pattern, errIsMissing) + result, err := SubstituteWith("ok ${FOO}", defaultMapping, defaultPattern, errIsMissing) assert.NilError(t, err) assert.Check(t, is.Equal("ok first", result)) - result, err = SubstituteWith("ok ${BAR}", defaultMapping, pattern, errIsMissing) + result, err = SubstituteWith("ok ${BAR}", defaultMapping, defaultPattern, errIsMissing) assert.NilError(t, err) assert.Check(t, is.Equal("ok ", result)) - _, err = SubstituteWith("ok ${NOTHERE}", defaultMapping, pattern, errIsMissing) + _, err = SubstituteWith("ok ${NOTHERE}", defaultMapping, defaultPattern, errIsMissing) assert.Check(t, is.ErrorContains(err, "required variable")) } @@ -245,18 +245,21 @@ func TestExtractVariables(t *testing.T) { }, "baz": []interface{}{ "foo", + "$docker:${project:-cli}", "$toto", }, }, expected: map[string]string{ - "bar": "foo", - "fruit": "banana", - "toto": "", + "bar": "foo", + "fruit": "banana", + "toto": "", + "docker": "", + "project": "cli", }, }, } for _, tc := range testCases { - actual := ExtractVariables(tc.dict) + actual := ExtractVariables(tc.dict, defaultPattern) assert.Check(t, is.DeepEqual(actual, tc.expected)) } }