Support CDI devices in --device flag

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2023-03-10 08:29:47 +02:00
parent dc2eb3bf7c
commit dad225d1e2
2 changed files with 96 additions and 29 deletions

View File

@ -13,6 +13,7 @@ import (
"strings" "strings"
"time" "time"
cdi "github.com/container-orchestrated-devices/container-device-interface/pkg/parser"
"github.com/docker/cli/cli/compose/loader" "github.com/docker/cli/cli/compose/loader"
"github.com/docker/cli/opts" "github.com/docker/cli/opts"
"github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/container"
@ -449,12 +450,17 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
// parsing flags, we haven't yet sent a _ping to the daemon to determine // parsing flags, we haven't yet sent a _ping to the daemon to determine
// what operating system it is. // what operating system it is.
deviceMappings := []container.DeviceMapping{} deviceMappings := []container.DeviceMapping{}
var cdiDeviceNames []string
for _, device := range copts.devices.GetAll() { for _, device := range copts.devices.GetAll() {
var ( var (
validated string validated string
deviceMapping container.DeviceMapping deviceMapping container.DeviceMapping
err error err error
) )
if cdi.IsQualifiedName(device) {
cdiDeviceNames = append(cdiDeviceNames, device)
continue
}
validated, err = validateDevice(device, serverOS) validated, err = validateDevice(device, serverOS)
if err != nil { if err != nil {
return nil, err return nil, err
@ -559,6 +565,15 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
} }
} }
deviceRequests := copts.gpus.Value()
if len(cdiDeviceNames) > 0 {
cdiDeviceRequest := container.DeviceRequest{
Driver: "cdi",
DeviceIDs: cdiDeviceNames,
}
deviceRequests = append(deviceRequests, cdiDeviceRequest)
}
resources := container.Resources{ resources := container.Resources{
CgroupParent: copts.cgroupParent, CgroupParent: copts.cgroupParent,
Memory: copts.memory.Value(), Memory: copts.memory.Value(),
@ -589,7 +604,7 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
Ulimits: copts.ulimits.GetList(), Ulimits: copts.ulimits.GetList(),
DeviceCgroupRules: copts.deviceCgroupRules.GetAll(), DeviceCgroupRules: copts.deviceCgroupRules.GetAll(),
Devices: deviceMappings, Devices: deviceMappings,
DeviceRequests: copts.gpus.Value(), DeviceRequests: deviceRequests,
} }
config := &container.Config{ config := &container.Config{

View File

@ -417,39 +417,91 @@ func TestParseWithExpose(t *testing.T) {
func TestParseDevice(t *testing.T) { func TestParseDevice(t *testing.T) {
skip.If(t, runtime.GOOS != "linux") // Windows and macOS validate server-side skip.If(t, runtime.GOOS != "linux") // Windows and macOS validate server-side
valids := map[string]container.DeviceMapping{ testCases := []struct {
"/dev/snd": { devices []string
PathOnHost: "/dev/snd", deviceMapping *container.DeviceMapping
PathInContainer: "/dev/snd", deviceRequests []container.DeviceRequest
CgroupPermissions: "rwm", }{
{
devices: []string{"/dev/snd"},
deviceMapping: &container.DeviceMapping{
PathOnHost: "/dev/snd",
PathInContainer: "/dev/snd",
CgroupPermissions: "rwm",
},
}, },
"/dev/snd:rw": { {
PathOnHost: "/dev/snd", devices: []string{"/dev/snd:rw"},
PathInContainer: "/dev/snd", deviceMapping: &container.DeviceMapping{
CgroupPermissions: "rw", PathOnHost: "/dev/snd",
PathInContainer: "/dev/snd",
CgroupPermissions: "rw",
},
}, },
"/dev/snd:/something": { {
PathOnHost: "/dev/snd", devices: []string{"/dev/snd:/something"},
PathInContainer: "/something", deviceMapping: &container.DeviceMapping{
CgroupPermissions: "rwm", PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rwm",
},
}, },
"/dev/snd:/something:rw": { {
PathOnHost: "/dev/snd", devices: []string{"/dev/snd:/something:rw"},
PathInContainer: "/something", deviceMapping: &container.DeviceMapping{
CgroupPermissions: "rw", PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rw",
},
},
{
devices: []string{"vendor.com/class=name"},
deviceMapping: nil,
deviceRequests: []container.DeviceRequest{
{
Driver: "cdi",
DeviceIDs: []string{"vendor.com/class=name"},
},
},
},
{
devices: []string{"vendor.com/class=name", "/dev/snd:/something:rw"},
deviceMapping: &container.DeviceMapping{
PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rw",
},
deviceRequests: []container.DeviceRequest{
{
Driver: "cdi",
DeviceIDs: []string{"vendor.com/class=name"},
},
},
}, },
} }
for device, deviceMapping := range valids {
_, hostconfig, _, err := parseRun([]string{fmt.Sprintf("--device=%v", device), "img", "cmd"}) for _, tc := range testCases {
if err != nil { t.Run(fmt.Sprintf("%s", tc.devices), func(t *testing.T) {
t.Fatal(err) var args []string
} for _, d := range tc.devices {
if len(hostconfig.Devices) != 1 { args = append(args, fmt.Sprintf("--device=%v", d))
t.Fatalf("Expected 1 devices, got %v", hostconfig.Devices) }
} args = append(args, "img", "cmd")
if hostconfig.Devices[0] != deviceMapping {
t.Fatalf("Expected %v, got %v", deviceMapping, hostconfig.Devices) _, hostconfig, _, err := parseRun(args)
}
assert.NilError(t, err)
if tc.deviceMapping != nil {
if assert.Check(t, is.Len(hostconfig.Devices, 1)) {
assert.Check(t, is.DeepEqual(*tc.deviceMapping, hostconfig.Devices[0]))
}
} else {
assert.Check(t, is.Len(hostconfig.Devices, 0))
}
assert.Check(t, is.DeepEqual(tc.deviceRequests, hostconfig.DeviceRequests))
})
} }
} }