@@ -19,6 +19,7 @@ package nvcdi
19
19
import (
20
20
"fmt"
21
21
"strconv"
22
+ "strings"
22
23
23
24
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
24
25
"github.com/NVIDIA/go-nvlib/pkg/nvml"
@@ -79,7 +80,6 @@ func (l *nvmllib) GetCommonEdits() (*cdi.ContainerEdits, error) {
79
80
// GetDeviceSpecsByID returns the CDI device specs for the GPU(s) represented by
80
81
// the provided identifiers, where an identifier is an index or UUID of a valid
81
82
// GPU device.
82
- // TODO: support identifiers that correspond to MIG devices
83
83
func (l * nvmllib ) GetDeviceSpecsByID (identifiers ... string ) ([]specs.Device , error ) {
84
84
for _ , id := range identifiers {
85
85
if id == "all" {
@@ -104,11 +104,7 @@ func (l *nvmllib) GetDeviceSpecsByID(identifiers ...string) ([]specs.Device, err
104
104
}
105
105
106
106
for i , nvmlDevice := range nvmlDevices {
107
- nvlibDevice , err := l .devicelib .NewDevice (nvmlDevice )
108
- if err != nil {
109
- return nil , fmt .Errorf ("failed to construct device: %w" , err )
110
- }
111
- deviceEdits , err := l .GetGPUDeviceEdits (nvlibDevice )
107
+ deviceEdits , err := l .getEditsForDevice (nvmlDevice )
112
108
if err != nil {
113
109
return nil , fmt .Errorf ("failed to get CDI device edits for identifier %q: %w" , identifiers [i ], err )
114
110
}
@@ -151,12 +147,64 @@ func (l *nvmllib) getNVMLDeviceByID(id string) (nvml.Device, error) {
151
147
}
152
148
153
149
if devID .isMigIndex () {
154
- return nil , fmt .Errorf ("MIG index is not supported" )
150
+ var gpuIdx , migIdx int
151
+ var parent nvml.Device
152
+ split := strings .SplitN (id , ":" , 2 )
153
+ if gpuIdx , err = strconv .Atoi (split [0 ]); err != nil {
154
+ return nil , fmt .Errorf ("failed to convert device index to an int: %w" , err )
155
+ }
156
+ if migIdx , err = strconv .Atoi (split [1 ]); err != nil {
157
+ return nil , fmt .Errorf ("failed to convert device index to an int: %w" , err )
158
+ }
159
+ if parent , err = l .nvmllib .DeviceGetHandleByIndex (gpuIdx ); err != nvml .SUCCESS {
160
+ return nil , fmt .Errorf ("failed to get parent device handle: %w" , err )
161
+ }
162
+ return parent .GetMigDeviceHandleByIndex (migIdx )
155
163
}
156
164
157
165
return nil , fmt .Errorf ("identifier is not a valid UUID or index: %q" , id )
158
166
}
159
167
168
+ func (l * nvmllib ) getEditsForDevice (nvmlDevice nvml.Device ) (* cdi.ContainerEdits , error ) {
169
+ mig , err := nvmlDevice .IsMigDeviceHandle ()
170
+ if err != nvml .SUCCESS {
171
+ return nil , fmt .Errorf ("failed to determine if device handle is a MIG device: %w" , err )
172
+ }
173
+ if mig {
174
+ return l .getEditsForMIGDevice (nvmlDevice )
175
+ }
176
+ return l .getEditsForGPUDevice (nvmlDevice )
177
+ }
178
+
179
+ func (l * nvmllib ) getEditsForGPUDevice (nvmlDevice nvml.Device ) (* cdi.ContainerEdits , error ) {
180
+ nvlibDevice , err := l .devicelib .NewDevice (nvmlDevice )
181
+ if err != nil {
182
+ return nil , fmt .Errorf ("failed to construct device: %w" , err )
183
+ }
184
+ deviceEdits , err := l .GetGPUDeviceEdits (nvlibDevice )
185
+ if err != nil {
186
+ return nil , fmt .Errorf ("failed to get GPU device edits: %w" , err )
187
+ }
188
+
189
+ return deviceEdits , nil
190
+ }
191
+
192
+ func (l * nvmllib ) getEditsForMIGDevice (nvmlDevice nvml.Device ) (* cdi.ContainerEdits , error ) {
193
+ nvmlParentDevice , ret := nvmlDevice .GetDeviceHandleFromMigDeviceHandle ()
194
+ if ret != nvml .SUCCESS {
195
+ return nil , fmt .Errorf ("failed to get parent device handle: %w" , ret )
196
+ }
197
+ nvlibMigDevice , err := l .devicelib .NewMigDevice (nvmlDevice )
198
+ if err != nil {
199
+ return nil , fmt .Errorf ("failed to construct device: %w" , err )
200
+ }
201
+ nvlibParentDevice , err := l .devicelib .NewDevice (nvmlParentDevice )
202
+ if err != nil {
203
+ return nil , fmt .Errorf ("failed to construct parent device: %w" , err )
204
+ }
205
+ return l .GetMIGDeviceEdits (nvlibParentDevice , nvlibMigDevice )
206
+ }
207
+
160
208
func (l * nvmllib ) getGPUDeviceSpecs () ([]specs.Device , error ) {
161
209
var deviceSpecs []specs.Device
162
210
err := l .devicelib .VisitDevices (func (i int , d device.Device ) error {
0 commit comments