diff --git a/cli/command/container/opts.go b/cli/command/container/opts.go index a68bbe4833..73f40c8012 100644 --- a/cli/command/container/opts.go +++ b/cli/command/container/opts.go @@ -13,6 +13,7 @@ import ( "strings" "time" + cdi "github.com/container-orchestrated-devices/container-device-interface/pkg/parser" "github.com/docker/cli/cli/compose/loader" "github.com/docker/cli/opts" "github.com/docker/docker/api/types/container" @@ -449,12 +450,17 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con // parsing flags, we haven't yet sent a _ping to the daemon to determine // what operating system it is. deviceMappings := []container.DeviceMapping{} + var cdiDeviceNames []string for _, device := range copts.devices.GetAll() { var ( validated string deviceMapping container.DeviceMapping err error ) + if cdi.IsQualifiedName(device) { + cdiDeviceNames = append(cdiDeviceNames, device) + continue + } validated, err = validateDevice(device, serverOS) if err != nil { return nil, err @@ -559,6 +565,15 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con } } + deviceRequests := copts.gpus.Value() + if len(cdiDeviceNames) > 0 { + cdiDeviceRequest := container.DeviceRequest{ + Driver: "cdi", + DeviceIDs: cdiDeviceNames, + } + deviceRequests = append(deviceRequests, cdiDeviceRequest) + } + resources := container.Resources{ CgroupParent: copts.cgroupParent, Memory: copts.memory.Value(), @@ -589,7 +604,7 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con Ulimits: copts.ulimits.GetList(), DeviceCgroupRules: copts.deviceCgroupRules.GetAll(), Devices: deviceMappings, - DeviceRequests: copts.gpus.Value(), + DeviceRequests: deviceRequests, } config := &container.Config{ diff --git a/cli/command/container/opts_test.go b/cli/command/container/opts_test.go index 1d2514e3d0..885f693b91 100644 --- a/cli/command/container/opts_test.go +++ b/cli/command/container/opts_test.go @@ -417,39 +417,91 @@ func TestParseWithExpose(t *testing.T) { func TestParseDevice(t *testing.T) { skip.If(t, runtime.GOOS != "linux") // Windows and macOS validate server-side - valids := map[string]container.DeviceMapping{ - "/dev/snd": { - PathOnHost: "/dev/snd", - PathInContainer: "/dev/snd", - CgroupPermissions: "rwm", + testCases := []struct { + devices []string + deviceMapping *container.DeviceMapping + deviceRequests []container.DeviceRequest + }{ + { + devices: []string{"/dev/snd"}, + deviceMapping: &container.DeviceMapping{ + PathOnHost: "/dev/snd", + PathInContainer: "/dev/snd", + CgroupPermissions: "rwm", + }, }, - "/dev/snd:rw": { - PathOnHost: "/dev/snd", - PathInContainer: "/dev/snd", - CgroupPermissions: "rw", + { + devices: []string{"/dev/snd:rw"}, + deviceMapping: &container.DeviceMapping{ + PathOnHost: "/dev/snd", + PathInContainer: "/dev/snd", + CgroupPermissions: "rw", + }, }, - "/dev/snd:/something": { - PathOnHost: "/dev/snd", - PathInContainer: "/something", - CgroupPermissions: "rwm", + { + devices: []string{"/dev/snd:/something"}, + deviceMapping: &container.DeviceMapping{ + PathOnHost: "/dev/snd", + PathInContainer: "/something", + CgroupPermissions: "rwm", + }, }, - "/dev/snd:/something:rw": { - PathOnHost: "/dev/snd", - PathInContainer: "/something", - CgroupPermissions: "rw", + { + devices: []string{"/dev/snd:/something:rw"}, + deviceMapping: &container.DeviceMapping{ + PathOnHost: "/dev/snd", + PathInContainer: "/something", + CgroupPermissions: "rw", + }, + }, + { + devices: []string{"vendor.com/class=name"}, + deviceMapping: nil, + deviceRequests: []container.DeviceRequest{ + { + Driver: "cdi", + DeviceIDs: []string{"vendor.com/class=name"}, + }, + }, + }, + { + devices: []string{"vendor.com/class=name", "/dev/snd:/something:rw"}, + deviceMapping: &container.DeviceMapping{ + PathOnHost: "/dev/snd", + PathInContainer: "/something", + CgroupPermissions: "rw", + }, + deviceRequests: []container.DeviceRequest{ + { + Driver: "cdi", + DeviceIDs: []string{"vendor.com/class=name"}, + }, + }, }, } - for device, deviceMapping := range valids { - _, hostconfig, _, err := parseRun([]string{fmt.Sprintf("--device=%v", device), "img", "cmd"}) - if err != nil { - t.Fatal(err) - } - if len(hostconfig.Devices) != 1 { - t.Fatalf("Expected 1 devices, got %v", hostconfig.Devices) - } - if hostconfig.Devices[0] != deviceMapping { - t.Fatalf("Expected %v, got %v", deviceMapping, hostconfig.Devices) - } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%s", tc.devices), func(t *testing.T) { + var args []string + for _, d := range tc.devices { + args = append(args, fmt.Sprintf("--device=%v", d)) + } + args = append(args, "img", "cmd") + + _, hostconfig, _, err := parseRun(args) + + assert.NilError(t, err) + + if tc.deviceMapping != nil { + if assert.Check(t, is.Len(hostconfig.Devices, 1)) { + assert.Check(t, is.DeepEqual(*tc.deviceMapping, hostconfig.Devices[0])) + } + } else { + assert.Check(t, is.Len(hostconfig.Devices, 0)) + } + + assert.Check(t, is.DeepEqual(tc.deviceRequests, hostconfig.DeviceRequests)) + }) } }