diff --git a/opts/env.go b/opts/env.go index e6ddd73309..d21c8ccbef 100644 --- a/opts/env.go +++ b/opts/env.go @@ -1,46 +1,30 @@ package opts import ( - "fmt" "os" - "runtime" "strings" + + "github.com/pkg/errors" ) // ValidateEnv validates an environment variable and returns it. -// If no value is specified, it returns the current value using os.Getenv. +// If no value is specified, it obtains its value from the current environment // // As on ParseEnvFile and related to #16585, environment variable names -// are not validate what so ever, it's up to application inside docker +// are not validated, and it's up to the application inside the container // to validate them or not. // // The only validation here is to check if name is empty, per #25099 func ValidateEnv(val string) (string, error) { - arr := strings.Split(val, "=") + arr := strings.SplitN(val, "=", 2) if arr[0] == "" { - return "", fmt.Errorf("invalid environment variable: %s", val) + return "", errors.New("invalid environment variable: " + val) } if len(arr) > 1 { return val, nil } - if !doesEnvExist(val) { - return val, nil + if envVal, ok := os.LookupEnv(arr[0]); ok { + return arr[0] + "=" + envVal, nil } - return fmt.Sprintf("%s=%s", val, os.Getenv(val)), nil -} - -func doesEnvExist(name string) bool { - for _, entry := range os.Environ() { - parts := strings.SplitN(entry, "=", 2) - if runtime.GOOS == "windows" { - // Environment variable are case-insensitive on Windows. PaTh, path and PATH are equivalent. - if strings.EqualFold(parts[0], name) { - return true - } - } - if parts[0] == name { - return true - } - } - return false + return val, nil } diff --git a/opts/env_test.go b/opts/env_test.go index 6f6c7a7a29..3561ceb314 100644 --- a/opts/env_test.go +++ b/opts/env_test.go @@ -5,38 +5,115 @@ import ( "os" "runtime" "testing" + + "gotest.tools/v3/assert" ) func TestValidateEnv(t *testing.T) { - valids := map[string]string{ - "a": "a", - "something": "something", - "_=a": "_=a", - "env1=value1": "env1=value1", - "_env1=value1": "_env1=value1", - "env2=value2=value3": "env2=value2=value3", - "env3=abc!qwe": "env3=abc!qwe", - "env_4=value 4": "env_4=value 4", - "PATH": fmt.Sprintf("PATH=%v", os.Getenv("PATH")), - "PATH=something": "PATH=something", - "asd!qwe": "asd!qwe", - "1asd": "1asd", - "123": "123", - "some space": "some space", - " some space before": " some space before", - "some space after ": "some space after ", + type testCase struct { + value string + expected string + err error } - // Environment variables are case in-sensitive on Windows + tests := []testCase{ + { + value: "a", + expected: "a", + }, + { + value: "something", + expected: "something", + }, + { + value: "_=a", + expected: "_=a", + }, + { + value: "env1=value1", + expected: "env1=value1", + }, + { + value: "_env1=value1", + expected: "_env1=value1", + }, + { + value: "env2=value2=value3", + expected: "env2=value2=value3", + }, + { + value: "env3=abc!qwe", + expected: "env3=abc!qwe", + }, + { + value: "env_4=value 4", + expected: "env_4=value 4", + }, + { + value: "PATH", + expected: fmt.Sprintf("PATH=%v", os.Getenv("PATH")), + }, + { + value: "=a", + err: fmt.Errorf("invalid environment variable: =a"), + }, + { + value: "PATH=", + expected: "PATH=", + }, + { + value: "PATH=something", + expected: "PATH=something", + }, + { + value: "asd!qwe", + expected: "asd!qwe", + }, + { + value: "1asd", + expected: "1asd", + }, + { + value: "123", + expected: "123", + }, + { + value: "some space", + expected: "some space", + }, + { + value: " some space before", + expected: " some space before", + }, + { + value: "some space after ", + expected: "some space after ", + }, + { + value: "=", + err: fmt.Errorf("invalid environment variable: ="), + }, + } + if runtime.GOOS == "windows" { - valids["PaTh"] = fmt.Sprintf("PaTh=%v", os.Getenv("PATH")) + // Environment variables are case in-sensitive on Windows + tests = append(tests, testCase{ + value: "PaTh", + expected: fmt.Sprintf("PaTh=%v", os.Getenv("PATH")), + err: nil, + }) } - for value, expected := range valids { - actual, err := ValidateEnv(value) - if err != nil { - t.Fatal(err) - } - if actual != expected { - t.Fatalf("Expected [%v], got [%v]", expected, actual) - } + + for _, tc := range tests { + tc := tc + t.Run(tc.value, func(t *testing.T) { + actual, err := ValidateEnv(tc.value) + + if tc.err == nil { + assert.NilError(t, err) + } else { + assert.Error(t, err, tc.err.Error()) + } + assert.Equal(t, actual, tc.expected) + }) } }