Merge pull request #1249 from vdemeester/compose-template-pkg-enhancement

Add a new `ExtractVariables` function to `compose/template` package
This commit is contained in:
Vincent Demeester 2018-08-01 16:18:46 +02:00 committed by GitHub
commit da544e8938
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 173 additions and 25 deletions

View File

@ -93,25 +93,91 @@ func Substitute(template string, mapping Mapping) (string, error) {
return SubstituteWith(template, mapping, pattern, DefaultSubstituteFuncs...) return SubstituteWith(template, mapping, pattern, DefaultSubstituteFuncs...)
} }
// ExtractVariables returns a map of all the variables defined in the specified
// composefile (dict representation) and their default value if any.
func ExtractVariables(configDict map[string]interface{}) map[string]string {
return recurseExtract(configDict)
}
func recurseExtract(value interface{}) map[string]string {
m := map[string]string{}
switch value := value.(type) {
case string:
if v, is := extractVariable(value); is {
m[v.name] = v.value
}
case map[string]interface{}:
for _, elem := range value {
submap := recurseExtract(elem)
for key, value := range submap {
m[key] = value
}
}
case []interface{}:
for _, elem := range value {
if v, is := extractVariable(elem); is {
m[v.name] = v.value
}
}
}
return m
}
type extractedValue struct {
name string
value string
}
func extractVariable(value interface{}) (extractedValue, bool) {
sValue, ok := value.(string)
if !ok {
return extractedValue{}, false
}
matches := pattern.FindStringSubmatch(sValue)
if len(matches) == 0 {
return extractedValue{}, false
}
groups := matchGroups(matches)
if escaped := groups["escaped"]; escaped != "" {
return extractedValue{}, false
}
val := groups["named"]
if val == "" {
val = groups["braced"]
}
name := val
var defaultValue string
switch {
case strings.Contains(val, ":?"):
name, _ = partition(val, ":?")
case strings.Contains(val, "?"):
name, _ = partition(val, "?")
case strings.Contains(val, ":-"):
name, defaultValue = partition(val, ":-")
case strings.Contains(val, "-"):
name, defaultValue = partition(val, "-")
}
return extractedValue{name: name, value: defaultValue}, true
}
// Soft default (fall back if unset or empty) // Soft default (fall back if unset or empty)
func softDefault(substitution string, mapping Mapping) (string, bool, error) { func softDefault(substitution string, mapping Mapping) (string, bool, error) {
if !strings.Contains(substitution, ":-") { return withDefault(substitution, mapping, "-:")
return "", false, nil
}
name, defaultValue := partition(substitution, ":-")
value, ok := mapping(name)
if !ok || value == "" {
return defaultValue, true, nil
}
return value, true, nil
} }
// Hard default (fall back if-and-only-if empty) // Hard default (fall back if-and-only-if empty)
func hardDefault(substitution string, mapping Mapping) (string, bool, error) { func hardDefault(substitution string, mapping Mapping) (string, bool, error) {
if !strings.Contains(substitution, "-") { return withDefault(substitution, mapping, "-")
}
func withDefault(substitution string, mapping Mapping, sep string) (string, bool, error) {
if !strings.Contains(substitution, sep) {
return "", false, nil return "", false, nil
} }
name, defaultValue := partition(substitution, "-") name, defaultValue := partition(substitution, sep)
value, ok := mapping(name) value, ok := mapping(name)
if !ok { if !ok {
return defaultValue, true, nil return defaultValue, true, nil
@ -120,26 +186,20 @@ func hardDefault(substitution string, mapping Mapping) (string, bool, error) {
} }
func requiredNonEmpty(substitution string, mapping Mapping) (string, bool, error) { func requiredNonEmpty(substitution string, mapping Mapping) (string, bool, error) {
if !strings.Contains(substitution, ":?") { return withRequired(substitution, mapping, ":?", func(v string) bool { return v != "" })
return "", false, nil
}
name, errorMessage := partition(substitution, ":?")
value, ok := mapping(name)
if !ok || value == "" {
return "", true, &InvalidTemplateError{
Template: fmt.Sprintf("required variable %s is missing a value: %s", name, errorMessage),
}
}
return value, true, nil
} }
func required(substitution string, mapping Mapping) (string, bool, error) { func required(substitution string, mapping Mapping) (string, bool, error) {
if !strings.Contains(substitution, "?") { return withRequired(substitution, mapping, "?", func(_ string) bool { return true })
}
func withRequired(substitution string, mapping Mapping, sep string, valid func(string) bool) (string, bool, error) {
if !strings.Contains(substitution, sep) {
return "", false, nil return "", false, nil
} }
name, errorMessage := partition(substitution, "?") name, errorMessage := partition(substitution, sep)
value, ok := mapping(name) value, ok := mapping(name)
if !ok { if !ok || !valid(value) {
return "", true, &InvalidTemplateError{ return "", true, &InvalidTemplateError{
Template: fmt.Sprintf("required variable %s is missing a value: %s", name, errorMessage), Template: fmt.Sprintf("required variable %s is missing a value: %s", name, errorMessage),
} }

View File

@ -172,3 +172,91 @@ func TestSubstituteWithCustomFunc(t *testing.T) {
_, err = SubstituteWith("ok ${NOTHERE}", defaultMapping, pattern, errIsMissing) _, err = SubstituteWith("ok ${NOTHERE}", defaultMapping, pattern, errIsMissing)
assert.Check(t, is.ErrorContains(err, "required variable")) assert.Check(t, is.ErrorContains(err, "required variable"))
} }
func TestExtractVariables(t *testing.T) {
testCases := []struct {
dict map[string]interface{}
expected map[string]string
}{
{
dict: map[string]interface{}{},
expected: map[string]string{},
},
{
dict: map[string]interface{}{
"foo": "bar",
},
expected: map[string]string{},
},
{
dict: map[string]interface{}{
"foo": "$bar",
},
expected: map[string]string{
"bar": "",
},
},
{
dict: map[string]interface{}{
"foo": "${bar}",
},
expected: map[string]string{
"bar": "",
},
},
{
dict: map[string]interface{}{
"foo": "${bar?:foo}",
},
expected: map[string]string{
"bar": "",
},
},
{
dict: map[string]interface{}{
"foo": "${bar?foo}",
},
expected: map[string]string{
"bar": "",
},
},
{
dict: map[string]interface{}{
"foo": "${bar:-foo}",
},
expected: map[string]string{
"bar": "foo",
},
},
{
dict: map[string]interface{}{
"foo": "${bar-foo}",
},
expected: map[string]string{
"bar": "foo",
},
},
{
dict: map[string]interface{}{
"foo": "${bar:-foo}",
"bar": map[string]interface{}{
"foo": "${fruit:-banana}",
"bar": "vegetable",
},
"baz": []interface{}{
"foo",
"$toto",
},
},
expected: map[string]string{
"bar": "foo",
"fruit": "banana",
"toto": "",
},
},
}
for _, tc := range testCases {
actual := ExtractVariables(tc.dict)
assert.Check(t, is.DeepEqual(actual, tc.expected))
}
}