Support CDI devices in --device flag

Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
Evan Lezar 2023-03-10 08:29:47 +02:00
parent dc2eb3bf7c
commit dad225d1e2
2 changed files with 96 additions and 29 deletions

View File

@ -13,6 +13,7 @@ import (
"strings"
"time"
cdi "github.com/container-orchestrated-devices/container-device-interface/pkg/parser"
"github.com/docker/cli/cli/compose/loader"
"github.com/docker/cli/opts"
"github.com/docker/docker/api/types/container"
@ -449,12 +450,17 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
// parsing flags, we haven't yet sent a _ping to the daemon to determine
// what operating system it is.
deviceMappings := []container.DeviceMapping{}
var cdiDeviceNames []string
for _, device := range copts.devices.GetAll() {
var (
validated string
deviceMapping container.DeviceMapping
err error
)
if cdi.IsQualifiedName(device) {
cdiDeviceNames = append(cdiDeviceNames, device)
continue
}
validated, err = validateDevice(device, serverOS)
if err != nil {
return nil, err
@ -559,6 +565,15 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
}
}
deviceRequests := copts.gpus.Value()
if len(cdiDeviceNames) > 0 {
cdiDeviceRequest := container.DeviceRequest{
Driver: "cdi",
DeviceIDs: cdiDeviceNames,
}
deviceRequests = append(deviceRequests, cdiDeviceRequest)
}
resources := container.Resources{
CgroupParent: copts.cgroupParent,
Memory: copts.memory.Value(),
@ -589,7 +604,7 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
Ulimits: copts.ulimits.GetList(),
DeviceCgroupRules: copts.deviceCgroupRules.GetAll(),
Devices: deviceMappings,
DeviceRequests: copts.gpus.Value(),
DeviceRequests: deviceRequests,
}
config := &container.Config{

View File

@ -417,39 +417,91 @@ func TestParseWithExpose(t *testing.T) {
func TestParseDevice(t *testing.T) {
skip.If(t, runtime.GOOS != "linux") // Windows and macOS validate server-side
valids := map[string]container.DeviceMapping{
"/dev/snd": {
PathOnHost: "/dev/snd",
PathInContainer: "/dev/snd",
CgroupPermissions: "rwm",
testCases := []struct {
devices []string
deviceMapping *container.DeviceMapping
deviceRequests []container.DeviceRequest
}{
{
devices: []string{"/dev/snd"},
deviceMapping: &container.DeviceMapping{
PathOnHost: "/dev/snd",
PathInContainer: "/dev/snd",
CgroupPermissions: "rwm",
},
},
"/dev/snd:rw": {
PathOnHost: "/dev/snd",
PathInContainer: "/dev/snd",
CgroupPermissions: "rw",
{
devices: []string{"/dev/snd:rw"},
deviceMapping: &container.DeviceMapping{
PathOnHost: "/dev/snd",
PathInContainer: "/dev/snd",
CgroupPermissions: "rw",
},
},
"/dev/snd:/something": {
PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rwm",
{
devices: []string{"/dev/snd:/something"},
deviceMapping: &container.DeviceMapping{
PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rwm",
},
},
"/dev/snd:/something:rw": {
PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rw",
{
devices: []string{"/dev/snd:/something:rw"},
deviceMapping: &container.DeviceMapping{
PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rw",
},
},
{
devices: []string{"vendor.com/class=name"},
deviceMapping: nil,
deviceRequests: []container.DeviceRequest{
{
Driver: "cdi",
DeviceIDs: []string{"vendor.com/class=name"},
},
},
},
{
devices: []string{"vendor.com/class=name", "/dev/snd:/something:rw"},
deviceMapping: &container.DeviceMapping{
PathOnHost: "/dev/snd",
PathInContainer: "/something",
CgroupPermissions: "rw",
},
deviceRequests: []container.DeviceRequest{
{
Driver: "cdi",
DeviceIDs: []string{"vendor.com/class=name"},
},
},
},
}
for device, deviceMapping := range valids {
_, hostconfig, _, err := parseRun([]string{fmt.Sprintf("--device=%v", device), "img", "cmd"})
if err != nil {
t.Fatal(err)
}
if len(hostconfig.Devices) != 1 {
t.Fatalf("Expected 1 devices, got %v", hostconfig.Devices)
}
if hostconfig.Devices[0] != deviceMapping {
t.Fatalf("Expected %v, got %v", deviceMapping, hostconfig.Devices)
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("%s", tc.devices), func(t *testing.T) {
var args []string
for _, d := range tc.devices {
args = append(args, fmt.Sprintf("--device=%v", d))
}
args = append(args, "img", "cmd")
_, hostconfig, _, err := parseRun(args)
assert.NilError(t, err)
if tc.deviceMapping != nil {
if assert.Check(t, is.Len(hostconfig.Devices, 1)) {
assert.Check(t, is.DeepEqual(*tc.deviceMapping, hostconfig.Devices[0]))
}
} else {
assert.Check(t, is.Len(hostconfig.Devices, 0))
}
assert.Check(t, is.DeepEqual(tc.deviceRequests, hostconfig.DeviceRequests))
})
}
}