diff --git a/cli/compose/template/template.go b/cli/compose/template/template.go index 8426a23522..086194b8fe 100644 --- a/cli/compose/template/template.go +++ b/cli/compose/template/template.go @@ -7,7 +7,7 @@ import ( ) var delimiter = "\\$" -var substitution = "[_a-z][_a-z0-9]*(?::?-[^}]+)?" +var substitution = "[_a-z][_a-z0-9]*(?::?[-?][^}]*)?" var patternString = fmt.Sprintf( "%s(?i:(?P%s)|(?P%s)|{(?P%s)}|(?P))", @@ -37,57 +37,78 @@ func Substitute(template string, mapping Mapping) (string, error) { var err error result := pattern.ReplaceAllStringFunc(template, func(substring string) string { matches := pattern.FindStringSubmatch(substring) - groups := make(map[string]string) - for i, name := range pattern.SubexpNames() { - if i != 0 { - groups[name] = matches[i] - } + groups := matchGroups(matches) + if escaped := groups["escaped"]; escaped != "" { + return escaped } substitution := groups["named"] if substitution == "" { substitution = groups["braced"] } - if substitution != "" { - // Soft default (fall back if unset or empty) - if strings.Contains(substitution, ":-") { - name, defaultValue := partition(substitution, ":-") - value, ok := mapping(name) - if !ok || value == "" { - return defaultValue - } - return value - } - // Hard default (fall back if-and-only-if empty) - if strings.Contains(substitution, "-") { - name, defaultValue := partition(substitution, "-") - value, ok := mapping(name) - if !ok { - return defaultValue - } - return value - } + switch { - // No default (fall back to empty string) - value, ok := mapping(substitution) + case substitution == "": + err = &InvalidTemplateError{Template: template} + return "" + + // Soft default (fall back if unset or empty) + case strings.Contains(substitution, ":-"): + name, defaultValue := partition(substitution, ":-") + value, ok := mapping(name) + if !ok || value == "" { + return defaultValue + } + return value + + // Hard default (fall back if-and-only-if empty) + case strings.Contains(substitution, "-"): + name, defaultValue := partition(substitution, "-") + value, ok := mapping(name) if !ok { + return defaultValue + } + return value + + case strings.Contains(substitution, ":?"): + name, errorMessage := partition(substitution, ":?") + value, ok := mapping(name) + if !ok || value == "" { + err = &InvalidTemplateError{ + Template: fmt.Sprintf("required variable %s is missing a value: %s", name, errorMessage), + } + return "" + } + return value + + case strings.Contains(substitution, "?"): + name, errorMessage := partition(substitution, "?") + value, ok := mapping(name) + if !ok { + err = &InvalidTemplateError{ + Template: fmt.Sprintf("required variable %s is missing a value: %s", name, errorMessage), + } return "" } return value } - if escaped := groups["escaped"]; escaped != "" { - return escaped - } - - err = &InvalidTemplateError{Template: template} - return "" + value, _ := mapping(substitution) + return value }) return result, err } +func matchGroups(matches []string) map[string]string { + groups := make(map[string]string) + for i, name := range pattern.SubexpNames()[1:] { + groups[name] = matches[i+1] + } + return groups +} + // Split the string at the first occurrence of sep, and return the part before the separator, // and the part after the separator. // diff --git a/cli/compose/template/template_test.go b/cli/compose/template/template_test.go index 4ad3f252ac..4fec57e8e8 100644 --- a/cli/compose/template/template_test.go +++ b/cli/compose/template/template_test.go @@ -3,7 +3,9 @@ package template import ( "testing" + "github.com/docker/cli/internal/test/testutil" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var defaults = map[string]string{ @@ -22,6 +24,12 @@ func TestEscaped(t *testing.T) { assert.Equal(t, "${foo}", result) } +func TestSubstituteNoMatch(t *testing.T) { + result, err := Substitute("foo", defaultMapping) + require.NoError(t, err) + require.Equal(t, "foo", result) +} + func TestInvalid(t *testing.T) { invalidTemplates := []string{ "${", @@ -81,3 +89,64 @@ func TestNonAlphanumericDefault(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "ok /non:-alphanumeric", result) } + +func TestMandatoryVariableErrors(t *testing.T) { + testCases := []struct { + template string + expectedError string + }{ + { + template: "not ok ${UNSET_VAR:?Mandatory Variable Unset}", + expectedError: "required variable UNSET_VAR is missing a value: Mandatory Variable Unset", + }, + { + template: "not ok ${BAR:?Mandatory Variable Empty}", + expectedError: "required variable BAR is missing a value: Mandatory Variable Empty", + }, + { + template: "not ok ${UNSET_VAR:?}", + expectedError: "required variable UNSET_VAR is missing a value", + }, + { + template: "not ok ${UNSET_VAR?Mandatory Variable Unset}", + expectedError: "required variable UNSET_VAR is missing a value: Mandatory Variable Unset", + }, + { + template: "not ok ${UNSET_VAR?}", + expectedError: "required variable UNSET_VAR is missing a value", + }, + } + + for _, tc := range testCases { + _, err := Substitute(tc.template, defaultMapping) + assert.Error(t, err) + assert.IsType(t, &InvalidTemplateError{}, err) + testutil.ErrorContains(t, err, tc.expectedError) + } +} + +func TestDefaultsForMandatoryVariables(t *testing.T) { + testCases := []struct { + template string + expected string + }{ + { + template: "ok ${FOO:?err}", + expected: "ok first", + }, + { + template: "ok ${FOO?err}", + expected: "ok first", + }, + { + template: "ok ${BAR?err}", + expected: "ok ", + }, + } + + for _, tc := range testCases { + result, err := Substitute(tc.template, defaultMapping) + assert.Nil(t, err) + assert.Equal(t, tc.expected, result) + } +}