diff --git a/cli/command/container/cp.go b/cli/command/container/cp.go index 525866f2a9..67e5e598ea 100644 --- a/cli/command/container/cp.go +++ b/cli/command/container/cp.go @@ -386,13 +386,12 @@ func splitCpArg(arg string) (container, path string) { return "", arg } - parts := strings.SplitN(arg, ":", 2) - - if len(parts) == 1 || strings.HasPrefix(parts[0], ".") { + container, path, ok := strings.Cut(arg, ":") + if !ok || strings.HasPrefix(container, ".") { // Either there's no `:` in the arg // OR it's an explicit local relative path like `./file:name.txt`. return "", arg } - return parts[0], parts[1] + return container, path } diff --git a/cli/command/container/opts.go b/cli/command/container/opts.go index fd32566867..e4db20ee3e 100644 --- a/cli/command/container/opts.go +++ b/cli/command/container/opts.go @@ -354,14 +354,13 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con toBind := bind if parsed.Type == string(mounttypes.TypeBind) { - if arr := strings.SplitN(bind, ":", 2); len(arr) == 2 { - hostPart := arr[0] + if hostPart, targetPath, ok := strings.Cut(bind, ":"); ok { if strings.HasPrefix(hostPart, "."+string(filepath.Separator)) || hostPart == "." { if absHostPart, err := filepath.Abs(hostPart); err == nil { hostPart = absHostPart } } - toBind = hostPart + ":" + arr[1] + toBind = hostPart + ":" + targetPath } } @@ -377,11 +376,8 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con // Can't evaluate options passed into --tmpfs until we actually mount tmpfs := make(map[string]string) for _, t := range copts.tmpfs.GetAll() { - if arr := strings.SplitN(t, ":", 2); len(arr) > 1 { - tmpfs[arr[0]] = arr[1] - } else { - tmpfs[arr[0]] = "" - } + k, v, _ := strings.Cut(t, ":") + tmpfs[k] = v } var ( @@ -390,7 +386,7 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con ) if len(copts.Args) > 0 { - runCmd = strslice.StrSlice(copts.Args) + runCmd = copts.Args } if copts.entrypoint != "" { @@ -529,13 +525,11 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con if haveHealthSettings { return nil, errors.Errorf("--no-healthcheck conflicts with --health-* options") } - test := strslice.StrSlice{"NONE"} - healthConfig = &container.HealthConfig{Test: test} + healthConfig = &container.HealthConfig{Test: strslice.StrSlice{"NONE"}} } else if haveHealthSettings { var probe strslice.StrSlice if copts.healthCmd != "" { - args := []string{"CMD-SHELL", copts.healthCmd} - probe = strslice.StrSlice(args) + probe = []string{"CMD-SHELL", copts.healthCmd} } if copts.healthInterval < 0 { return nil, errors.Errorf("--health-interval cannot be negative") @@ -822,12 +816,11 @@ func convertToStandardNotation(ports []string) ([]string, error) { if strings.Contains(publish, "=") { params := map[string]string{"protocol": "tcp"} for _, param := range strings.Split(publish, ",") { - opt := strings.Split(param, "=") - if len(opt) < 2 { + k, v, ok := strings.Cut(param, "=") + if !ok || k == "" { return optsList, errors.Errorf("invalid publish opts format (should be name=value but got '%s')", param) } - - params[opt[0]] = opt[1] + params[k] = v } optsList = append(optsList, fmt.Sprintf("%s:%s/%s", params["published"], params["target"], params["protocol"])) } else { @@ -848,22 +841,22 @@ func parseLoggingOpts(loggingDriver string, loggingOpts []string) (map[string]st // takes a local seccomp daemon, reads the file contents for sending to the daemon func parseSecurityOpts(securityOpts []string) ([]string, error) { for key, opt := range securityOpts { - con := strings.SplitN(opt, "=", 2) - if len(con) == 1 && con[0] != "no-new-privileges" { - if strings.Contains(opt, ":") { - con = strings.SplitN(opt, ":", 2) - } else { - return securityOpts, errors.Errorf("Invalid --security-opt: %q", opt) - } + k, v, ok := strings.Cut(opt, "=") + if !ok && k != "no-new-privileges" { + k, v, ok = strings.Cut(opt, ":") } - if con[0] == "seccomp" && con[1] != "unconfined" { - f, err := os.ReadFile(con[1]) + if (!ok || v == "") && k != "no-new-privileges" { + // "no-new-privileges" is the only option that does not require a value. + return securityOpts, errors.Errorf("Invalid --security-opt: %q", opt) + } + if k == "seccomp" && v != "unconfined" { + f, err := os.ReadFile(v) if err != nil { - return securityOpts, errors.Errorf("opening seccomp profile (%s) failed: %v", con[1], err) + return securityOpts, errors.Errorf("opening seccomp profile (%s) failed: %v", v, err) } b := bytes.NewBuffer(nil) if err := json.Compact(b, f); err != nil { - return securityOpts, errors.Errorf("compacting json for seccomp profile (%s) failed: %v", con[1], err) + return securityOpts, errors.Errorf("compacting json for seccomp profile (%s) failed: %v", v, err) } securityOpts[key] = fmt.Sprintf("seccomp=%s", b.Bytes()) } @@ -895,12 +888,11 @@ func parseSystemPaths(securityOpts []string) (filtered, maskedPaths, readonlyPat func parseStorageOpts(storageOpts []string) (map[string]string, error) { m := make(map[string]string) for _, option := range storageOpts { - if strings.Contains(option, "=") { - opt := strings.SplitN(option, "=", 2) - m[opt[0]] = opt[1] - } else { + k, v, ok := strings.Cut(option, "=") + if !ok { return nil, errors.Errorf("invalid storage option") } + m[k] = v } return m, nil } @@ -921,7 +913,8 @@ func parseDevice(device, serverOS string) (container.DeviceMapping, error) { func parseLinuxDevice(device string) (container.DeviceMapping, error) { var src, dst string permissions := "rwm" - arr := strings.Split(device, ":") + // We expect 3 parts at maximum; limit to 4 parts to detect invalid options. + arr := strings.SplitN(device, ":", 4) switch len(arr) { case 3: permissions = arr[2] diff --git a/cli/command/container/opts_test.go b/cli/command/container/opts_test.go index 7e78562a48..6421f80a70 100644 --- a/cli/command/container/opts_test.go +++ b/cli/command/container/opts_test.go @@ -649,8 +649,8 @@ func TestRunFlagsParseShmSize(t *testing.T) { func TestParseRestartPolicy(t *testing.T) { invalids := map[string]string{ - "always:2:3": "invalid restart policy format", - "on-failure:invalid": "maximum retry count must be an integer", + "always:2:3": "invalid restart policy format: maximum retry count must be an integer", + "on-failure:invalid": "invalid restart policy format: maximum retry count must be an integer", } valids := map[string]container.RestartPolicy{ "": {}, diff --git a/cli/command/network/connect.go b/cli/command/network/connect.go index ea4a7aa023..0c7552542d 100644 --- a/cli/command/network/connect.go +++ b/cli/command/network/connect.go @@ -81,13 +81,13 @@ func runConnect(dockerCli command.Cli, options connectOptions) error { func convertDriverOpt(opts []string) (map[string]string, error) { driverOpt := make(map[string]string) for _, opt := range opts { - parts := strings.SplitN(opt, "=", 2) - if len(parts) != 2 { + k, v, ok := strings.Cut(opt, "=") + // TODO(thaJeztah): we should probably not accept whitespace here (both for key and value). + k = strings.TrimSpace(k) + if !ok || k == "" { return nil, fmt.Errorf("invalid key/value pair format in driver options") } - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - driverOpt[key] = value + driverOpt[k] = strings.TrimSpace(v) } return driverOpt, nil } diff --git a/cli/command/service/opts.go b/cli/command/service/opts.go index 05b3b8055d..fa44b84962 100644 --- a/cli/command/service/opts.go +++ b/cli/command/service/opts.go @@ -93,17 +93,17 @@ func (opts *placementPrefOpts) String() string { // Note: in the future strategies other than "spread", may be supported, // as well as additional comma-separated options. func (opts *placementPrefOpts) Set(value string) error { - fields := strings.Split(value, "=") - if len(fields) != 2 { + strategy, arg, ok := strings.Cut(value, "=") + if !ok || strategy == "" { return errors.New(`placement preference must be of the format "="`) } - if fields[0] != "spread" { - return errors.Errorf("unsupported placement preference %s (only spread is supported)", fields[0]) + if strategy != "spread" { + return errors.Errorf("unsupported placement preference %s (only spread is supported)", strategy) } opts.prefs = append(opts.prefs, swarm.PlacementPreference{ Spread: &swarm.SpreadOver{ - SpreadDescriptor: fields[1], + SpreadDescriptor: arg, }, }) opts.strings = append(opts.strings, value) @@ -121,8 +121,11 @@ type ShlexOpt []string // Set the value func (s *ShlexOpt) Set(value string) error { valueSlice, err := shlex.Split(value) - *s = ShlexOpt(valueSlice) - return err + if err != nil { + return err + } + *s = valueSlice + return nil } // Type returns the tyep of the value @@ -475,10 +478,12 @@ func (opts *healthCheckOptions) toHealthConfig() (*container.HealthConfig, error // // This assumes input value (:) has already been validated func convertExtraHostsToSwarmHosts(extraHosts []string) []string { - hosts := []string{} + hosts := make([]string, 0, len(extraHosts)) for _, extraHost := range extraHosts { - parts := strings.SplitN(extraHost, ":", 2) - hosts = append(hosts, fmt.Sprintf("%s %s", parts[1], parts[0])) + host, ip, ok := strings.Cut(extraHost, ":") + if ok { + hosts = append(hosts, ip+" "+host) + } } return hosts } @@ -628,7 +633,7 @@ func (options *serviceOptions) makeEnv() ([]string, error) { } currentEnv := make([]string, 0, len(envVariables)) for _, env := range envVariables { // need to process each var, in order - k := strings.SplitN(env, "=", 2)[0] + k, _, _ := strings.Cut(env, "=") for i, current := range currentEnv { // remove duplicates if current == env { continue // no update required, may hide this behind flag to preserve order of envVariables diff --git a/cli/command/service/scale.go b/cli/command/service/scale.go index 4bc00e5050..f2625445eb 100644 --- a/cli/command/service/scale.go +++ b/cli/command/service/scale.go @@ -43,7 +43,7 @@ func scaleArgs(cmd *cobra.Command, args []string) error { return err } for _, arg := range args { - if parts := strings.SplitN(arg, "=", 2); len(parts) != 2 { + if k, v, ok := strings.Cut(arg, "="); !ok || k == "" || v == "" { return errors.Errorf( "Invalid scale specifier '%s'.\nSee '%s --help'.\n\nUsage: %s\n\n%s", arg, @@ -62,8 +62,7 @@ func runScale(dockerCli command.Cli, options *scaleOptions, args []string) error ctx := context.Background() for _, arg := range args { - parts := strings.SplitN(arg, "=", 2) - serviceID, scaleStr := parts[0], parts[1] + serviceID, scaleStr, _ := strings.Cut(arg, "=") // validate input arg scale number scale, err := strconv.ParseUint(scaleStr, 10, 64) diff --git a/cli/command/service/update.go b/cli/command/service/update.go index ff6f2fcf95..482b91bc2f 100644 --- a/cli/command/service/update.go +++ b/cli/command/service/update.go @@ -879,8 +879,8 @@ func removeConfigs(flags *pflag.FlagSet, spec *swarm.ContainerSpec, credSpecName } func envKey(value string) string { - kv := strings.SplitN(value, "=", 2) - return kv[0] + k, _, _ := strings.Cut(value, "=") + return k } func buildToRemoveSet(flags *pflag.FlagSet, flag string) map[string]struct{} { @@ -1174,12 +1174,8 @@ func updateHosts(flags *pflag.FlagSet, hosts *[]string) error { if flags.Changed(flagHostRemove) { extraHostsToRemove := flags.Lookup(flagHostRemove).Value.(*opts.ListOpts).GetAll() for _, entry := range extraHostsToRemove { - v := strings.SplitN(entry, ":", 2) - if len(v) > 1 { - toRemove = append(toRemove, hostMapping{IPAddr: v[1], Host: v[0]}) - } else { - toRemove = append(toRemove, hostMapping{Host: v[0]}) - } + hostName, ipAddr, _ := strings.Cut(entry, ":") + toRemove = append(toRemove, hostMapping{IPAddr: ipAddr, Host: hostName}) } } diff --git a/cli/command/stack/loader/loader.go b/cli/command/stack/loader/loader.go index 22352f434e..105e84e644 100644 --- a/cli/command/stack/loader/loader.go +++ b/cli/command/stack/loader/loader.go @@ -104,12 +104,12 @@ func GetConfigDetails(composefiles []string, stdin io.Reader) (composetypes.Conf func buildEnvironment(env []string) (map[string]string, error) { result := make(map[string]string, len(env)) for _, s := range env { - // if value is empty, s is like "K=", not "K". - if !strings.Contains(s, "=") { + k, v, ok := strings.Cut(s, "=") + if !ok || k == "" { return result, errors.Errorf("unexpected environment %q", s) } - kv := strings.SplitN(s, "=", 2) - result[kv[0]] = kv[1] + // value may be set, but empty if "s" is like "K=", not "K". + result[k] = v } return result, nil } diff --git a/cli/command/swarm/opts.go b/cli/command/swarm/opts.go index 4db2e303b1..acbadf9d0e 100644 --- a/cli/command/swarm/opts.go +++ b/cli/command/swarm/opts.go @@ -175,14 +175,12 @@ func parseExternalCA(caSpec string) (*swarm.ExternalCA, error) { ) for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - - if len(parts) != 2 { + key, value, ok := strings.Cut(field, "=") + if !ok { return nil, errors.Errorf("invalid field '%s' must be a key=value pair", field) } - key, value := parts[0], parts[1] - + // TODO(thaJeztah): these options should not be case-insensitive. switch strings.ToLower(key) { case "protocol": hasProtocol = true diff --git a/cli/command/utils.go b/cli/command/utils.go index 8913b9a361..753f428aa0 100644 --- a/cli/command/utils.go +++ b/cli/command/utils.go @@ -96,26 +96,26 @@ func PruneFilters(dockerCli Cli, pruneFilters filters.Args) filters.Args { return pruneFilters } for _, f := range dockerCli.ConfigFile().PruneFilters { - parts := strings.SplitN(f, "=", 2) - if len(parts) != 2 { + k, v, ok := strings.Cut(f, "=") + if !ok { continue } - if parts[0] == "label" { + if k == "label" { // CLI label filter supersede config.json. // If CLI label filter conflict with config.json, // skip adding label! filter in config.json. - if pruneFilters.Contains("label!") && pruneFilters.ExactMatch("label!", parts[1]) { + if pruneFilters.Contains("label!") && pruneFilters.ExactMatch("label!", v) { continue } - } else if parts[0] == "label!" { + } else if k == "label!" { // CLI label! filter supersede config.json. // If CLI label! filter conflict with config.json, // skip adding label filter in config.json. - if pruneFilters.Contains("label") && pruneFilters.ExactMatch("label", parts[1]) { + if pruneFilters.Contains("label") && pruneFilters.ExactMatch("label", v) { continue } } - pruneFilters.Add(parts[0], parts[1]) + pruneFilters.Add(k, v) } return pruneFilters diff --git a/cli/command/volume/create.go b/cli/command/volume/create.go index 6f89933597..a3219da92a 100644 --- a/cli/command/volume/create.go +++ b/cli/command/volume/create.go @@ -165,9 +165,9 @@ func runCreate(dockerCli command.Cli, options createOptions) error { // comma-separated list of equal separated maps segments := map[string]string{} for _, segment := range strings.Split(top, ",") { - parts := strings.SplitN(segment, "=", 2) // TODO(dperny): validate topology syntax - segments[parts[0]] = parts[1] + k, v, _ := strings.Cut(segment, "=") + segments[k] = v } topology.Requisite = append( topology.Requisite, @@ -180,9 +180,9 @@ func runCreate(dockerCli command.Cli, options createOptions) error { // comma-separated list of equal separated maps segments := map[string]string{} for _, segment := range strings.Split(top, ",") { - parts := strings.SplitN(segment, "=", 2) // TODO(dperny): validate topology syntax - segments[parts[0]] = parts[1] + k, v, _ := strings.Cut(segment, "=") + segments[k] = v } topology.Preferred = append( diff --git a/cli/compose/convert/service.go b/cli/compose/convert/service.go index e1ba6ffb99..7e0b8cfc99 100644 --- a/cli/compose/convert/service.go +++ b/cli/compose/convert/service.go @@ -427,11 +427,11 @@ func uint32Ptr(value uint32) *uint32 { // convertExtraHosts converts : mappings to SwarmKit notation: // "IP-address hostname(s)". The original order of mappings is preserved. func convertExtraHosts(extraHosts composetypes.HostsList) []string { - hosts := []string{} + hosts := make([]string, 0, len(extraHosts)) for _, hostIP := range extraHosts { - if v := strings.SplitN(hostIP, ":", 2); len(v) == 2 { + if hostName, ipAddr, ok := strings.Cut(hostIP, ":"); ok { // Convert to SwarmKit notation: IP-address hostname(s) - hosts = append(hosts, fmt.Sprintf("%s %s", v[1], v[0])) + hosts = append(hosts, ipAddr+" "+hostName) } } return hosts diff --git a/cli/compose/loader/loader.go b/cli/compose/loader/loader.go index 57a4bffc32..a9b8f1f1f9 100644 --- a/cli/compose/loader/loader.go +++ b/cli/compose/loader/loader.go @@ -829,21 +829,20 @@ func transformListOrMapping(listOrMapping interface{}, sep string, allowNil bool } func transformMappingOrList(mappingOrList interface{}, sep string, allowNil bool) interface{} { - switch value := mappingOrList.(type) { + switch values := mappingOrList.(type) { case map[string]interface{}: - return toMapStringString(value, allowNil) - case ([]interface{}): + return toMapStringString(values, allowNil) + case []interface{}: result := make(map[string]interface{}) - for _, value := range value { - parts := strings.SplitN(value.(string), sep, 2) - key := parts[0] + for _, v := range values { + key, val, hasValue := strings.Cut(v.(string), sep) switch { - case len(parts) == 1 && allowNil: + case !hasValue && allowNil: result[key] = nil - case len(parts) == 1 && !allowNil: + case !hasValue && !allowNil: result[key] = "" default: - result[key] = parts[1] + result[key] = val } } return result diff --git a/cli/compose/template/template.go b/cli/compose/template/template.go index 12927ddeda..f8e507a90b 100644 --- a/cli/compose/template/template.go +++ b/cli/compose/template/template.go @@ -239,9 +239,9 @@ func matchGroups(matches []string, pattern *regexp.Regexp) map[string]string { // // If the separator is not found, return the string itself, followed by an empty string. func partition(s, sep string) (string, string) { - if strings.Contains(s, sep) { - parts := strings.SplitN(s, sep, 2) - return parts[0], parts[1] + k, v, ok := strings.Cut(s, sep) + if !ok { + return s, "" } - return s, "" + return k, v } diff --git a/cli/config/configfile/file.go b/cli/config/configfile/file.go index 796b0a0aed..92ff69368a 100644 --- a/cli/config/configfile/file.go +++ b/cli/config/configfile/file.go @@ -241,12 +241,11 @@ func decodeAuth(authStr string) (string, string, error) { if n > decLen { return "", "", errors.Errorf("Something went wrong decoding auth config") } - arr := strings.SplitN(string(decoded), ":", 2) - if len(arr) != 2 { + userName, password, ok := strings.Cut(string(decoded), ":") + if !ok || userName == "" { return "", "", errors.Errorf("Invalid auth configuration file") } - password := strings.Trim(arr[1], "\x00") - return arr[0], password, nil + return userName, strings.Trim(password, "\x00"), nil } // GetCredentialsStore returns a new credentials store from the settings in the diff --git a/cli/config/credentials/file_store.go b/cli/config/credentials/file_store.go index e509820b73..de1c676e50 100644 --- a/cli/config/credentials/file_store.go +++ b/cli/config/credentials/file_store.go @@ -75,7 +75,6 @@ func ConvertToHostname(url string) string { stripped = strings.TrimPrefix(url, "https://") } - nameParts := strings.SplitN(stripped, "/", 2) - - return nameParts[0] + hostName, _, _ := strings.Cut(stripped, "/") + return hostName } diff --git a/internal/test/strings.go b/internal/test/strings.go index d001c8f3f8..ccbfbbccdd 100644 --- a/internal/test/strings.go +++ b/internal/test/strings.go @@ -14,16 +14,14 @@ func CompareMultipleValues(t *testing.T, value, expected string) { // be guaranteed to have the same order as our expected value // We'll create maps and use reflect.DeepEquals to check instead: entriesMap := make(map[string]string) - expMap := make(map[string]string) - entries := strings.Split(value, ",") - expectedEntries := strings.Split(expected, ",") - for _, entry := range entries { - keyval := strings.Split(entry, "=") - entriesMap[keyval[0]] = keyval[1] + for _, entry := range strings.Split(value, ",") { + k, v, _ := strings.Cut(entry, "=") + entriesMap[k] = v } - for _, expected := range expectedEntries { - keyval := strings.Split(expected, "=") - expMap[keyval[0]] = keyval[1] + expMap := make(map[string]string) + for _, exp := range strings.Split(expected, ",") { + k, v, _ := strings.Cut(exp, "=") + expMap[k] = v } assert.Check(t, is.DeepEqual(expMap, entriesMap)) } diff --git a/opts/config.go b/opts/config.go index 40bc13e251..3be0fa93dd 100644 --- a/opts/config.go +++ b/opts/config.go @@ -40,25 +40,23 @@ func (o *ConfigOpt) Set(value string) error { } for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - key := strings.ToLower(parts[0]) - - if len(parts) != 2 { + key, val, ok := strings.Cut(field, "=") + if !ok || key == "" { return fmt.Errorf("invalid field '%s' must be a key=value pair", field) } - value := parts[1] - switch key { + // TODO(thaJeztah): these options should not be case-insensitive. + switch strings.ToLower(key) { case "source", "src": - options.ConfigName = value + options.ConfigName = val case "target": - options.File.Name = value + options.File.Name = val case "uid": - options.File.UID = value + options.File.UID = val case "gid": - options.File.GID = value + options.File.GID = val case "mode": - m, err := strconv.ParseUint(value, 0, 32) + m, err := strconv.ParseUint(val, 0, 32) if err != nil { return fmt.Errorf("invalid mode specified: %v", err) } diff --git a/opts/env.go b/opts/env.go index d21c8ccbef..214d6f4400 100644 --- a/opts/env.go +++ b/opts/env.go @@ -16,15 +16,16 @@ import ( // // The only validation here is to check if name is empty, per #25099 func ValidateEnv(val string) (string, error) { - arr := strings.SplitN(val, "=", 2) - if arr[0] == "" { + k, _, hasValue := strings.Cut(val, "=") + if k == "" { return "", errors.New("invalid environment variable: " + val) } - if len(arr) > 1 { + if hasValue { + // val contains a "=" (but value may be an empty string) return val, nil } - if envVal, ok := os.LookupEnv(arr[0]); ok { - return arr[0] + "=" + envVal, nil + if envVal, ok := os.LookupEnv(k); ok { + return k + "=" + envVal, nil } return val, nil } diff --git a/opts/file.go b/opts/file.go index 2346cc1670..72b90e117f 100644 --- a/opts/file.go +++ b/opts/file.go @@ -46,10 +46,10 @@ func parseKeyValueFile(filename string, emptyFn func(string) (string, bool)) ([] currentLine++ // line is not empty, and not starting with '#' if len(line) > 0 && !strings.HasPrefix(line, "#") { - data := strings.SplitN(line, "=", 2) + variable, value, hasValue := strings.Cut(line, "=") // trim the front of a variable, but nothing else - variable := strings.TrimLeft(data[0], whiteSpaces) + variable = strings.TrimLeft(variable, whiteSpaces) if strings.ContainsAny(variable, whiteSpaces) { return []string{}, ErrBadKey{fmt.Sprintf("variable '%s' contains whitespaces", variable)} } @@ -57,18 +57,17 @@ func parseKeyValueFile(filename string, emptyFn func(string) (string, bool)) ([] return []string{}, ErrBadKey{fmt.Sprintf("no variable name on line '%s'", line)} } - if len(data) > 1 { + if hasValue { // pass the value through, no trimming - lines = append(lines, fmt.Sprintf("%s=%s", variable, data[1])) + lines = append(lines, variable+"="+value) } else { - var value string var present bool if emptyFn != nil { value, present = emptyFn(line) } if present { // if only a pass-through variable is given, clean it up. - lines = append(lines, fmt.Sprintf("%s=%s", strings.TrimSpace(line), value)) + lines = append(lines, strings.TrimSpace(variable)+"="+value) } } } diff --git a/opts/gpus.go b/opts/gpus.go index 8796a805d4..93bf939786 100644 --- a/opts/gpus.go +++ b/opts/gpus.go @@ -38,14 +38,13 @@ func (o *GpuOpts) Set(value string) error { seen := map[string]struct{}{} // Set writable as the default for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - key := parts[0] + key, val, withValue := strings.Cut(field, "=") if _, ok := seen[key]; ok { return fmt.Errorf("gpu request key '%s' can be specified only once", key) } seen[key] = struct{}{} - if len(parts) == 1 { + if !withValue { seen["count"] = struct{}{} req.Count, err = parseCount(key) if err != nil { @@ -54,21 +53,20 @@ func (o *GpuOpts) Set(value string) error { continue } - value := parts[1] switch key { case "driver": - req.Driver = value + req.Driver = val case "count": - req.Count, err = parseCount(value) + req.Count, err = parseCount(val) if err != nil { return err } case "device": - req.DeviceIDs = strings.Split(value, ",") + req.DeviceIDs = strings.Split(val, ",") case "capabilities": - req.Capabilities = [][]string{append(strings.Split(value, ","), "gpu")} + req.Capabilities = [][]string{append(strings.Split(val, ","), "gpu")} case "options": - r := csv.NewReader(strings.NewReader(value)) + r := csv.NewReader(strings.NewReader(val)) optFields, err := r.Read() if err != nil { return errors.Wrap(err, "failed to read gpu options") diff --git a/opts/hosts.go b/opts/hosts.go index d59421b308..7cdd1218f7 100644 --- a/opts/hosts.go +++ b/opts/hosts.go @@ -33,6 +33,8 @@ const ( ) // ValidateHost validates that the specified string is a valid host and returns it. +// +// TODO(thaJeztah): ValidateHost appears to be unused; deprecate it. func ValidateHost(val string) (string, error) { host := strings.TrimSpace(val) // The empty string means default and is not handled by parseDockerDaemonHost @@ -69,18 +71,19 @@ func ParseHost(defaultToTLS bool, val string) (string, error) { // parseDockerDaemonHost parses the specified address and returns an address that will be used as the host. // Depending of the address specified, this may return one of the global Default* strings defined in hosts.go. func parseDockerDaemonHost(addr string) (string, error) { - addrParts := strings.SplitN(addr, "://", 2) - if len(addrParts) == 1 && addrParts[0] != "" { - addrParts = []string{"tcp", addrParts[0]} + proto, host, hasProto := strings.Cut(addr, "://") + if !hasProto && proto != "" { + host = proto + proto = "tcp" } - switch addrParts[0] { + switch proto { case "tcp": - return ParseTCPAddr(addrParts[1], defaultTCPHost) + return ParseTCPAddr(host, defaultTCPHost) case "unix": - return parseSimpleProtoAddr("unix", addrParts[1], defaultUnixSocket) + return parseSimpleProtoAddr(proto, host, defaultUnixSocket) case "npipe": - return parseSimpleProtoAddr("npipe", addrParts[1], defaultNamedPipe) + return parseSimpleProtoAddr(proto, host, defaultNamedPipe) case "fd": return addr, nil case "ssh": @@ -160,16 +163,18 @@ func ParseTCPAddr(tryAddr string, defaultAddr string) (string, error) { // ValidateExtraHost validates that the specified string is a valid extrahost and returns it. // ExtraHost is in the form of name:ip where the ip has to be a valid ip (IPv4 or IPv6). +// +// TODO(thaJeztah): remove client-side validation, and delegate to the API server. func ValidateExtraHost(val string) (string, error) { // allow for IPv6 addresses in extra hosts by only splitting on first ":" - arr := strings.SplitN(val, ":", 2) - if len(arr) != 2 || len(arr[0]) == 0 { + k, v, ok := strings.Cut(val, ":") + if !ok || k == "" { return "", fmt.Errorf("bad format for add-host: %q", val) } // Skip IPaddr validation for "host-gateway" string - if arr[1] != hostGatewayName { - if _, err := ValidateIPAddress(arr[1]); err != nil { - return "", fmt.Errorf("invalid IP address in add-host: %q", arr[1]) + if v != hostGatewayName { + if _, err := ValidateIPAddress(v); err != nil { + return "", fmt.Errorf("invalid IP address in add-host: %q", v) } } return val, nil diff --git a/opts/mount.go b/opts/mount.go index 7ffc3acc93..2b531127eb 100644 --- a/opts/mount.go +++ b/opts/mount.go @@ -56,21 +56,21 @@ func (m *MountOpt) Set(value string) error { } setValueOnMap := func(target map[string]string, value string) { - parts := strings.SplitN(value, "=", 2) - if len(parts) == 1 { - target[value] = "" - } else { - target[parts[0]] = parts[1] + k, v, _ := strings.Cut(value, "=") + if k != "" { + target[k] = v } } mount.Type = mounttypes.TypeVolume // default to volume mounts // Set writable as the default for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - key := strings.ToLower(parts[0]) + key, val, ok := strings.Cut(field, "=") - if len(parts) == 1 { + // TODO(thaJeztah): these options should not be case-insensitive. + key = strings.ToLower(key) + + if !ok { switch key { case "readonly", "ro": mount.ReadOnly = true @@ -81,64 +81,61 @@ func (m *MountOpt) Set(value string) error { case "bind-nonrecursive": bindOptions().NonRecursive = true continue + default: + return fmt.Errorf("invalid field '%s' must be a key=value pair", field) } } - if len(parts) != 2 { - return fmt.Errorf("invalid field '%s' must be a key=value pair", field) - } - - value := parts[1] switch key { case "type": - mount.Type = mounttypes.Type(strings.ToLower(value)) + mount.Type = mounttypes.Type(strings.ToLower(val)) case "source", "src": - mount.Source = value - if strings.HasPrefix(value, "."+string(filepath.Separator)) || value == "." { - if abs, err := filepath.Abs(value); err == nil { + mount.Source = val + if strings.HasPrefix(val, "."+string(filepath.Separator)) || val == "." { + if abs, err := filepath.Abs(val); err == nil { mount.Source = abs } } case "target", "dst", "destination": - mount.Target = value + mount.Target = val case "readonly", "ro": - mount.ReadOnly, err = strconv.ParseBool(value) + mount.ReadOnly, err = strconv.ParseBool(val) if err != nil { - return fmt.Errorf("invalid value for %s: %s", key, value) + return fmt.Errorf("invalid value for %s: %s", key, val) } case "consistency": - mount.Consistency = mounttypes.Consistency(strings.ToLower(value)) + mount.Consistency = mounttypes.Consistency(strings.ToLower(val)) case "bind-propagation": - bindOptions().Propagation = mounttypes.Propagation(strings.ToLower(value)) + bindOptions().Propagation = mounttypes.Propagation(strings.ToLower(val)) case "bind-nonrecursive": - bindOptions().NonRecursive, err = strconv.ParseBool(value) + bindOptions().NonRecursive, err = strconv.ParseBool(val) if err != nil { - return fmt.Errorf("invalid value for %s: %s", key, value) + return fmt.Errorf("invalid value for %s: %s", key, val) } case "volume-nocopy": - volumeOptions().NoCopy, err = strconv.ParseBool(value) + volumeOptions().NoCopy, err = strconv.ParseBool(val) if err != nil { - return fmt.Errorf("invalid value for volume-nocopy: %s", value) + return fmt.Errorf("invalid value for volume-nocopy: %s", val) } case "volume-label": - setValueOnMap(volumeOptions().Labels, value) + setValueOnMap(volumeOptions().Labels, val) case "volume-driver": - volumeOptions().DriverConfig.Name = value + volumeOptions().DriverConfig.Name = val case "volume-opt": if volumeOptions().DriverConfig.Options == nil { volumeOptions().DriverConfig.Options = make(map[string]string) } - setValueOnMap(volumeOptions().DriverConfig.Options, value) + setValueOnMap(volumeOptions().DriverConfig.Options, val) case "tmpfs-size": - sizeBytes, err := units.RAMInBytes(value) + sizeBytes, err := units.RAMInBytes(val) if err != nil { - return fmt.Errorf("invalid value for %s: %s", key, value) + return fmt.Errorf("invalid value for %s: %s", key, val) } tmpfsOptions().SizeBytes = sizeBytes case "tmpfs-mode": - ui64, err := strconv.ParseUint(value, 8, 32) + ui64, err := strconv.ParseUint(val, 8, 32) if err != nil { - return fmt.Errorf("invalid value for %s: %s", key, value) + return fmt.Errorf("invalid value for %s: %s", key, val) } tmpfsOptions().Mode = os.FileMode(ui64) default: diff --git a/opts/network.go b/opts/network.go index ce7370ee0e..12c3977b1b 100644 --- a/opts/network.go +++ b/opts/network.go @@ -48,34 +48,33 @@ func (n *NetworkOpt) Set(value string) error { netOpt.Aliases = []string{} for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - - if len(parts) < 2 { + // TODO(thaJeztah): these options should not be case-insensitive. + key, val, ok := strings.Cut(strings.ToLower(field), "=") + if !ok || key == "" { return fmt.Errorf("invalid field %s", field) } - key := strings.TrimSpace(strings.ToLower(parts[0])) - value := strings.TrimSpace(strings.ToLower(parts[1])) + key = strings.TrimSpace(key) + val = strings.TrimSpace(val) switch key { case networkOptName: - netOpt.Target = value + netOpt.Target = val case networkOptAlias: - netOpt.Aliases = append(netOpt.Aliases, value) + netOpt.Aliases = append(netOpt.Aliases, val) case networkOptIPv4Address: - netOpt.IPv4Address = value + netOpt.IPv4Address = val case networkOptIPv6Address: - netOpt.IPv6Address = value + netOpt.IPv6Address = val case driverOpt: - key, value, err = parseDriverOpt(value) - if err == nil { - if netOpt.DriverOpts == nil { - netOpt.DriverOpts = make(map[string]string) - } - netOpt.DriverOpts[key] = value - } else { + key, val, err = parseDriverOpt(val) + if err != nil { return err } + if netOpt.DriverOpts == nil { + netOpt.DriverOpts = make(map[string]string) + } + netOpt.DriverOpts[key] = val default: return fmt.Errorf("invalid field key %s", key) } @@ -116,11 +115,13 @@ func (n *NetworkOpt) NetworkMode() string { } func parseDriverOpt(driverOpt string) (string, string, error) { - parts := strings.SplitN(driverOpt, "=", 2) - if len(parts) != 2 { + // TODO(thaJeztah): these options should not be case-insensitive. + // TODO(thaJeztah): should value be converted to lowercase as well, or only the key? + key, value, ok := strings.Cut(strings.ToLower(driverOpt), "=") + if !ok || key == "" { return "", "", fmt.Errorf("invalid key value pair format in driver options") } - key := strings.TrimSpace(strings.ToLower(parts[0])) - value := strings.TrimSpace(strings.ToLower(parts[1])) + key = strings.TrimSpace(key) + value = strings.TrimSpace(value) return key, value, nil } diff --git a/opts/opts.go b/opts/opts.go index 03550023b0..4e7790f0fc 100644 --- a/opts/opts.go +++ b/opts/opts.go @@ -55,7 +55,7 @@ func (opts *ListOpts) Set(value string) error { } value = v } - (*opts.values) = append((*opts.values), value) + *opts.values = append(*opts.values, value) return nil } @@ -63,7 +63,7 @@ func (opts *ListOpts) Set(value string) error { func (opts *ListOpts) Delete(key string) { for i, k := range *opts.values { if k == key { - (*opts.values) = append((*opts.values)[:i], (*opts.values)[i+1:]...) + *opts.values = append((*opts.values)[:i], (*opts.values)[i+1:]...) return } } @@ -81,7 +81,7 @@ func (opts *ListOpts) GetMap() map[string]struct{} { // GetAll returns the values of slice. func (opts *ListOpts) GetAll() []string { - return (*opts.values) + return *opts.values } // GetAllOrEmpty returns the values of the slice @@ -106,7 +106,7 @@ func (opts *ListOpts) Get(key string) bool { // Len returns the amount of element in the slice. func (opts *ListOpts) Len() int { - return len((*opts.values)) + return len(*opts.values) } // Type returns a string name for this Option type @@ -165,12 +165,8 @@ func (opts *MapOpts) Set(value string) error { } value = v } - vals := strings.SplitN(value, "=", 2) - if len(vals) == 1 { - (opts.values)[vals[0]] = "" - } else { - (opts.values)[vals[0]] = vals[1] - } + k, v, _ := strings.Cut(value, "=") + opts.values[k] = v return nil } @@ -277,16 +273,16 @@ func validateDomain(val string) (string, error) { // // TODO discuss if quotes (and other special characters) should be valid or invalid for keys // TODO discuss if leading/trailing whitespace in keys should be preserved (and valid) -func ValidateLabel(val string) (string, error) { - arr := strings.SplitN(val, "=", 2) - key := strings.TrimLeft(arr[0], whiteSpaces) +func ValidateLabel(value string) (string, error) { + key, _, _ := strings.Cut(value, "=") + key = strings.TrimLeft(key, whiteSpaces) if key == "" { - return "", fmt.Errorf("invalid label '%s': empty name", val) + return "", fmt.Errorf("invalid label '%s': empty name", value) } if strings.ContainsAny(key, whiteSpaces) { return "", fmt.Errorf("label '%s' contains whitespaces", key) } - return val, nil + return value, nil } // ValidateSysctl validates a sysctl and returns it. @@ -305,20 +301,19 @@ func ValidateSysctl(val string) (string, error) { "net.", "fs.mqueue.", } - arr := strings.Split(val, "=") - if len(arr) < 2 { - return "", fmt.Errorf("sysctl '%s' is not whitelisted", val) + k, _, ok := strings.Cut(val, "=") + if !ok || k == "" { + return "", fmt.Errorf("sysctl '%s' is not allowed", val) } - if validSysctlMap[arr[0]] { + if validSysctlMap[k] { return val, nil } - for _, vp := range validSysctlPrefixes { - if strings.HasPrefix(arr[0], vp) { + if strings.HasPrefix(k, vp) { return val, nil } } - return "", fmt.Errorf("sysctl '%s' is not whitelisted", val) + return "", fmt.Errorf("sysctl '%s' is not allowed", val) } // FilterOpt is a flag type for validating filters @@ -347,11 +342,12 @@ func (o *FilterOpt) Set(value string) error { if !strings.Contains(value, "=") { return errors.New("bad format of filter (expected name=value)") } - f := strings.SplitN(value, "=", 2) - name := strings.ToLower(strings.TrimSpace(f[0])) - value = strings.TrimSpace(f[1]) + name, val, _ := strings.Cut(value, "=") - o.filter.Add(name, value) + // TODO(thaJeztah): these options should not be case-insensitive. + name = strings.ToLower(strings.TrimSpace(name)) + val = strings.TrimSpace(val) + o.filter.Add(name, val) return nil } @@ -411,10 +407,14 @@ func ParseLink(val string) (string, string, error) { if val == "" { return "", "", fmt.Errorf("empty string specified for links") } - arr := strings.Split(val, ":") + // We expect two parts, but restrict to three to allow detecting invalid formats. + arr := strings.SplitN(val, ":", 3) + + // TODO(thaJeztah): clean up this logic!! if len(arr) > 2 { return "", "", fmt.Errorf("bad format for links: %s", val) } + // TODO(thaJeztah): this should trim the "/" prefix as well?? if len(arr) == 1 { return val, val, nil } @@ -422,6 +422,7 @@ func ParseLink(val string) (string, string, error) { // from an already created container and the format is not `foo:bar` // but `/foo:/c1/bar` if strings.HasPrefix(arr[0], "/") { + // TODO(thaJeztah): clean up this logic!! _, alias := path.Split(arr[1]) return arr[0][1:], alias, nil } diff --git a/opts/opts_test.go b/opts/opts_test.go index 2b1c443417..1c91df5a4e 100644 --- a/opts/opts_test.go +++ b/opts/opts_test.go @@ -32,24 +32,30 @@ func TestValidateIPAddress(t *testing.T) { func TestMapOpts(t *testing.T) { tmpMap := make(map[string]string) - o := NewMapOpts(tmpMap, logOptsValidator) - o.Set("max-size=1") - if o.String() != "map[max-size:1]" { - t.Errorf("%s != [map[max-size:1]", o.String()) + o := NewMapOpts(tmpMap, sampleValidator) + err := o.Set("valid-option=1") + if err != nil { + t.Error(err) + } + if o.String() != "map[valid-option:1]" { + t.Errorf("%s != [map[valid-option:1]", o.String()) } - o.Set("max-file=2") + err = o.Set("valid-option2=2") + if err != nil { + t.Error(err) + } if len(tmpMap) != 2 { t.Errorf("map length %d != 2", len(tmpMap)) } - if tmpMap["max-file"] != "2" { - t.Errorf("max-file = %s != 2", tmpMap["max-file"]) + if tmpMap["valid-option"] != "1" { + t.Errorf("valid-option = %s != 1", tmpMap["valid-option"]) + } + if tmpMap["valid-option2"] != "2" { + t.Errorf("valid-option2 = %s != 2", tmpMap["valid-option2"]) } - if tmpMap["max-size"] != "1" { - t.Errorf("max-size = %s != 1", tmpMap["max-size"]) - } if o.Set("dummy-val=3") == nil { t.Error("validator is not being called") } @@ -57,15 +63,24 @@ func TestMapOpts(t *testing.T) { func TestListOptsWithoutValidator(t *testing.T) { o := NewListOpts(nil) - o.Set("foo") + err := o.Set("foo") + if err != nil { + t.Error(err) + } if o.String() != "[foo]" { t.Errorf("%s != [foo]", o.String()) } - o.Set("bar") + err = o.Set("bar") + if err != nil { + t.Error(err) + } if o.Len() != 2 { t.Errorf("%d != 2", o.Len()) } - o.Set("bar") + err = o.Set("bar") + if err != nil { + t.Error(err) + } if o.Len() != 3 { t.Errorf("%d != 3", o.Len()) } @@ -90,27 +105,35 @@ func TestListOptsWithoutValidator(t *testing.T) { } func TestListOptsWithValidator(t *testing.T) { - // Re-using logOptsvalidator (used by MapOpts) - o := NewListOpts(logOptsValidator) - o.Set("foo") + o := NewListOpts(sampleValidator) + err := o.Set("foo") + if err == nil { + t.Error(err) + } if o.String() != "" { t.Errorf(`%s != ""`, o.String()) } - o.Set("foo=bar") + err = o.Set("foo=bar") + if err == nil { + t.Error(err) + } if o.String() != "" { t.Errorf(`%s != ""`, o.String()) } - o.Set("max-file=2") + err = o.Set("valid-option2=2") + if err != nil { + t.Error(err) + } if o.Len() != 1 { t.Errorf("%d != 1", o.Len()) } - if !o.Get("max-file=2") { - t.Error("o.Get(\"max-file=2\") == false") + if !o.Get("valid-option2=2") { + t.Error(`o.Get("valid-option2=2") == false`) } if o.Get("baz") { - t.Error("o.Get(\"baz\") == true") + t.Error(`o.Get("baz") == true`) } - o.Delete("max-file=2") + o.Delete("valid-option2=2") if o.String() != "" { t.Errorf(`%s != ""`, o.String()) } @@ -277,13 +300,13 @@ func TestValidateLabel(t *testing.T) { } } -func logOptsValidator(val string) (string, error) { - allowedKeys := map[string]string{"max-size": "1", "max-file": "2"} - vals := strings.Split(val, "=") - if allowedKeys[vals[0]] != "" { +func sampleValidator(val string) (string, error) { + allowedKeys := map[string]string{"valid-option": "1", "valid-option2": "2"} + k, _, _ := strings.Cut(val, "=") + if allowedKeys[k] != "" { return val, nil } - return "", fmt.Errorf("invalid key %s", vals[0]) + return "", fmt.Errorf("invalid key %s", k) } func TestNamedListOpts(t *testing.T) { diff --git a/opts/parse.go b/opts/parse.go index 4012c461fb..017577e4bf 100644 --- a/opts/parse.go +++ b/opts/parse.go @@ -41,12 +41,8 @@ func readKVStrings(files []string, override []string, emptyFn func(string) (stri func ConvertKVStringsToMap(values []string) map[string]string { result := make(map[string]string, len(values)) for _, value := range values { - kv := strings.SplitN(value, "=", 2) - if len(kv) == 1 { - result[kv[0]] = "" - } else { - result[kv[0]] = kv[1] - } + k, v, _ := strings.Cut(value, "=") + result[k] = v } return result @@ -62,11 +58,11 @@ func ConvertKVStringsToMap(values []string) map[string]string { func ConvertKVStringsToMapWithNil(values []string) map[string]*string { result := make(map[string]*string, len(values)) for _, value := range values { - kv := strings.SplitN(value, "=", 2) - if len(kv) == 1 { - result[kv[0]] = nil + k, v, ok := strings.Cut(value, "=") + if !ok { + result[k] = nil } else { - result[kv[0]] = &kv[1] + result[k] = &v } } @@ -81,21 +77,15 @@ func ParseRestartPolicy(policy string) (container.RestartPolicy, error) { return p, nil } - parts := strings.Split(policy, ":") - - if len(parts) > 2 { - return p, fmt.Errorf("invalid restart policy format") - } - if len(parts) == 2 { - count, err := strconv.Atoi(parts[1]) + k, v, _ := strings.Cut(policy, ":") + if v != "" { + count, err := strconv.Atoi(v) if err != nil { - return p, fmt.Errorf("maximum retry count must be an integer") + return p, fmt.Errorf("invalid restart policy format: maximum retry count must be an integer") } - p.MaximumRetryCount = count } - p.Name = parts[0] - + p.Name = k return p, nil } diff --git a/opts/port.go b/opts/port.go index c0814dbd9a..fe41cdd288 100644 --- a/opts/port.go +++ b/opts/port.go @@ -42,36 +42,33 @@ func (p *PortOpt) Set(value string) error { pConfig := swarm.PortConfig{} for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - if len(parts) != 2 { + // TODO(thaJeztah): these options should not be case-insensitive. + key, val, ok := strings.Cut(strings.ToLower(field), "=") + if !ok || key == "" { return fmt.Errorf("invalid field %s", field) } - - key := strings.ToLower(parts[0]) - value := strings.ToLower(parts[1]) - switch key { case portOptProtocol: - if value != string(swarm.PortConfigProtocolTCP) && value != string(swarm.PortConfigProtocolUDP) && value != string(swarm.PortConfigProtocolSCTP) { - return fmt.Errorf("invalid protocol value %s", value) + if val != string(swarm.PortConfigProtocolTCP) && val != string(swarm.PortConfigProtocolUDP) && val != string(swarm.PortConfigProtocolSCTP) { + return fmt.Errorf("invalid protocol value %s", val) } - pConfig.Protocol = swarm.PortConfigProtocol(value) + pConfig.Protocol = swarm.PortConfigProtocol(val) case portOptMode: - if value != string(swarm.PortConfigPublishModeIngress) && value != string(swarm.PortConfigPublishModeHost) { - return fmt.Errorf("invalid publish mode value %s", value) + if val != string(swarm.PortConfigPublishModeIngress) && val != string(swarm.PortConfigPublishModeHost) { + return fmt.Errorf("invalid publish mode value %s", val) } - pConfig.PublishMode = swarm.PortConfigPublishMode(value) + pConfig.PublishMode = swarm.PortConfigPublishMode(val) case portOptTargetPort: - tPort, err := strconv.ParseUint(value, 10, 16) + tPort, err := strconv.ParseUint(val, 10, 16) if err != nil { return err } pConfig.TargetPort = uint32(tPort) case portOptPublishedPort: - pPort, err := strconv.ParseUint(value, 10, 16) + pPort, err := strconv.ParseUint(val, 10, 16) if err != nil { return err } diff --git a/opts/secret.go b/opts/secret.go index fabc62c01a..750dbe4f30 100644 --- a/opts/secret.go +++ b/opts/secret.go @@ -40,25 +40,22 @@ func (o *SecretOpt) Set(value string) error { } for _, field := range fields { - parts := strings.SplitN(field, "=", 2) - key := strings.ToLower(parts[0]) - - if len(parts) != 2 { + key, val, ok := strings.Cut(field, "=") + if !ok || key == "" { return fmt.Errorf("invalid field '%s' must be a key=value pair", field) } - - value := parts[1] - switch key { + // TODO(thaJeztah): these options should not be case-insensitive. + switch strings.ToLower(key) { case "source", "src": - options.SecretName = value + options.SecretName = val case "target": - options.File.Name = value + options.File.Name = val case "uid": - options.File.UID = value + options.File.UID = val case "gid": - options.File.GID = value + options.File.GID = val case "mode": - m, err := strconv.ParseUint(value, 0, 32) + m, err := strconv.ParseUint(val, 0, 32) if err != nil { return fmt.Errorf("invalid mode specified: %v", err) } diff --git a/opts/throttledevice.go b/opts/throttledevice.go index 0bf5dd666f..9fb788433b 100644 --- a/opts/throttledevice.go +++ b/opts/throttledevice.go @@ -14,14 +14,15 @@ type ValidatorThrottleFctType func(val string) (*blkiodev.ThrottleDevice, error) // ValidateThrottleBpsDevice validates that the specified string has a valid device-rate format. func ValidateThrottleBpsDevice(val string) (*blkiodev.ThrottleDevice, error) { - split := strings.SplitN(val, ":", 2) - if len(split) != 2 { + k, v, ok := strings.Cut(val, ":") + if !ok || k == "" { return nil, fmt.Errorf("bad format: %s", val) } - if !strings.HasPrefix(split[0], "/dev/") { + // TODO(thaJeztah): should we really validate this on the client? + if !strings.HasPrefix(k, "/dev/") { return nil, fmt.Errorf("bad format for device path: %s", val) } - rate, err := units.RAMInBytes(split[1]) + rate, err := units.RAMInBytes(v) if err != nil { return nil, fmt.Errorf("invalid rate for device: %s. The correct format is :[]. Number must be a positive integer. Unit is optional and can be kb, mb, or gb", val) } @@ -30,26 +31,27 @@ func ValidateThrottleBpsDevice(val string) (*blkiodev.ThrottleDevice, error) { } return &blkiodev.ThrottleDevice{ - Path: split[0], + Path: v, Rate: uint64(rate), }, nil } // ValidateThrottleIOpsDevice validates that the specified string has a valid device-rate format. func ValidateThrottleIOpsDevice(val string) (*blkiodev.ThrottleDevice, error) { - split := strings.SplitN(val, ":", 2) - if len(split) != 2 { + k, v, ok := strings.Cut(val, ":") + if !ok || k == "" { return nil, fmt.Errorf("bad format: %s", val) } - if !strings.HasPrefix(split[0], "/dev/") { + // TODO(thaJeztah): should we really validate this on the client? + if !strings.HasPrefix(k, "/dev/") { return nil, fmt.Errorf("bad format for device path: %s", val) } - rate, err := strconv.ParseUint(split[1], 10, 64) + rate, err := strconv.ParseUint(v, 10, 64) if err != nil { return nil, fmt.Errorf("invalid rate for device: %s. The correct format is :. Number must be a positive integer", val) } - return &blkiodev.ThrottleDevice{Path: split[0], Rate: rate}, nil + return &blkiodev.ThrottleDevice{Path: k, Rate: rate}, nil } // ThrottledeviceOpt defines a map of ThrottleDevices @@ -77,7 +79,7 @@ func (opt *ThrottledeviceOpt) Set(val string) error { } value = v } - (opt.values) = append((opt.values), value) + opt.values = append(opt.values, value) return nil } @@ -93,10 +95,7 @@ func (opt *ThrottledeviceOpt) String() string { // GetList returns a slice of pointers to ThrottleDevices. func (opt *ThrottledeviceOpt) GetList() []*blkiodev.ThrottleDevice { - var throttledevice []*blkiodev.ThrottleDevice - throttledevice = append(throttledevice, opt.values...) - - return throttledevice + return append([]*blkiodev.ThrottleDevice{}, opt.values...) } // Type returns the option type diff --git a/opts/weightdevice.go b/opts/weightdevice.go index f8057d0fb7..3077e3da7f 100644 --- a/opts/weightdevice.go +++ b/opts/weightdevice.go @@ -13,14 +13,15 @@ type ValidatorWeightFctType func(val string) (*blkiodev.WeightDevice, error) // ValidateWeightDevice validates that the specified string has a valid device-weight format. func ValidateWeightDevice(val string) (*blkiodev.WeightDevice, error) { - split := strings.SplitN(val, ":", 2) - if len(split) != 2 { + k, v, ok := strings.Cut(val, ":") + if !ok || k == "" { return nil, fmt.Errorf("bad format: %s", val) } - if !strings.HasPrefix(split[0], "/dev/") { + // TODO(thaJeztah): should we really validate this on the client? + if !strings.HasPrefix(k, "/dev/") { return nil, fmt.Errorf("bad format for device path: %s", val) } - weight, err := strconv.ParseUint(split[1], 10, 16) + weight, err := strconv.ParseUint(v, 10, 16) if err != nil { return nil, fmt.Errorf("invalid weight for device: %s", val) } @@ -29,7 +30,7 @@ func ValidateWeightDevice(val string) (*blkiodev.WeightDevice, error) { } return &blkiodev.WeightDevice{ - Path: split[0], + Path: k, Weight: uint16(weight), }, nil } @@ -42,9 +43,8 @@ type WeightdeviceOpt struct { // NewWeightdeviceOpt creates a new WeightdeviceOpt func NewWeightdeviceOpt(validator ValidatorWeightFctType) WeightdeviceOpt { - values := []*blkiodev.WeightDevice{} return WeightdeviceOpt{ - values: values, + values: []*blkiodev.WeightDevice{}, validator: validator, } } @@ -59,7 +59,7 @@ func (opt *WeightdeviceOpt) Set(val string) error { } value = v } - (opt.values) = append((opt.values), value) + opt.values = append(opt.values, value) return nil } diff --git a/service/logs/parse_logs.go b/service/logs/parse_logs.go index c01564ced5..9771f484f6 100644 --- a/service/logs/parse_logs.go +++ b/service/logs/parse_logs.go @@ -20,16 +20,17 @@ func ParseLogDetails(details string) (map[string]string, error) { pairs := strings.Split(details, ",") detailsMap := make(map[string]string, len(pairs)) for _, pair := range pairs { - p := strings.SplitN(pair, "=", 2) - // if there is no equals sign, we will only get 1 part back - if len(p) != 2 { + k, v, ok := strings.Cut(pair, "=") + if !ok || k == "" { + // missing equal sign, or no key. return nil, errors.New("invalid details format") } - k, err := url.QueryUnescape(p[0]) + var err error + k, err = url.QueryUnescape(k) if err != nil { return nil, err } - v, err := url.QueryUnescape(p[1]) + v, err = url.QueryUnescape(v) if err != nil { return nil, err } diff --git a/service/logs/parse_logs_test.go b/service/logs/parse_logs_test.go index 0a7ad3d967..9f6abdf714 100644 --- a/service/logs/parse_logs_test.go +++ b/service/logs/parse_logs_test.go @@ -3,33 +3,59 @@ package logs import ( "testing" - "github.com/pkg/errors" "gotest.tools/v3/assert" is "gotest.tools/v3/assert/cmp" ) func TestParseLogDetails(t *testing.T) { testCases := []struct { - line string - expected map[string]string - err error + line string + expected map[string]string + expectedErr string }{ - {"key=value", map[string]string{"key": "value"}, nil}, - {"key1=value1,key2=value2", map[string]string{"key1": "value1", "key2": "value2"}, nil}, - {"key+with+spaces=value%3Dequals,asdf%2C=", map[string]string{"key with spaces": "value=equals", "asdf,": ""}, nil}, - {"key=,=nothing", map[string]string{"key": "", "": "nothing"}, nil}, - {"=", map[string]string{"": ""}, nil}, - {"errors", nil, errors.New("invalid details format")}, + { + line: "key=value", + expected: map[string]string{"key": "value"}, + }, + { + line: "key1=value1,key2=value2", + expected: map[string]string{"key1": "value1", "key2": "value2"}, + }, + { + line: "key+with+spaces=value%3Dequals,asdf%2C=", + expected: map[string]string{"key with spaces": "value=equals", "asdf,": ""}, + }, + { + line: "key=,key2=", + expected: map[string]string{"key": "", "key2": ""}, + }, + { + line: "key=,=nothing", + expectedErr: "invalid details format", + }, + { + line: "=nothing", + expectedErr: "invalid details format", + }, + { + line: "=", + expectedErr: "invalid details format", + }, + { + line: "errors", + expectedErr: "invalid details format", + }, } - for _, testcase := range testCases { - testcase := testcase - t.Run(testcase.line, func(t *testing.T) { - actual, err := ParseLogDetails(testcase.line) - if testcase.err != nil { - assert.Error(t, err, testcase.err.Error()) - return + for _, tc := range testCases { + tc := tc + t.Run(tc.line, func(t *testing.T) { + actual, err := ParseLogDetails(tc.line) + if tc.expectedErr != "" { + assert.Check(t, is.ErrorContains(err, tc.expectedErr)) + } else { + assert.Check(t, err) } - assert.Check(t, is.DeepEqual(testcase.expected, actual)) + assert.Check(t, is.DeepEqual(tc.expected, actual)) }) } }