mirror of https://github.com/docker/cli.git
Support CDI devices in --device flag
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
dc2eb3bf7c
commit
dad225d1e2
|
@ -13,6 +13,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
cdi "github.com/container-orchestrated-devices/container-device-interface/pkg/parser"
|
||||
"github.com/docker/cli/cli/compose/loader"
|
||||
"github.com/docker/cli/opts"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
|
@ -449,12 +450,17 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
|
|||
// parsing flags, we haven't yet sent a _ping to the daemon to determine
|
||||
// what operating system it is.
|
||||
deviceMappings := []container.DeviceMapping{}
|
||||
var cdiDeviceNames []string
|
||||
for _, device := range copts.devices.GetAll() {
|
||||
var (
|
||||
validated string
|
||||
deviceMapping container.DeviceMapping
|
||||
err error
|
||||
)
|
||||
if cdi.IsQualifiedName(device) {
|
||||
cdiDeviceNames = append(cdiDeviceNames, device)
|
||||
continue
|
||||
}
|
||||
validated, err = validateDevice(device, serverOS)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -559,6 +565,15 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
|
|||
}
|
||||
}
|
||||
|
||||
deviceRequests := copts.gpus.Value()
|
||||
if len(cdiDeviceNames) > 0 {
|
||||
cdiDeviceRequest := container.DeviceRequest{
|
||||
Driver: "cdi",
|
||||
DeviceIDs: cdiDeviceNames,
|
||||
}
|
||||
deviceRequests = append(deviceRequests, cdiDeviceRequest)
|
||||
}
|
||||
|
||||
resources := container.Resources{
|
||||
CgroupParent: copts.cgroupParent,
|
||||
Memory: copts.memory.Value(),
|
||||
|
@ -589,7 +604,7 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
|
|||
Ulimits: copts.ulimits.GetList(),
|
||||
DeviceCgroupRules: copts.deviceCgroupRules.GetAll(),
|
||||
Devices: deviceMappings,
|
||||
DeviceRequests: copts.gpus.Value(),
|
||||
DeviceRequests: deviceRequests,
|
||||
}
|
||||
|
||||
config := &container.Config{
|
||||
|
|
|
@ -417,39 +417,91 @@ func TestParseWithExpose(t *testing.T) {
|
|||
|
||||
func TestParseDevice(t *testing.T) {
|
||||
skip.If(t, runtime.GOOS != "linux") // Windows and macOS validate server-side
|
||||
valids := map[string]container.DeviceMapping{
|
||||
"/dev/snd": {
|
||||
testCases := []struct {
|
||||
devices []string
|
||||
deviceMapping *container.DeviceMapping
|
||||
deviceRequests []container.DeviceRequest
|
||||
}{
|
||||
{
|
||||
devices: []string{"/dev/snd"},
|
||||
deviceMapping: &container.DeviceMapping{
|
||||
PathOnHost: "/dev/snd",
|
||||
PathInContainer: "/dev/snd",
|
||||
CgroupPermissions: "rwm",
|
||||
},
|
||||
"/dev/snd:rw": {
|
||||
},
|
||||
{
|
||||
devices: []string{"/dev/snd:rw"},
|
||||
deviceMapping: &container.DeviceMapping{
|
||||
PathOnHost: "/dev/snd",
|
||||
PathInContainer: "/dev/snd",
|
||||
CgroupPermissions: "rw",
|
||||
},
|
||||
"/dev/snd:/something": {
|
||||
},
|
||||
{
|
||||
devices: []string{"/dev/snd:/something"},
|
||||
deviceMapping: &container.DeviceMapping{
|
||||
PathOnHost: "/dev/snd",
|
||||
PathInContainer: "/something",
|
||||
CgroupPermissions: "rwm",
|
||||
},
|
||||
"/dev/snd:/something:rw": {
|
||||
},
|
||||
{
|
||||
devices: []string{"/dev/snd:/something:rw"},
|
||||
deviceMapping: &container.DeviceMapping{
|
||||
PathOnHost: "/dev/snd",
|
||||
PathInContainer: "/something",
|
||||
CgroupPermissions: "rw",
|
||||
},
|
||||
},
|
||||
{
|
||||
devices: []string{"vendor.com/class=name"},
|
||||
deviceMapping: nil,
|
||||
deviceRequests: []container.DeviceRequest{
|
||||
{
|
||||
Driver: "cdi",
|
||||
DeviceIDs: []string{"vendor.com/class=name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
devices: []string{"vendor.com/class=name", "/dev/snd:/something:rw"},
|
||||
deviceMapping: &container.DeviceMapping{
|
||||
PathOnHost: "/dev/snd",
|
||||
PathInContainer: "/something",
|
||||
CgroupPermissions: "rw",
|
||||
},
|
||||
deviceRequests: []container.DeviceRequest{
|
||||
{
|
||||
Driver: "cdi",
|
||||
DeviceIDs: []string{"vendor.com/class=name"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for device, deviceMapping := range valids {
|
||||
_, hostconfig, _, err := parseRun([]string{fmt.Sprintf("--device=%v", device), "img", "cmd"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(fmt.Sprintf("%s", tc.devices), func(t *testing.T) {
|
||||
var args []string
|
||||
for _, d := range tc.devices {
|
||||
args = append(args, fmt.Sprintf("--device=%v", d))
|
||||
}
|
||||
if len(hostconfig.Devices) != 1 {
|
||||
t.Fatalf("Expected 1 devices, got %v", hostconfig.Devices)
|
||||
args = append(args, "img", "cmd")
|
||||
|
||||
_, hostconfig, _, err := parseRun(args)
|
||||
|
||||
assert.NilError(t, err)
|
||||
|
||||
if tc.deviceMapping != nil {
|
||||
if assert.Check(t, is.Len(hostconfig.Devices, 1)) {
|
||||
assert.Check(t, is.DeepEqual(*tc.deviceMapping, hostconfig.Devices[0]))
|
||||
}
|
||||
if hostconfig.Devices[0] != deviceMapping {
|
||||
t.Fatalf("Expected %v, got %v", deviceMapping, hostconfig.Devices)
|
||||
} else {
|
||||
assert.Check(t, is.Len(hostconfig.Devices, 0))
|
||||
}
|
||||
|
||||
assert.Check(t, is.DeepEqual(tc.deviceRequests, hostconfig.DeviceRequests))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue