mirror of https://github.com/docker/cli.git
Support CDI devices in --device flag
Signed-off-by: Evan Lezar <elezar@nvidia.com>
This commit is contained in:
parent
dc2eb3bf7c
commit
dad225d1e2
|
@ -13,6 +13,7 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
cdi "github.com/container-orchestrated-devices/container-device-interface/pkg/parser"
|
||||||
"github.com/docker/cli/cli/compose/loader"
|
"github.com/docker/cli/cli/compose/loader"
|
||||||
"github.com/docker/cli/opts"
|
"github.com/docker/cli/opts"
|
||||||
"github.com/docker/docker/api/types/container"
|
"github.com/docker/docker/api/types/container"
|
||||||
|
@ -449,12 +450,17 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
|
||||||
// parsing flags, we haven't yet sent a _ping to the daemon to determine
|
// parsing flags, we haven't yet sent a _ping to the daemon to determine
|
||||||
// what operating system it is.
|
// what operating system it is.
|
||||||
deviceMappings := []container.DeviceMapping{}
|
deviceMappings := []container.DeviceMapping{}
|
||||||
|
var cdiDeviceNames []string
|
||||||
for _, device := range copts.devices.GetAll() {
|
for _, device := range copts.devices.GetAll() {
|
||||||
var (
|
var (
|
||||||
validated string
|
validated string
|
||||||
deviceMapping container.DeviceMapping
|
deviceMapping container.DeviceMapping
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
if cdi.IsQualifiedName(device) {
|
||||||
|
cdiDeviceNames = append(cdiDeviceNames, device)
|
||||||
|
continue
|
||||||
|
}
|
||||||
validated, err = validateDevice(device, serverOS)
|
validated, err = validateDevice(device, serverOS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -559,6 +565,15 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
deviceRequests := copts.gpus.Value()
|
||||||
|
if len(cdiDeviceNames) > 0 {
|
||||||
|
cdiDeviceRequest := container.DeviceRequest{
|
||||||
|
Driver: "cdi",
|
||||||
|
DeviceIDs: cdiDeviceNames,
|
||||||
|
}
|
||||||
|
deviceRequests = append(deviceRequests, cdiDeviceRequest)
|
||||||
|
}
|
||||||
|
|
||||||
resources := container.Resources{
|
resources := container.Resources{
|
||||||
CgroupParent: copts.cgroupParent,
|
CgroupParent: copts.cgroupParent,
|
||||||
Memory: copts.memory.Value(),
|
Memory: copts.memory.Value(),
|
||||||
|
@ -589,7 +604,7 @@ func parse(flags *pflag.FlagSet, copts *containerOptions, serverOS string) (*con
|
||||||
Ulimits: copts.ulimits.GetList(),
|
Ulimits: copts.ulimits.GetList(),
|
||||||
DeviceCgroupRules: copts.deviceCgroupRules.GetAll(),
|
DeviceCgroupRules: copts.deviceCgroupRules.GetAll(),
|
||||||
Devices: deviceMappings,
|
Devices: deviceMappings,
|
||||||
DeviceRequests: copts.gpus.Value(),
|
DeviceRequests: deviceRequests,
|
||||||
}
|
}
|
||||||
|
|
||||||
config := &container.Config{
|
config := &container.Config{
|
||||||
|
|
|
@ -417,39 +417,91 @@ func TestParseWithExpose(t *testing.T) {
|
||||||
|
|
||||||
func TestParseDevice(t *testing.T) {
|
func TestParseDevice(t *testing.T) {
|
||||||
skip.If(t, runtime.GOOS != "linux") // Windows and macOS validate server-side
|
skip.If(t, runtime.GOOS != "linux") // Windows and macOS validate server-side
|
||||||
valids := map[string]container.DeviceMapping{
|
testCases := []struct {
|
||||||
"/dev/snd": {
|
devices []string
|
||||||
|
deviceMapping *container.DeviceMapping
|
||||||
|
deviceRequests []container.DeviceRequest
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
devices: []string{"/dev/snd"},
|
||||||
|
deviceMapping: &container.DeviceMapping{
|
||||||
PathOnHost: "/dev/snd",
|
PathOnHost: "/dev/snd",
|
||||||
PathInContainer: "/dev/snd",
|
PathInContainer: "/dev/snd",
|
||||||
CgroupPermissions: "rwm",
|
CgroupPermissions: "rwm",
|
||||||
},
|
},
|
||||||
"/dev/snd:rw": {
|
},
|
||||||
|
{
|
||||||
|
devices: []string{"/dev/snd:rw"},
|
||||||
|
deviceMapping: &container.DeviceMapping{
|
||||||
PathOnHost: "/dev/snd",
|
PathOnHost: "/dev/snd",
|
||||||
PathInContainer: "/dev/snd",
|
PathInContainer: "/dev/snd",
|
||||||
CgroupPermissions: "rw",
|
CgroupPermissions: "rw",
|
||||||
},
|
},
|
||||||
"/dev/snd:/something": {
|
},
|
||||||
|
{
|
||||||
|
devices: []string{"/dev/snd:/something"},
|
||||||
|
deviceMapping: &container.DeviceMapping{
|
||||||
PathOnHost: "/dev/snd",
|
PathOnHost: "/dev/snd",
|
||||||
PathInContainer: "/something",
|
PathInContainer: "/something",
|
||||||
CgroupPermissions: "rwm",
|
CgroupPermissions: "rwm",
|
||||||
},
|
},
|
||||||
"/dev/snd:/something:rw": {
|
},
|
||||||
|
{
|
||||||
|
devices: []string{"/dev/snd:/something:rw"},
|
||||||
|
deviceMapping: &container.DeviceMapping{
|
||||||
PathOnHost: "/dev/snd",
|
PathOnHost: "/dev/snd",
|
||||||
PathInContainer: "/something",
|
PathInContainer: "/something",
|
||||||
CgroupPermissions: "rw",
|
CgroupPermissions: "rw",
|
||||||
},
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
devices: []string{"vendor.com/class=name"},
|
||||||
|
deviceMapping: nil,
|
||||||
|
deviceRequests: []container.DeviceRequest{
|
||||||
|
{
|
||||||
|
Driver: "cdi",
|
||||||
|
DeviceIDs: []string{"vendor.com/class=name"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
devices: []string{"vendor.com/class=name", "/dev/snd:/something:rw"},
|
||||||
|
deviceMapping: &container.DeviceMapping{
|
||||||
|
PathOnHost: "/dev/snd",
|
||||||
|
PathInContainer: "/something",
|
||||||
|
CgroupPermissions: "rw",
|
||||||
|
},
|
||||||
|
deviceRequests: []container.DeviceRequest{
|
||||||
|
{
|
||||||
|
Driver: "cdi",
|
||||||
|
DeviceIDs: []string{"vendor.com/class=name"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for device, deviceMapping := range valids {
|
|
||||||
_, hostconfig, _, err := parseRun([]string{fmt.Sprintf("--device=%v", device), "img", "cmd"})
|
for _, tc := range testCases {
|
||||||
if err != nil {
|
t.Run(fmt.Sprintf("%s", tc.devices), func(t *testing.T) {
|
||||||
t.Fatal(err)
|
var args []string
|
||||||
|
for _, d := range tc.devices {
|
||||||
|
args = append(args, fmt.Sprintf("--device=%v", d))
|
||||||
}
|
}
|
||||||
if len(hostconfig.Devices) != 1 {
|
args = append(args, "img", "cmd")
|
||||||
t.Fatalf("Expected 1 devices, got %v", hostconfig.Devices)
|
|
||||||
|
_, hostconfig, _, err := parseRun(args)
|
||||||
|
|
||||||
|
assert.NilError(t, err)
|
||||||
|
|
||||||
|
if tc.deviceMapping != nil {
|
||||||
|
if assert.Check(t, is.Len(hostconfig.Devices, 1)) {
|
||||||
|
assert.Check(t, is.DeepEqual(*tc.deviceMapping, hostconfig.Devices[0]))
|
||||||
}
|
}
|
||||||
if hostconfig.Devices[0] != deviceMapping {
|
} else {
|
||||||
t.Fatalf("Expected %v, got %v", deviceMapping, hostconfig.Devices)
|
assert.Check(t, is.Len(hostconfig.Devices, 0))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert.Check(t, is.DeepEqual(tc.deviceRequests, hostconfig.DeviceRequests))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue