mirror of https://github.com/docker/cli.git
container: --gpus support
Signed-off-by: Tibor Vass <tibor@docker.com>
This commit is contained in:
parent
91339e1108
commit
1ba368a5ac
|
@ -46,6 +46,7 @@ type containerOptions struct {
|
||||||
labels opts.ListOpts
|
labels opts.ListOpts
|
||||||
deviceCgroupRules opts.ListOpts
|
deviceCgroupRules opts.ListOpts
|
||||||
devices opts.ListOpts
|
devices opts.ListOpts
|
||||||
|
gpus opts.GpuOpts
|
||||||
ulimits *opts.UlimitOpt
|
ulimits *opts.UlimitOpt
|
||||||
sysctls *opts.MapOpts
|
sysctls *opts.MapOpts
|
||||||
publish opts.ListOpts
|
publish opts.ListOpts
|
||||||
|
@ -166,6 +167,8 @@ func addFlags(flags *pflag.FlagSet) *containerOptions {
|
||||||
flags.VarP(&copts.attach, "attach", "a", "Attach to STDIN, STDOUT or STDERR")
|
flags.VarP(&copts.attach, "attach", "a", "Attach to STDIN, STDOUT or STDERR")
|
||||||
flags.Var(&copts.deviceCgroupRules, "device-cgroup-rule", "Add a rule to the cgroup allowed devices list")
|
flags.Var(&copts.deviceCgroupRules, "device-cgroup-rule", "Add a rule to the cgroup allowed devices list")
|
||||||
flags.Var(&copts.devices, "device", "Add a host device to the container")
|
flags.Var(&copts.devices, "device", "Add a host device to the container")
|
||||||
|
flags.Var(&copts.gpus, "gpus", "GPU devices to add to the container ('all' to pass all GPUs)")
|
||||||
|
flags.SetAnnotation("gpus", "version", []string{"1.40"})
|
||||||
flags.VarP(&copts.env, "env", "e", "Set environment variables")
|
flags.VarP(&copts.env, "env", "e", "Set environment variables")
|
||||||
flags.Var(&copts.envFile, "env-file", "Read in a file of environment variables")
|
flags.Var(&copts.envFile, "env-file", "Read in a file of environment variables")
|
||||||
flags.StringVar(&copts.entrypoint, "entrypoint", "", "Overwrite the default ENTRYPOINT of the image")
|
flags.StringVar(&copts.entrypoint, "entrypoint", "", "Overwrite the default ENTRYPOINT of the image")
|
||||||
|
@ -557,6 +560,7 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
|
||||||
Ulimits: copts.ulimits.GetList(),
|
Ulimits: copts.ulimits.GetList(),
|
||||||
DeviceCgroupRules: copts.deviceCgroupRules.GetAll(),
|
DeviceCgroupRules: copts.deviceCgroupRules.GetAll(),
|
||||||
Devices: deviceMappings,
|
Devices: deviceMappings,
|
||||||
|
DeviceRequests: copts.gpus.Value(),
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &container.Config{
|
config := &container.Config{
|
||||||
|
|
|
@ -0,0 +1,112 @@
|
||||||
|
package opts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/csv"
|
||||||
|
"fmt"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/docker/docker/api/types/container"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GpuOpts is a Value type for parsing mounts
|
||||||
|
type GpuOpts struct {
|
||||||
|
values []container.DeviceRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseCount(s string) (int, error) {
|
||||||
|
if s == "all" {
|
||||||
|
return -1, nil
|
||||||
|
}
|
||||||
|
i, err := strconv.Atoi(s)
|
||||||
|
return i, errors.Wrap(err, "count must be an integer")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set a new mount value
|
||||||
|
// nolint: gocyclo
|
||||||
|
func (o *GpuOpts) Set(value string) error {
|
||||||
|
csvReader := csv.NewReader(strings.NewReader(value))
|
||||||
|
fields, err := csvReader.Read()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req := container.DeviceRequest{}
|
||||||
|
|
||||||
|
seen := map[string]struct{}{}
|
||||||
|
// Set writable as the default
|
||||||
|
for _, field := range fields {
|
||||||
|
parts := strings.SplitN(field, "=", 2)
|
||||||
|
key := parts[0]
|
||||||
|
if _, ok := seen[key]; ok {
|
||||||
|
return fmt.Errorf("gpu request key '%s' can be specified only once", key)
|
||||||
|
}
|
||||||
|
seen[key] = struct{}{}
|
||||||
|
|
||||||
|
if len(parts) == 1 {
|
||||||
|
seen["count"] = struct{}{}
|
||||||
|
req.Count, err = parseCount(key)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
value := parts[1]
|
||||||
|
switch key {
|
||||||
|
case "driver":
|
||||||
|
req.Driver = value
|
||||||
|
case "count":
|
||||||
|
req.Count, err = parseCount(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
case "device":
|
||||||
|
req.DeviceIDs = strings.Split(value, ",")
|
||||||
|
case "capabilities":
|
||||||
|
req.Capabilities = [][]string{append(strings.Split(value, ","), "gpu")}
|
||||||
|
case "options":
|
||||||
|
r := csv.NewReader(strings.NewReader(value))
|
||||||
|
optFields, err := r.Read()
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "failed to read gpu options")
|
||||||
|
}
|
||||||
|
req.Options = ConvertKVStringsToMap(optFields)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unexpected key '%s' in '%s'", key, field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := seen["count"]; !ok && req.DeviceIDs == nil {
|
||||||
|
req.Count = 1
|
||||||
|
}
|
||||||
|
if req.Options == nil {
|
||||||
|
req.Options = make(map[string]string)
|
||||||
|
}
|
||||||
|
if req.Capabilities == nil {
|
||||||
|
req.Capabilities = [][]string{{"gpu"}}
|
||||||
|
}
|
||||||
|
|
||||||
|
o.values = append(o.values, req)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type returns the type of this option
|
||||||
|
func (o *GpuOpts) Type() string {
|
||||||
|
return "gpu-request"
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string repr of this option
|
||||||
|
func (o *GpuOpts) String() string {
|
||||||
|
gpus := []string{}
|
||||||
|
for _, gpu := range o.values {
|
||||||
|
gpus = append(gpus, fmt.Sprintf("%v", gpu))
|
||||||
|
}
|
||||||
|
return strings.Join(gpus, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value returns the mounts
|
||||||
|
func (o *GpuOpts) Value() []container.DeviceRequest {
|
||||||
|
return o.values
|
||||||
|
}
|
|
@ -0,0 +1,48 @@
|
||||||
|
package opts
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/docker/docker/api/types/container"
|
||||||
|
"gotest.tools/assert"
|
||||||
|
is "gotest.tools/assert/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGpusOptAll(t *testing.T) {
|
||||||
|
for _, testcase := range []string{
|
||||||
|
"all",
|
||||||
|
"-1",
|
||||||
|
"count=all",
|
||||||
|
"count=-1",
|
||||||
|
} {
|
||||||
|
var gpus GpuOpts
|
||||||
|
gpus.Set(testcase)
|
||||||
|
gpuReqs := gpus.Value()
|
||||||
|
assert.Assert(t, is.Len(gpuReqs, 1))
|
||||||
|
assert.Check(t, is.DeepEqual(gpuReqs[0], container.DeviceRequest{
|
||||||
|
Count: -1,
|
||||||
|
Capabilities: [][]string{{"gpu"}},
|
||||||
|
Options: map[string]string{},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGpusOpts(t *testing.T) {
|
||||||
|
for _, testcase := range []string{
|
||||||
|
"driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"",
|
||||||
|
"1,driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"",
|
||||||
|
"count=1,driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\"",
|
||||||
|
"driver=nvidia,\"capabilities=compute,utility\",\"options=foo=bar,baz=qux\",count=1",
|
||||||
|
} {
|
||||||
|
var gpus GpuOpts
|
||||||
|
gpus.Set(testcase)
|
||||||
|
gpuReqs := gpus.Value()
|
||||||
|
assert.Assert(t, is.Len(gpuReqs, 1))
|
||||||
|
assert.Check(t, is.DeepEqual(gpuReqs[0], container.DeviceRequest{
|
||||||
|
Driver: "nvidia",
|
||||||
|
Count: 1,
|
||||||
|
Capabilities: [][]string{{"compute", "utility", "gpu"}},
|
||||||
|
Options: map[string]string{"foo": "bar", "baz": "qux"},
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue