From afb87e42f27f9b4866c47548496e18719b0bd220 Mon Sep 17 00:00:00 2001 From: Vincent Demeester Date: Wed, 1 Aug 2018 16:12:49 +0200 Subject: [PATCH] Add a new `ExtractVariables` function to `compose/template` package It allows to get easily all the variables defined in a composefile (the `map[string]interface{}` representation that `loader.ParseYAML` returns at least) and their default value too. This commit also does some small function extract on substitution funcs to reduce a tiny bit duplication. Signed-off-by: Vincent Demeester --- cli/compose/template/template.go | 110 ++++++++++++++++++++------ cli/compose/template/template_test.go | 88 +++++++++++++++++++++ 2 files changed, 173 insertions(+), 25 deletions(-) diff --git a/cli/compose/template/template.go b/cli/compose/template/template.go index be5d456053..81d579078f 100644 --- a/cli/compose/template/template.go +++ b/cli/compose/template/template.go @@ -93,25 +93,91 @@ func Substitute(template string, mapping Mapping) (string, error) { return SubstituteWith(template, mapping, pattern, DefaultSubstituteFuncs...) } +// ExtractVariables returns a map of all the variables defined in the specified +// composefile (dict representation) and their default value if any. +func ExtractVariables(configDict map[string]interface{}) map[string]string { + return recurseExtract(configDict) +} + +func recurseExtract(value interface{}) map[string]string { + m := map[string]string{} + + switch value := value.(type) { + case string: + if v, is := extractVariable(value); is { + m[v.name] = v.value + } + case map[string]interface{}: + for _, elem := range value { + submap := recurseExtract(elem) + for key, value := range submap { + m[key] = value + } + } + + case []interface{}: + for _, elem := range value { + if v, is := extractVariable(elem); is { + m[v.name] = v.value + } + } + } + + return m +} + +type extractedValue struct { + name string + value string +} + +func extractVariable(value interface{}) (extractedValue, bool) { + sValue, ok := value.(string) + if !ok { + return extractedValue{}, false + } + matches := pattern.FindStringSubmatch(sValue) + if len(matches) == 0 { + return extractedValue{}, false + } + groups := matchGroups(matches) + if escaped := groups["escaped"]; escaped != "" { + return extractedValue{}, false + } + val := groups["named"] + if val == "" { + val = groups["braced"] + } + name := val + var defaultValue string + switch { + case strings.Contains(val, ":?"): + name, _ = partition(val, ":?") + case strings.Contains(val, "?"): + name, _ = partition(val, "?") + case strings.Contains(val, ":-"): + name, defaultValue = partition(val, ":-") + case strings.Contains(val, "-"): + name, defaultValue = partition(val, "-") + } + return extractedValue{name: name, value: defaultValue}, true +} + // Soft default (fall back if unset or empty) func softDefault(substitution string, mapping Mapping) (string, bool, error) { - if !strings.Contains(substitution, ":-") { - return "", false, nil - } - name, defaultValue := partition(substitution, ":-") - value, ok := mapping(name) - if !ok || value == "" { - return defaultValue, true, nil - } - return value, true, nil + return withDefault(substitution, mapping, "-:") } // Hard default (fall back if-and-only-if empty) func hardDefault(substitution string, mapping Mapping) (string, bool, error) { - if !strings.Contains(substitution, "-") { + return withDefault(substitution, mapping, "-") +} + +func withDefault(substitution string, mapping Mapping, sep string) (string, bool, error) { + if !strings.Contains(substitution, sep) { return "", false, nil } - name, defaultValue := partition(substitution, "-") + name, defaultValue := partition(substitution, sep) value, ok := mapping(name) if !ok { return defaultValue, true, nil @@ -120,26 +186,20 @@ func hardDefault(substitution string, mapping Mapping) (string, bool, error) { } func requiredNonEmpty(substitution string, mapping Mapping) (string, bool, error) { - if !strings.Contains(substitution, ":?") { - return "", false, nil - } - name, errorMessage := partition(substitution, ":?") - value, ok := mapping(name) - if !ok || value == "" { - return "", true, &InvalidTemplateError{ - Template: fmt.Sprintf("required variable %s is missing a value: %s", name, errorMessage), - } - } - return value, true, nil + return withRequired(substitution, mapping, ":?", func(v string) bool { return v != "" }) } func required(substitution string, mapping Mapping) (string, bool, error) { - if !strings.Contains(substitution, "?") { + return withRequired(substitution, mapping, "?", func(_ string) bool { return true }) +} + +func withRequired(substitution string, mapping Mapping, sep string, valid func(string) bool) (string, bool, error) { + if !strings.Contains(substitution, sep) { return "", false, nil } - name, errorMessage := partition(substitution, "?") + name, errorMessage := partition(substitution, sep) value, ok := mapping(name) - if !ok { + if !ok || !valid(value) { return "", true, &InvalidTemplateError{ Template: fmt.Sprintf("required variable %s is missing a value: %s", name, errorMessage), } diff --git a/cli/compose/template/template_test.go b/cli/compose/template/template_test.go index 4d2f030816..48c588c702 100644 --- a/cli/compose/template/template_test.go +++ b/cli/compose/template/template_test.go @@ -172,3 +172,91 @@ func TestSubstituteWithCustomFunc(t *testing.T) { _, err = SubstituteWith("ok ${NOTHERE}", defaultMapping, pattern, errIsMissing) assert.Check(t, is.ErrorContains(err, "required variable")) } + +func TestExtractVariables(t *testing.T) { + testCases := []struct { + dict map[string]interface{} + expected map[string]string + }{ + { + dict: map[string]interface{}{}, + expected: map[string]string{}, + }, + { + dict: map[string]interface{}{ + "foo": "bar", + }, + expected: map[string]string{}, + }, + { + dict: map[string]interface{}{ + "foo": "$bar", + }, + expected: map[string]string{ + "bar": "", + }, + }, + { + dict: map[string]interface{}{ + "foo": "${bar}", + }, + expected: map[string]string{ + "bar": "", + }, + }, + { + dict: map[string]interface{}{ + "foo": "${bar?:foo}", + }, + expected: map[string]string{ + "bar": "", + }, + }, + { + dict: map[string]interface{}{ + "foo": "${bar?foo}", + }, + expected: map[string]string{ + "bar": "", + }, + }, + { + dict: map[string]interface{}{ + "foo": "${bar:-foo}", + }, + expected: map[string]string{ + "bar": "foo", + }, + }, + { + dict: map[string]interface{}{ + "foo": "${bar-foo}", + }, + expected: map[string]string{ + "bar": "foo", + }, + }, + { + dict: map[string]interface{}{ + "foo": "${bar:-foo}", + "bar": map[string]interface{}{ + "foo": "${fruit:-banana}", + "bar": "vegetable", + }, + "baz": []interface{}{ + "foo", + "$toto", + }, + }, + expected: map[string]string{ + "bar": "foo", + "fruit": "banana", + "toto": "", + }, + }, + } + for _, tc := range testCases { + actual := ExtractVariables(tc.dict) + assert.Check(t, is.DeepEqual(actual, tc.expected)) + } +}