opts: use strings.Cut for handling key/value pairs

Signed-off-by: Sebastiaan van Stijn <github@gone.nl>
This commit is contained in:
Sebastiaan van Stijn 2022-12-27 16:24:23 +01:00
parent a473c5b38a
commit 6c39bc1f60
No known key found for this signature in database
GPG Key ID: 76698F39D527CE8C
15 changed files with 175 additions and 189 deletions

View File

@ -649,8 +649,8 @@ func TestRunFlagsParseShmSize(t *testing.T) {
func TestParseRestartPolicy(t *testing.T) { func TestParseRestartPolicy(t *testing.T) {
invalids := map[string]string{ invalids := map[string]string{
"always:2:3": "invalid restart policy format", "always:2:3": "invalid restart policy format: maximum retry count must be an integer",
"on-failure:invalid": "maximum retry count must be an integer", "on-failure:invalid": "invalid restart policy format: maximum retry count must be an integer",
} }
valids := map[string]container.RestartPolicy{ valids := map[string]container.RestartPolicy{
"": {}, "": {},

View File

@ -40,25 +40,23 @@ func (o *ConfigOpt) Set(value string) error {
} }
for _, field := range fields { for _, field := range fields {
parts := strings.SplitN(field, "=", 2) key, val, ok := strings.Cut(field, "=")
key := strings.ToLower(parts[0]) if !ok || key == "" {
if len(parts) != 2 {
return fmt.Errorf("invalid field '%s' must be a key=value pair", field) return fmt.Errorf("invalid field '%s' must be a key=value pair", field)
} }
value := parts[1] // TODO(thaJeztah): these options should not be case-insensitive.
switch key { switch strings.ToLower(key) {
case "source", "src": case "source", "src":
options.ConfigName = value options.ConfigName = val
case "target": case "target":
options.File.Name = value options.File.Name = val
case "uid": case "uid":
options.File.UID = value options.File.UID = val
case "gid": case "gid":
options.File.GID = value options.File.GID = val
case "mode": case "mode":
m, err := strconv.ParseUint(value, 0, 32) m, err := strconv.ParseUint(val, 0, 32)
if err != nil { if err != nil {
return fmt.Errorf("invalid mode specified: %v", err) return fmt.Errorf("invalid mode specified: %v", err)
} }

View File

@ -16,15 +16,16 @@ import (
// //
// The only validation here is to check if name is empty, per #25099 // The only validation here is to check if name is empty, per #25099
func ValidateEnv(val string) (string, error) { func ValidateEnv(val string) (string, error) {
arr := strings.SplitN(val, "=", 2) k, _, hasValue := strings.Cut(val, "=")
if arr[0] == "" { if k == "" {
return "", errors.New("invalid environment variable: " + val) return "", errors.New("invalid environment variable: " + val)
} }
if len(arr) > 1 { if hasValue {
// val contains a "=" (but value may be an empty string)
return val, nil return val, nil
} }
if envVal, ok := os.LookupEnv(arr[0]); ok { if envVal, ok := os.LookupEnv(k); ok {
return arr[0] + "=" + envVal, nil return k + "=" + envVal, nil
} }
return val, nil return val, nil
} }

View File

@ -46,10 +46,10 @@ func parseKeyValueFile(filename string, emptyFn func(string) (string, bool)) ([]
currentLine++ currentLine++
// line is not empty, and not starting with '#' // line is not empty, and not starting with '#'
if len(line) > 0 && !strings.HasPrefix(line, "#") { if len(line) > 0 && !strings.HasPrefix(line, "#") {
data := strings.SplitN(line, "=", 2) variable, value, hasValue := strings.Cut(line, "=")
// trim the front of a variable, but nothing else // trim the front of a variable, but nothing else
variable := strings.TrimLeft(data[0], whiteSpaces) variable = strings.TrimLeft(variable, whiteSpaces)
if strings.ContainsAny(variable, whiteSpaces) { if strings.ContainsAny(variable, whiteSpaces) {
return []string{}, ErrBadKey{fmt.Sprintf("variable '%s' contains whitespaces", variable)} return []string{}, ErrBadKey{fmt.Sprintf("variable '%s' contains whitespaces", variable)}
} }
@ -57,18 +57,17 @@ func parseKeyValueFile(filename string, emptyFn func(string) (string, bool)) ([]
return []string{}, ErrBadKey{fmt.Sprintf("no variable name on line '%s'", line)} return []string{}, ErrBadKey{fmt.Sprintf("no variable name on line '%s'", line)}
} }
if len(data) > 1 { if hasValue {
// pass the value through, no trimming // pass the value through, no trimming
lines = append(lines, fmt.Sprintf("%s=%s", variable, data[1])) lines = append(lines, variable+"="+value)
} else { } else {
var value string
var present bool var present bool
if emptyFn != nil { if emptyFn != nil {
value, present = emptyFn(line) value, present = emptyFn(line)
} }
if present { if present {
// if only a pass-through variable is given, clean it up. // if only a pass-through variable is given, clean it up.
lines = append(lines, fmt.Sprintf("%s=%s", strings.TrimSpace(line), value)) lines = append(lines, strings.TrimSpace(variable)+"="+value)
} }
} }
} }

View File

@ -38,14 +38,13 @@ func (o *GpuOpts) Set(value string) error {
seen := map[string]struct{}{} seen := map[string]struct{}{}
// Set writable as the default // Set writable as the default
for _, field := range fields { for _, field := range fields {
parts := strings.SplitN(field, "=", 2) key, val, withValue := strings.Cut(field, "=")
key := parts[0]
if _, ok := seen[key]; ok { if _, ok := seen[key]; ok {
return fmt.Errorf("gpu request key '%s' can be specified only once", key) return fmt.Errorf("gpu request key '%s' can be specified only once", key)
} }
seen[key] = struct{}{} seen[key] = struct{}{}
if len(parts) == 1 { if !withValue {
seen["count"] = struct{}{} seen["count"] = struct{}{}
req.Count, err = parseCount(key) req.Count, err = parseCount(key)
if err != nil { if err != nil {
@ -54,21 +53,20 @@ func (o *GpuOpts) Set(value string) error {
continue continue
} }
value := parts[1]
switch key { switch key {
case "driver": case "driver":
req.Driver = value req.Driver = val
case "count": case "count":
req.Count, err = parseCount(value) req.Count, err = parseCount(val)
if err != nil { if err != nil {
return err return err
} }
case "device": case "device":
req.DeviceIDs = strings.Split(value, ",") req.DeviceIDs = strings.Split(val, ",")
case "capabilities": case "capabilities":
req.Capabilities = [][]string{append(strings.Split(value, ","), "gpu")} req.Capabilities = [][]string{append(strings.Split(val, ","), "gpu")}
case "options": case "options":
r := csv.NewReader(strings.NewReader(value)) r := csv.NewReader(strings.NewReader(val))
optFields, err := r.Read() optFields, err := r.Read()
if err != nil { if err != nil {
return errors.Wrap(err, "failed to read gpu options") return errors.Wrap(err, "failed to read gpu options")

View File

@ -33,6 +33,8 @@ const (
) )
// ValidateHost validates that the specified string is a valid host and returns it. // ValidateHost validates that the specified string is a valid host and returns it.
//
// TODO(thaJeztah): ValidateHost appears to be unused; deprecate it.
func ValidateHost(val string) (string, error) { func ValidateHost(val string) (string, error) {
host := strings.TrimSpace(val) host := strings.TrimSpace(val)
// The empty string means default and is not handled by parseDockerDaemonHost // The empty string means default and is not handled by parseDockerDaemonHost
@ -69,18 +71,19 @@ func ParseHost(defaultToTLS bool, val string) (string, error) {
// parseDockerDaemonHost parses the specified address and returns an address that will be used as the host. // parseDockerDaemonHost parses the specified address and returns an address that will be used as the host.
// Depending of the address specified, this may return one of the global Default* strings defined in hosts.go. // Depending of the address specified, this may return one of the global Default* strings defined in hosts.go.
func parseDockerDaemonHost(addr string) (string, error) { func parseDockerDaemonHost(addr string) (string, error) {
addrParts := strings.SplitN(addr, "://", 2) proto, host, hasProto := strings.Cut(addr, "://")
if len(addrParts) == 1 && addrParts[0] != "" { if !hasProto && proto != "" {
addrParts = []string{"tcp", addrParts[0]} host = proto
proto = "tcp"
} }
switch addrParts[0] { switch proto {
case "tcp": case "tcp":
return ParseTCPAddr(addrParts[1], defaultTCPHost) return ParseTCPAddr(host, defaultTCPHost)
case "unix": case "unix":
return parseSimpleProtoAddr("unix", addrParts[1], defaultUnixSocket) return parseSimpleProtoAddr(proto, host, defaultUnixSocket)
case "npipe": case "npipe":
return parseSimpleProtoAddr("npipe", addrParts[1], defaultNamedPipe) return parseSimpleProtoAddr(proto, host, defaultNamedPipe)
case "fd": case "fd":
return addr, nil return addr, nil
case "ssh": case "ssh":
@ -160,16 +163,18 @@ func ParseTCPAddr(tryAddr string, defaultAddr string) (string, error) {
// ValidateExtraHost validates that the specified string is a valid extrahost and returns it. // ValidateExtraHost validates that the specified string is a valid extrahost and returns it.
// ExtraHost is in the form of name:ip where the ip has to be a valid ip (IPv4 or IPv6). // ExtraHost is in the form of name:ip where the ip has to be a valid ip (IPv4 or IPv6).
//
// TODO(thaJeztah): remove client-side validation, and delegate to the API server.
func ValidateExtraHost(val string) (string, error) { func ValidateExtraHost(val string) (string, error) {
// allow for IPv6 addresses in extra hosts by only splitting on first ":" // allow for IPv6 addresses in extra hosts by only splitting on first ":"
arr := strings.SplitN(val, ":", 2) k, v, ok := strings.Cut(val, ":")
if len(arr) != 2 || len(arr[0]) == 0 { if !ok || k == "" {
return "", fmt.Errorf("bad format for add-host: %q", val) return "", fmt.Errorf("bad format for add-host: %q", val)
} }
// Skip IPaddr validation for "host-gateway" string // Skip IPaddr validation for "host-gateway" string
if arr[1] != hostGatewayName { if v != hostGatewayName {
if _, err := ValidateIPAddress(arr[1]); err != nil { if _, err := ValidateIPAddress(v); err != nil {
return "", fmt.Errorf("invalid IP address in add-host: %q", arr[1]) return "", fmt.Errorf("invalid IP address in add-host: %q", v)
} }
} }
return val, nil return val, nil

View File

@ -56,21 +56,21 @@ func (m *MountOpt) Set(value string) error {
} }
setValueOnMap := func(target map[string]string, value string) { setValueOnMap := func(target map[string]string, value string) {
parts := strings.SplitN(value, "=", 2) k, v, _ := strings.Cut(value, "=")
if len(parts) == 1 { if k != "" {
target[value] = "" target[k] = v
} else {
target[parts[0]] = parts[1]
} }
} }
mount.Type = mounttypes.TypeVolume // default to volume mounts mount.Type = mounttypes.TypeVolume // default to volume mounts
// Set writable as the default // Set writable as the default
for _, field := range fields { for _, field := range fields {
parts := strings.SplitN(field, "=", 2) key, val, ok := strings.Cut(field, "=")
key := strings.ToLower(parts[0])
if len(parts) == 1 { // TODO(thaJeztah): these options should not be case-insensitive.
key = strings.ToLower(key)
if !ok {
switch key { switch key {
case "readonly", "ro": case "readonly", "ro":
mount.ReadOnly = true mount.ReadOnly = true
@ -81,64 +81,61 @@ func (m *MountOpt) Set(value string) error {
case "bind-nonrecursive": case "bind-nonrecursive":
bindOptions().NonRecursive = true bindOptions().NonRecursive = true
continue continue
default:
return fmt.Errorf("invalid field '%s' must be a key=value pair", field)
} }
} }
if len(parts) != 2 {
return fmt.Errorf("invalid field '%s' must be a key=value pair", field)
}
value := parts[1]
switch key { switch key {
case "type": case "type":
mount.Type = mounttypes.Type(strings.ToLower(value)) mount.Type = mounttypes.Type(strings.ToLower(val))
case "source", "src": case "source", "src":
mount.Source = value mount.Source = val
if strings.HasPrefix(value, "."+string(filepath.Separator)) || value == "." { if strings.HasPrefix(val, "."+string(filepath.Separator)) || val == "." {
if abs, err := filepath.Abs(value); err == nil { if abs, err := filepath.Abs(val); err == nil {
mount.Source = abs mount.Source = abs
} }
} }
case "target", "dst", "destination": case "target", "dst", "destination":
mount.Target = value mount.Target = val
case "readonly", "ro": case "readonly", "ro":
mount.ReadOnly, err = strconv.ParseBool(value) mount.ReadOnly, err = strconv.ParseBool(val)
if err != nil { if err != nil {
return fmt.Errorf("invalid value for %s: %s", key, value) return fmt.Errorf("invalid value for %s: %s", key, val)
} }
case "consistency": case "consistency":
mount.Consistency = mounttypes.Consistency(strings.ToLower(value)) mount.Consistency = mounttypes.Consistency(strings.ToLower(val))
case "bind-propagation": case "bind-propagation":
bindOptions().Propagation = mounttypes.Propagation(strings.ToLower(value)) bindOptions().Propagation = mounttypes.Propagation(strings.ToLower(val))
case "bind-nonrecursive": case "bind-nonrecursive":
bindOptions().NonRecursive, err = strconv.ParseBool(value) bindOptions().NonRecursive, err = strconv.ParseBool(val)
if err != nil { if err != nil {
return fmt.Errorf("invalid value for %s: %s", key, value) return fmt.Errorf("invalid value for %s: %s", key, val)
} }
case "volume-nocopy": case "volume-nocopy":
volumeOptions().NoCopy, err = strconv.ParseBool(value) volumeOptions().NoCopy, err = strconv.ParseBool(val)
if err != nil { if err != nil {
return fmt.Errorf("invalid value for volume-nocopy: %s", value) return fmt.Errorf("invalid value for volume-nocopy: %s", val)
} }
case "volume-label": case "volume-label":
setValueOnMap(volumeOptions().Labels, value) setValueOnMap(volumeOptions().Labels, val)
case "volume-driver": case "volume-driver":
volumeOptions().DriverConfig.Name = value volumeOptions().DriverConfig.Name = val
case "volume-opt": case "volume-opt":
if volumeOptions().DriverConfig.Options == nil { if volumeOptions().DriverConfig.Options == nil {
volumeOptions().DriverConfig.Options = make(map[string]string) volumeOptions().DriverConfig.Options = make(map[string]string)
} }
setValueOnMap(volumeOptions().DriverConfig.Options, value) setValueOnMap(volumeOptions().DriverConfig.Options, val)
case "tmpfs-size": case "tmpfs-size":
sizeBytes, err := units.RAMInBytes(value) sizeBytes, err := units.RAMInBytes(val)
if err != nil { if err != nil {
return fmt.Errorf("invalid value for %s: %s", key, value) return fmt.Errorf("invalid value for %s: %s", key, val)
} }
tmpfsOptions().SizeBytes = sizeBytes tmpfsOptions().SizeBytes = sizeBytes
case "tmpfs-mode": case "tmpfs-mode":
ui64, err := strconv.ParseUint(value, 8, 32) ui64, err := strconv.ParseUint(val, 8, 32)
if err != nil { if err != nil {
return fmt.Errorf("invalid value for %s: %s", key, value) return fmt.Errorf("invalid value for %s: %s", key, val)
} }
tmpfsOptions().Mode = os.FileMode(ui64) tmpfsOptions().Mode = os.FileMode(ui64)
default: default:

View File

@ -48,34 +48,33 @@ func (n *NetworkOpt) Set(value string) error {
netOpt.Aliases = []string{} netOpt.Aliases = []string{}
for _, field := range fields { for _, field := range fields {
parts := strings.SplitN(field, "=", 2) // TODO(thaJeztah): these options should not be case-insensitive.
key, val, ok := strings.Cut(strings.ToLower(field), "=")
if len(parts) < 2 { if !ok || key == "" {
return fmt.Errorf("invalid field %s", field) return fmt.Errorf("invalid field %s", field)
} }
key := strings.TrimSpace(strings.ToLower(parts[0])) key = strings.TrimSpace(key)
value := strings.TrimSpace(strings.ToLower(parts[1])) val = strings.TrimSpace(val)
switch key { switch key {
case networkOptName: case networkOptName:
netOpt.Target = value netOpt.Target = val
case networkOptAlias: case networkOptAlias:
netOpt.Aliases = append(netOpt.Aliases, value) netOpt.Aliases = append(netOpt.Aliases, val)
case networkOptIPv4Address: case networkOptIPv4Address:
netOpt.IPv4Address = value netOpt.IPv4Address = val
case networkOptIPv6Address: case networkOptIPv6Address:
netOpt.IPv6Address = value netOpt.IPv6Address = val
case driverOpt: case driverOpt:
key, value, err = parseDriverOpt(value) key, val, err = parseDriverOpt(val)
if err == nil { if err != nil {
if netOpt.DriverOpts == nil {
netOpt.DriverOpts = make(map[string]string)
}
netOpt.DriverOpts[key] = value
} else {
return err return err
} }
if netOpt.DriverOpts == nil {
netOpt.DriverOpts = make(map[string]string)
}
netOpt.DriverOpts[key] = val
default: default:
return fmt.Errorf("invalid field key %s", key) return fmt.Errorf("invalid field key %s", key)
} }
@ -116,11 +115,13 @@ func (n *NetworkOpt) NetworkMode() string {
} }
func parseDriverOpt(driverOpt string) (string, string, error) { func parseDriverOpt(driverOpt string) (string, string, error) {
parts := strings.SplitN(driverOpt, "=", 2) // TODO(thaJeztah): these options should not be case-insensitive.
if len(parts) != 2 { // TODO(thaJeztah): should value be converted to lowercase as well, or only the key?
key, value, ok := strings.Cut(strings.ToLower(driverOpt), "=")
if !ok || key == "" {
return "", "", fmt.Errorf("invalid key value pair format in driver options") return "", "", fmt.Errorf("invalid key value pair format in driver options")
} }
key := strings.TrimSpace(strings.ToLower(parts[0])) key = strings.TrimSpace(key)
value := strings.TrimSpace(strings.ToLower(parts[1])) value = strings.TrimSpace(value)
return key, value, nil return key, value, nil
} }

View File

@ -165,12 +165,8 @@ func (opts *MapOpts) Set(value string) error {
} }
value = v value = v
} }
vals := strings.SplitN(value, "=", 2) k, v, _ := strings.Cut(value, "=")
if len(vals) == 1 { opts.values[k] = v
opts.values[vals[0]] = ""
} else {
opts.values[vals[0]] = vals[1]
}
return nil return nil
} }
@ -277,16 +273,16 @@ func validateDomain(val string) (string, error) {
// //
// TODO discuss if quotes (and other special characters) should be valid or invalid for keys // TODO discuss if quotes (and other special characters) should be valid or invalid for keys
// TODO discuss if leading/trailing whitespace in keys should be preserved (and valid) // TODO discuss if leading/trailing whitespace in keys should be preserved (and valid)
func ValidateLabel(val string) (string, error) { func ValidateLabel(value string) (string, error) {
arr := strings.SplitN(val, "=", 2) key, _, _ := strings.Cut(value, "=")
key := strings.TrimLeft(arr[0], whiteSpaces) key = strings.TrimLeft(key, whiteSpaces)
if key == "" { if key == "" {
return "", fmt.Errorf("invalid label '%s': empty name", val) return "", fmt.Errorf("invalid label '%s': empty name", value)
} }
if strings.ContainsAny(key, whiteSpaces) { if strings.ContainsAny(key, whiteSpaces) {
return "", fmt.Errorf("label '%s' contains whitespaces", key) return "", fmt.Errorf("label '%s' contains whitespaces", key)
} }
return val, nil return value, nil
} }
// ValidateSysctl validates a sysctl and returns it. // ValidateSysctl validates a sysctl and returns it.
@ -305,20 +301,19 @@ func ValidateSysctl(val string) (string, error) {
"net.", "net.",
"fs.mqueue.", "fs.mqueue.",
} }
arr := strings.Split(val, "=") k, _, ok := strings.Cut(val, "=")
if len(arr) < 2 { if !ok || k == "" {
return "", fmt.Errorf("sysctl '%s' is not whitelisted", val) return "", fmt.Errorf("sysctl '%s' is not allowed", val)
} }
if validSysctlMap[arr[0]] { if validSysctlMap[k] {
return val, nil return val, nil
} }
for _, vp := range validSysctlPrefixes { for _, vp := range validSysctlPrefixes {
if strings.HasPrefix(arr[0], vp) { if strings.HasPrefix(k, vp) {
return val, nil return val, nil
} }
} }
return "", fmt.Errorf("sysctl '%s' is not whitelisted", val) return "", fmt.Errorf("sysctl '%s' is not allowed", val)
} }
// FilterOpt is a flag type for validating filters // FilterOpt is a flag type for validating filters
@ -347,11 +342,12 @@ func (o *FilterOpt) Set(value string) error {
if !strings.Contains(value, "=") { if !strings.Contains(value, "=") {
return errors.New("bad format of filter (expected name=value)") return errors.New("bad format of filter (expected name=value)")
} }
f := strings.SplitN(value, "=", 2) name, val, _ := strings.Cut(value, "=")
name := strings.ToLower(strings.TrimSpace(f[0]))
value = strings.TrimSpace(f[1])
o.filter.Add(name, value) // TODO(thaJeztah): these options should not be case-insensitive.
name = strings.ToLower(strings.TrimSpace(name))
val = strings.TrimSpace(val)
o.filter.Add(name, val)
return nil return nil
} }
@ -411,10 +407,14 @@ func ParseLink(val string) (string, string, error) {
if val == "" { if val == "" {
return "", "", fmt.Errorf("empty string specified for links") return "", "", fmt.Errorf("empty string specified for links")
} }
arr := strings.Split(val, ":") // We expect two parts, but restrict to three to allow detecting invalid formats.
arr := strings.SplitN(val, ":", 3)
// TODO(thaJeztah): clean up this logic!!
if len(arr) > 2 { if len(arr) > 2 {
return "", "", fmt.Errorf("bad format for links: %s", val) return "", "", fmt.Errorf("bad format for links: %s", val)
} }
// TODO(thaJeztah): this should trim the "/" prefix as well??
if len(arr) == 1 { if len(arr) == 1 {
return val, val, nil return val, val, nil
} }
@ -422,6 +422,7 @@ func ParseLink(val string) (string, string, error) {
// from an already created container and the format is not `foo:bar` // from an already created container and the format is not `foo:bar`
// but `/foo:/c1/bar` // but `/foo:/c1/bar`
if strings.HasPrefix(arr[0], "/") { if strings.HasPrefix(arr[0], "/") {
// TODO(thaJeztah): clean up this logic!!
_, alias := path.Split(arr[1]) _, alias := path.Split(arr[1])
return arr[0][1:], alias, nil return arr[0][1:], alias, nil
} }

View File

@ -301,12 +301,12 @@ func TestValidateLabel(t *testing.T) {
} }
func sampleValidator(val string) (string, error) { func sampleValidator(val string) (string, error) {
allowedKeys := map[string]string{"max-size": "1", "max-file": "2"} allowedKeys := map[string]string{"valid-option": "1", "valid-option2": "2"}
vals := strings.Split(val, "=") k, _, _ := strings.Cut(val, "=")
if allowedKeys[vals[0]] != "" { if allowedKeys[k] != "" {
return val, nil return val, nil
} }
return "", fmt.Errorf("invalid key %s", vals[0]) return "", fmt.Errorf("invalid key %s", k)
} }
func TestNamedListOpts(t *testing.T) { func TestNamedListOpts(t *testing.T) {

View File

@ -41,12 +41,8 @@ func readKVStrings(files []string, override []string, emptyFn func(string) (stri
func ConvertKVStringsToMap(values []string) map[string]string { func ConvertKVStringsToMap(values []string) map[string]string {
result := make(map[string]string, len(values)) result := make(map[string]string, len(values))
for _, value := range values { for _, value := range values {
kv := strings.SplitN(value, "=", 2) k, v, _ := strings.Cut(value, "=")
if len(kv) == 1 { result[k] = v
result[kv[0]] = ""
} else {
result[kv[0]] = kv[1]
}
} }
return result return result
@ -62,11 +58,11 @@ func ConvertKVStringsToMap(values []string) map[string]string {
func ConvertKVStringsToMapWithNil(values []string) map[string]*string { func ConvertKVStringsToMapWithNil(values []string) map[string]*string {
result := make(map[string]*string, len(values)) result := make(map[string]*string, len(values))
for _, value := range values { for _, value := range values {
kv := strings.SplitN(value, "=", 2) k, v, ok := strings.Cut(value, "=")
if len(kv) == 1 { if !ok {
result[kv[0]] = nil result[k] = nil
} else { } else {
result[kv[0]] = &kv[1] result[k] = &v
} }
} }
@ -81,21 +77,15 @@ func ParseRestartPolicy(policy string) (container.RestartPolicy, error) {
return p, nil return p, nil
} }
parts := strings.Split(policy, ":") k, v, _ := strings.Cut(policy, ":")
if v != "" {
if len(parts) > 2 { count, err := strconv.Atoi(v)
return p, fmt.Errorf("invalid restart policy format")
}
if len(parts) == 2 {
count, err := strconv.Atoi(parts[1])
if err != nil { if err != nil {
return p, fmt.Errorf("maximum retry count must be an integer") return p, fmt.Errorf("invalid restart policy format: maximum retry count must be an integer")
} }
p.MaximumRetryCount = count p.MaximumRetryCount = count
} }
p.Name = parts[0] p.Name = k
return p, nil return p, nil
} }

View File

@ -42,36 +42,33 @@ func (p *PortOpt) Set(value string) error {
pConfig := swarm.PortConfig{} pConfig := swarm.PortConfig{}
for _, field := range fields { for _, field := range fields {
parts := strings.SplitN(field, "=", 2) // TODO(thaJeztah): these options should not be case-insensitive.
if len(parts) != 2 { key, val, ok := strings.Cut(strings.ToLower(field), "=")
if !ok || key == "" {
return fmt.Errorf("invalid field %s", field) return fmt.Errorf("invalid field %s", field)
} }
key := strings.ToLower(parts[0])
value := strings.ToLower(parts[1])
switch key { switch key {
case portOptProtocol: case portOptProtocol:
if value != string(swarm.PortConfigProtocolTCP) && value != string(swarm.PortConfigProtocolUDP) && value != string(swarm.PortConfigProtocolSCTP) { if val != string(swarm.PortConfigProtocolTCP) && val != string(swarm.PortConfigProtocolUDP) && val != string(swarm.PortConfigProtocolSCTP) {
return fmt.Errorf("invalid protocol value %s", value) return fmt.Errorf("invalid protocol value %s", val)
} }
pConfig.Protocol = swarm.PortConfigProtocol(value) pConfig.Protocol = swarm.PortConfigProtocol(val)
case portOptMode: case portOptMode:
if value != string(swarm.PortConfigPublishModeIngress) && value != string(swarm.PortConfigPublishModeHost) { if val != string(swarm.PortConfigPublishModeIngress) && val != string(swarm.PortConfigPublishModeHost) {
return fmt.Errorf("invalid publish mode value %s", value) return fmt.Errorf("invalid publish mode value %s", val)
} }
pConfig.PublishMode = swarm.PortConfigPublishMode(value) pConfig.PublishMode = swarm.PortConfigPublishMode(val)
case portOptTargetPort: case portOptTargetPort:
tPort, err := strconv.ParseUint(value, 10, 16) tPort, err := strconv.ParseUint(val, 10, 16)
if err != nil { if err != nil {
return err return err
} }
pConfig.TargetPort = uint32(tPort) pConfig.TargetPort = uint32(tPort)
case portOptPublishedPort: case portOptPublishedPort:
pPort, err := strconv.ParseUint(value, 10, 16) pPort, err := strconv.ParseUint(val, 10, 16)
if err != nil { if err != nil {
return err return err
} }

View File

@ -40,25 +40,22 @@ func (o *SecretOpt) Set(value string) error {
} }
for _, field := range fields { for _, field := range fields {
parts := strings.SplitN(field, "=", 2) key, val, ok := strings.Cut(field, "=")
key := strings.ToLower(parts[0]) if !ok || key == "" {
if len(parts) != 2 {
return fmt.Errorf("invalid field '%s' must be a key=value pair", field) return fmt.Errorf("invalid field '%s' must be a key=value pair", field)
} }
// TODO(thaJeztah): these options should not be case-insensitive.
value := parts[1] switch strings.ToLower(key) {
switch key {
case "source", "src": case "source", "src":
options.SecretName = value options.SecretName = val
case "target": case "target":
options.File.Name = value options.File.Name = val
case "uid": case "uid":
options.File.UID = value options.File.UID = val
case "gid": case "gid":
options.File.GID = value options.File.GID = val
case "mode": case "mode":
m, err := strconv.ParseUint(value, 0, 32) m, err := strconv.ParseUint(val, 0, 32)
if err != nil { if err != nil {
return fmt.Errorf("invalid mode specified: %v", err) return fmt.Errorf("invalid mode specified: %v", err)
} }

View File

@ -14,14 +14,15 @@ type ValidatorThrottleFctType func(val string) (*blkiodev.ThrottleDevice, error)
// ValidateThrottleBpsDevice validates that the specified string has a valid device-rate format. // ValidateThrottleBpsDevice validates that the specified string has a valid device-rate format.
func ValidateThrottleBpsDevice(val string) (*blkiodev.ThrottleDevice, error) { func ValidateThrottleBpsDevice(val string) (*blkiodev.ThrottleDevice, error) {
split := strings.SplitN(val, ":", 2) k, v, ok := strings.Cut(val, ":")
if len(split) != 2 { if !ok || k == "" {
return nil, fmt.Errorf("bad format: %s", val) return nil, fmt.Errorf("bad format: %s", val)
} }
if !strings.HasPrefix(split[0], "/dev/") { // TODO(thaJeztah): should we really validate this on the client?
if !strings.HasPrefix(k, "/dev/") {
return nil, fmt.Errorf("bad format for device path: %s", val) return nil, fmt.Errorf("bad format for device path: %s", val)
} }
rate, err := units.RAMInBytes(split[1]) rate, err := units.RAMInBytes(v)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid rate for device: %s. The correct format is <device-path>:<number>[<unit>]. Number must be a positive integer. Unit is optional and can be kb, mb, or gb", val) return nil, fmt.Errorf("invalid rate for device: %s. The correct format is <device-path>:<number>[<unit>]. Number must be a positive integer. Unit is optional and can be kb, mb, or gb", val)
} }
@ -30,26 +31,27 @@ func ValidateThrottleBpsDevice(val string) (*blkiodev.ThrottleDevice, error) {
} }
return &blkiodev.ThrottleDevice{ return &blkiodev.ThrottleDevice{
Path: split[0], Path: v,
Rate: uint64(rate), Rate: uint64(rate),
}, nil }, nil
} }
// ValidateThrottleIOpsDevice validates that the specified string has a valid device-rate format. // ValidateThrottleIOpsDevice validates that the specified string has a valid device-rate format.
func ValidateThrottleIOpsDevice(val string) (*blkiodev.ThrottleDevice, error) { func ValidateThrottleIOpsDevice(val string) (*blkiodev.ThrottleDevice, error) {
split := strings.SplitN(val, ":", 2) k, v, ok := strings.Cut(val, ":")
if len(split) != 2 { if !ok || k == "" {
return nil, fmt.Errorf("bad format: %s", val) return nil, fmt.Errorf("bad format: %s", val)
} }
if !strings.HasPrefix(split[0], "/dev/") { // TODO(thaJeztah): should we really validate this on the client?
if !strings.HasPrefix(k, "/dev/") {
return nil, fmt.Errorf("bad format for device path: %s", val) return nil, fmt.Errorf("bad format for device path: %s", val)
} }
rate, err := strconv.ParseUint(split[1], 10, 64) rate, err := strconv.ParseUint(v, 10, 64)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid rate for device: %s. The correct format is <device-path>:<number>. Number must be a positive integer", val) return nil, fmt.Errorf("invalid rate for device: %s. The correct format is <device-path>:<number>. Number must be a positive integer", val)
} }
return &blkiodev.ThrottleDevice{Path: split[0], Rate: rate}, nil return &blkiodev.ThrottleDevice{Path: k, Rate: rate}, nil
} }
// ThrottledeviceOpt defines a map of ThrottleDevices // ThrottledeviceOpt defines a map of ThrottleDevices

View File

@ -13,14 +13,15 @@ type ValidatorWeightFctType func(val string) (*blkiodev.WeightDevice, error)
// ValidateWeightDevice validates that the specified string has a valid device-weight format. // ValidateWeightDevice validates that the specified string has a valid device-weight format.
func ValidateWeightDevice(val string) (*blkiodev.WeightDevice, error) { func ValidateWeightDevice(val string) (*blkiodev.WeightDevice, error) {
split := strings.SplitN(val, ":", 2) k, v, ok := strings.Cut(val, ":")
if len(split) != 2 { if !ok || k == "" {
return nil, fmt.Errorf("bad format: %s", val) return nil, fmt.Errorf("bad format: %s", val)
} }
if !strings.HasPrefix(split[0], "/dev/") { // TODO(thaJeztah): should we really validate this on the client?
if !strings.HasPrefix(k, "/dev/") {
return nil, fmt.Errorf("bad format for device path: %s", val) return nil, fmt.Errorf("bad format for device path: %s", val)
} }
weight, err := strconv.ParseUint(split[1], 10, 16) weight, err := strconv.ParseUint(v, 10, 16)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid weight for device: %s", val) return nil, fmt.Errorf("invalid weight for device: %s", val)
} }
@ -29,7 +30,7 @@ func ValidateWeightDevice(val string) (*blkiodev.WeightDevice, error) {
} }
return &blkiodev.WeightDevice{ return &blkiodev.WeightDevice{
Path: split[0], Path: k,
Weight: uint16(weight), Weight: uint16(weight),
}, nil }, nil
} }
@ -42,9 +43,8 @@ type WeightdeviceOpt struct {
// NewWeightdeviceOpt creates a new WeightdeviceOpt // NewWeightdeviceOpt creates a new WeightdeviceOpt
func NewWeightdeviceOpt(validator ValidatorWeightFctType) WeightdeviceOpt { func NewWeightdeviceOpt(validator ValidatorWeightFctType) WeightdeviceOpt {
values := []*blkiodev.WeightDevice{}
return WeightdeviceOpt{ return WeightdeviceOpt{
values: values, values: []*blkiodev.WeightDevice{},
validator: validator, validator: validator,
} }
} }