diff --git a/cli/compose/template/template.go b/cli/compose/template/template.go index 9bfc7cec5c..2792a09ba5 100644 --- a/cli/compose/template/template.go +++ b/cli/compose/template/template.go @@ -37,77 +37,74 @@ func Substitute(template string, mapping Mapping) (string, error) { var err error result := pattern.ReplaceAllStringFunc(template, func(substring string) string { matches := pattern.FindStringSubmatch(substring) - groups := make(map[string]string) - for i, name := range pattern.SubexpNames() { - if i != 0 { - groups[name] = matches[i] - } + groups := matchGroups(matches) + if escaped := groups["escaped"]; escaped != "" { + return escaped } substitution := groups["named"] if substitution == "" { substitution = groups["braced"] } - if substitution != "" { - // Soft default (fall back if unset or empty) - if strings.Contains(substitution, ":-") { - name, defaultValue := partition(substitution, ":-") - value, ok := mapping(name) - if !ok || value == "" { - return defaultValue - } - return value - } - // Hard default (fall back if-and-only-if empty) - if strings.Contains(substitution, "-") { - name, defaultValue := partition(substitution, "-") - value, ok := mapping(name) - if !ok { - return defaultValue - } - return value - } + switch { - if strings.Contains(substitution, ":?") { - name, errorMessage := partition(substitution, ":?") - value, ok := mapping(name) - if !ok || value == "" { - err = &InvalidTemplateError{Template: errorMessage} - return "" - } - return value - } + case substitution == "": + err = &InvalidTemplateError{Template: template} + return "" - if strings.Contains(substitution, "?") { - name, errorMessage := partition(substitution, "?") - value, ok := mapping(name) - if !ok { - err = &InvalidTemplateError{Template: errorMessage} - return "" - } - return value + // Soft default (fall back if unset or empty) + case strings.Contains(substitution, ":-"): + name, defaultValue := partition(substitution, ":-") + value, ok := mapping(name) + if !ok || value == "" { + return defaultValue } + return value - // No default (fall back to empty string) - value, ok := mapping(substitution) + // Hard default (fall back if-and-only-if empty) + case strings.Contains(substitution, "-"): + name, defaultValue := partition(substitution, "-") + value, ok := mapping(name) if !ok { + return defaultValue + } + return value + + case strings.Contains(substitution, ":?"): + name, errorMessage := partition(substitution, ":?") + value, ok := mapping(name) + if !ok || value == "" { + err = &InvalidTemplateError{Template: errorMessage} + return "" + } + return value + + case strings.Contains(substitution, "?"): + name, errorMessage := partition(substitution, "?") + value, ok := mapping(name) + if !ok { + err = &InvalidTemplateError{Template: errorMessage} return "" } return value } - if escaped := groups["escaped"]; escaped != "" { - return escaped - } - - err = &InvalidTemplateError{Template: template} - return "" + value, _ := mapping(substitution) + return value }) return result, err } +func matchGroups(matches []string) map[string]string { + groups := make(map[string]string) + for i, name := range pattern.SubexpNames()[1:] { + groups[name] = matches[i+1] + } + return groups +} + // Split the string at the first occurrence of sep, and return the part before the separator, // and the part after the separator. // diff --git a/cli/compose/template/template_test.go b/cli/compose/template/template_test.go index d11b0a8ec3..ce75216dde 100644 --- a/cli/compose/template/template_test.go +++ b/cli/compose/template/template_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var defaults = map[string]string{ @@ -22,6 +23,12 @@ func TestEscaped(t *testing.T) { assert.Equal(t, "${foo}", result) } +func TestSubstituteNoMatch(t *testing.T) { + result, err := Substitute("foo", defaultMapping) + require.NoError(t, err) + require.Equal(t, "foo", result) +} + func TestInvalid(t *testing.T) { invalidTemplates := []string{ "${",