diff --git a/cli/compose/template/template.go b/cli/compose/template/template.go index 8426a23522..9bfc7cec5c 100644 --- a/cli/compose/template/template.go +++ b/cli/compose/template/template.go @@ -7,7 +7,7 @@ import ( ) var delimiter = "\\$" -var substitution = "[_a-z][_a-z0-9]*(?::?-[^}]+)?" +var substitution = "[_a-z][_a-z0-9]*(?::?[-?][^}]+)?" var patternString = fmt.Sprintf( "%s(?i:(?P%s)|(?P%s)|{(?P%s)}|(?P))", @@ -69,6 +69,26 @@ func Substitute(template string, mapping Mapping) (string, error) { return value } + if strings.Contains(substitution, ":?") { + name, errorMessage := partition(substitution, ":?") + value, ok := mapping(name) + if !ok || value == "" { + err = &InvalidTemplateError{Template: errorMessage} + return "" + } + return value + } + + if strings.Contains(substitution, "?") { + name, errorMessage := partition(substitution, "?") + value, ok := mapping(name) + if !ok { + err = &InvalidTemplateError{Template: errorMessage} + return "" + } + return value + } + // No default (fall back to empty string) value, ok := mapping(substitution) if !ok { diff --git a/cli/compose/template/template_test.go b/cli/compose/template/template_test.go index 4ad3f252ac..d11b0a8ec3 100644 --- a/cli/compose/template/template_test.go +++ b/cli/compose/template/template_test.go @@ -81,3 +81,63 @@ func TestNonAlphanumericDefault(t *testing.T) { assert.Nil(t, err) assert.Equal(t, "ok /non:-alphanumeric", result) } + +func TestMandatoryVariableErrors(t *testing.T) { + testCases := []struct { + template string + expectedError string + }{ + { + template: "not ok ${UNSET_VAR:?Mandatory Variable Unset}", + expectedError: "Mandatory Variable Unset", + }, + { + template: "not ok ${BAR:?Mandatory Variable Empty}", + expectedError: "Mandatory Variable Empty", + }, + { + template: "not ok ${UNSET_VAR:?}", + expectedError: "", + }, + { + template: "not ok ${UNSET_VAR?Mandatory Variable Unset", + expectedError: "Mandatory Variable Unset", + }, + { + template: "not ok ${UNSET_VAR?}", + expectedError: "", + }, + } + + for _, tc := range testCases { + _, err := Substitute(tc.template, defaultMapping) + assert.Error(t, err) + assert.IsType(t, &InvalidTemplateError{tc.expectedError}, err) + } +} + +func TestDefaultsForMandatoryVariables(t *testing.T) { + testCases := []struct { + template string + expected string + }{ + { + template: "ok ${FOO:?err}", + expected: "ok first", + }, + { + template: "ok ${FOO?err}", + expected: "ok first", + }, + { + template: "ok ${BAR?err}", + expected: "ok ", + }, + } + + for _, tc := range testCases { + result, err := Substitute(tc.template, defaultMapping) + assert.Nil(t, err) + assert.Equal(t, tc.expected, result) + } +}