diff --git a/cli/command/container/opts.go b/cli/command/container/opts.go index 8fe4ded9c8..6b10185081 100644 --- a/cli/command/container/opts.go +++ b/cli/command/container/opts.go @@ -46,6 +46,7 @@ type containerOptions struct { labels opts.ListOpts deviceCgroupRules opts.ListOpts devices opts.ListOpts + gpus opts.GpuOpts ulimits *opts.UlimitOpt sysctls *opts.MapOpts publish opts.ListOpts @@ -166,6 +167,8 @@ func addFlags(flags *pflag.FlagSet) *containerOptions { flags.VarP(&copts.attach, "attach", "a", "Attach to STDIN, STDOUT or STDERR") flags.Var(&copts.deviceCgroupRules, "device-cgroup-rule", "Add a rule to the cgroup allowed devices list") flags.Var(&copts.devices, "device", "Add a host device to the container") + flags.Var(&copts.gpus, "gpus", "GPU devices to add to the container ('all' to pass all GPUs)") + flags.SetAnnotation("gpus", "version", []string{"1.40"}) flags.VarP(&copts.env, "env", "e", "Set environment variables") flags.Var(&copts.envFile, "env-file", "Read in a file of environment variables") flags.StringVar(&copts.entrypoint, "entrypoint", "", "Overwrite the default ENTRYPOINT of the image") @@ -557,6 +560,7 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con Ulimits: copts.ulimits.GetList(), DeviceCgroupRules: copts.deviceCgroupRules.GetAll(), Devices: deviceMappings, + DeviceRequests: copts.gpus.Value(), } config := &container.Config{ diff --git a/opts/gpus.go b/opts/gpus.go new file mode 100644 index 0000000000..e110a4771e --- /dev/null +++ b/opts/gpus.go @@ -0,0 +1,112 @@ +package opts + +import ( + "encoding/csv" + "fmt" + "strconv" + "strings" + + "github.com/docker/docker/api/types/container" + "github.com/pkg/errors" +) + +// GpuOpts is a Value type for parsing mounts +type GpuOpts struct { + values []container.DeviceRequest +} + +func parseCount(s string) (int, error) { + if s == "all" { + return -1, nil + } + i, err := strconv.Atoi(s) + return i, errors.Wrap(err, "count must be an integer") +} + +// Set a new mount value +// nolint: gocyclo +func (o *GpuOpts) Set(value string) error { + csvReader := csv.NewReader(strings.NewReader(value)) + fields, err := csvReader.Read() + if err != nil { + return err + } + + req := container.DeviceRequest{} + + seen := map[string]struct{}{} + // Set writable as the default + for _, field := range fields { + parts := strings.SplitN(field, "=", 2) + key := parts[0] + if _, ok := seen[key]; ok { + return fmt.Errorf("gpu request key '%s' can be specified only once", key) + } + seen[key] = struct{}{} + + if len(parts) == 1 { + seen["count"] = struct{}{} + req.Count, err = parseCount(key) + if err != nil { + return err + } + continue + } + + value := parts[1] + switch key { + case "driver": + req.Driver = value + case "count": + req.Count, err = parseCount(value) + if err != nil { + return err + } + case "device": + req.DeviceIDs = strings.Split(value, ",") + case "capabilities": + req.Capabilities = [][]string{append(strings.Split(value, ","), "gpu")} + case "options": + r := csv.NewReader(strings.NewReader(value)) + optFields, err := r.Read() + if err != nil { + return errors.Wrap(err, "failed to read gpu options") + } + req.Options = ConvertKVStringsToMap(optFields) + default: + return fmt.Errorf("unexpected key '%s' in '%s'", key, field) + } + } + + if _, ok := seen["count"]; !ok && req.DeviceIDs == nil { + req.Count = 1 + } + if req.Options == nil { + req.Options = make(map[string]string) + } + if req.Capabilities == nil { + req.Capabilities = [][]string{{"gpu"}} + } + + o.values = append(o.values, req) + return nil +} + +// Type returns the type of this option +func (o *GpuOpts) Type() string { + return "gpu-request" +} + +// String returns a string repr of this option +func (o *GpuOpts) String() string { + gpus := []string{} + for _, gpu := range o.values { + gpus = append(gpus, fmt.Sprintf("%v", gpu)) + } + return strings.Join(gpus, ", ") +} + +// Value returns the mounts +func (o *GpuOpts) Value() []container.DeviceRequest { + return o.values +} diff --git a/opts/gpus_test.go b/opts/gpus_test.go new file mode 100644 index 0000000000..23fe7992e3 --- /dev/null +++ b/opts/gpus_test.go @@ -0,0 +1,48 @@ +package opts + +import ( + "testing" + + "github.com/docker/docker/api/types/container" + "gotest.tools/assert" + is "gotest.tools/assert/cmp" +) + +func TestGpusOptAll(t *testing.T) { + for _, testcase := range []string{ + "all", + "-1", + "count=all", + "count=-1", + } { + var gpus GpuOpts + gpus.Set(testcase) + gpuReqs := gpus.Value() + assert.Assert(t, is.Len(gpuReqs, 1)) + assert.Check(t, is.DeepEqual(gpuReqs[0], container.DeviceRequest{ + Count: -1, + Capabilities: [][]string{{"gpu"}}, + Options: map[string]string{}, + })) + } +} + +func TestGpusOpts(t *testing.T) { + for _, testcase := range []string{ + "driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"", + "1,driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"", + "count=1,driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"", + "driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\",count=1", + } { + var gpus GpuOpts + gpus.Set(testcase) + gpuReqs := gpus.Value() + assert.Assert(t, is.Len(gpuReqs, 1)) + assert.Check(t, is.DeepEqual(gpuReqs[0], container.DeviceRequest{ + Driver: "nvidia", + Count: 1, + Capabilities: [][]string{{"compute", "utility", "gpu"}}, + Options: map[string]string{"foo": "bar", "baz": "qux"}, + })) + } +}