Skip to content

Commit 3158146

Browse files
committed
Extend the 'runtime.nvidia.com/gpu' CDI device kind to support MIG devices specified by index or UUID
Signed-off-by: Christopher Desiniotis <[email protected]>
1 parent def7d09 commit 3158146

File tree

1 file changed

+55
-7
lines changed

1 file changed

+55
-7
lines changed

pkg/nvcdi/lib-nvml.go

Lines changed: 55 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package nvcdi
1919
import (
2020
"fmt"
2121
"strconv"
22+
"strings"
2223

2324
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
2425
"github.com/NVIDIA/go-nvlib/pkg/nvml"
@@ -79,7 +80,6 @@ func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
7980
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
8081
// the provided identifiers, where an identifier is an index or UUID of a valid
8182
// GPU device.
82-
// TODO: support identifiers that correspond to MIG devices
8383
func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, error) {
8484
for _, id := range identifiers {
8585
if id == "all" {
@@ -104,11 +104,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err
104104
}
105105

106106
for i, nvmlDevice := range nvmlDevices {
107-
nvlibDevice, err := l.devicelib.NewDevice(nvmlDevice)
108-
if err != nil {
109-
return nil, fmt.Errorf("failed to construct device: %w", err)
110-
}
111-
deviceEdits, err := l.GetGPUDeviceEdits(nvlibDevice)
107+
deviceEdits, err := l.getEditsForDevice(nvmlDevice)
112108
if err != nil {
113109
return nil, fmt.Errorf("failed to get CDI device edits for identifier %q: %w", identifiers[i], err)
114110
}
@@ -151,12 +147,64 @@ func (l *nvmllib) getNVMLDeviceByID(id string) (nvml.Device, error) {
151147
}
152148

153149
if devID.isMigIndex() {
154-
return nil, fmt.Errorf("MIG index is not supported")
150+
var gpuIdx, migIdx int
151+
var parent nvml.Device
152+
split := strings.SplitN(id, ":", 2)
153+
if gpuIdx, err = strconv.Atoi(split[0]); err != nil {
154+
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
155+
}
156+
if migIdx, err = strconv.Atoi(split[1]); err != nil {
157+
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
158+
}
159+
if parent, err = l.nvmllib.DeviceGetHandleByIndex(gpuIdx); err != nvml.SUCCESS {
160+
return nil, fmt.Errorf("failed to get parent device handle: %w", err)
161+
}
162+
return parent.GetMigDeviceHandleByIndex(migIdx)
155163
}
156164

157165
return nil, fmt.Errorf("identifier is not a valid UUID or index: %q", id)
158166
}
159167

168+
func (l *nvmllib) getEditsForDevice(nvmlDevice nvml.Device) (*cdi.ContainerEdits, error) {
169+
mig, err := nvmlDevice.IsMigDeviceHandle()
170+
if err != nvml.SUCCESS {
171+
return nil, fmt.Errorf("failed to determine if device handle is a MIG device: %w", err)
172+
}
173+
if mig {
174+
return l.getEditsForMIGDevice(nvmlDevice)
175+
}
176+
return l.getEditsForGPUDevice(nvmlDevice)
177+
}
178+
179+
func (l *nvmllib) getEditsForGPUDevice(nvmlDevice nvml.Device) (*cdi.ContainerEdits, error) {
180+
nvlibDevice, err := l.devicelib.NewDevice(nvmlDevice)
181+
if err != nil {
182+
return nil, fmt.Errorf("failed to construct device: %w", err)
183+
}
184+
deviceEdits, err := l.GetGPUDeviceEdits(nvlibDevice)
185+
if err != nil {
186+
return nil, fmt.Errorf("failed to get GPU device edits: %w", err)
187+
}
188+
189+
return deviceEdits, nil
190+
}
191+
192+
func (l *nvmllib) getEditsForMIGDevice(nvmlDevice nvml.Device) (*cdi.ContainerEdits, error) {
193+
nvmlParentDevice, ret := nvmlDevice.GetDeviceHandleFromMigDeviceHandle()
194+
if ret != nvml.SUCCESS {
195+
return nil, fmt.Errorf("failed to get parent device handle: %w", ret)
196+
}
197+
nvlibMigDevice, err := l.devicelib.NewMigDevice(nvmlDevice)
198+
if err != nil {
199+
return nil, fmt.Errorf("failed to construct device: %w", err)
200+
}
201+
nvlibParentDevice, err := l.devicelib.NewDevice(nvmlParentDevice)
202+
if err != nil {
203+
return nil, fmt.Errorf("failed to construct parent device: %w", err)
204+
}
205+
return l.GetMIGDeviceEdits(nvlibParentDevice, nvlibMigDevice)
206+
}
207+
160208
func (l *nvmllib) getGPUDeviceSpecs() ([]specs.Device, error) {
161209
var deviceSpecs []specs.Device
162210
err := l.devicelib.VisitDevices(func(i int, d device.Device) error {

0 commit comments

Comments
 (0)