Skip to content

Commit 80a78e6

Browse files
author
Evan Lezar
committed
Merge branch 'device-namer' into 'main'
Refactor device namer See merge request nvidia/container-toolkit/container-toolkit!453
2 parents 32ec104 + 9f46c34 commit 80a78e6

File tree

6 files changed

+192
-15
lines changed

6 files changed

+192
-15
lines changed

pkg/nvcdi/full-gpu-nvml.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func (l *nvmllib) GetGPUDeviceSpecs(i int, d device.Device) (*specs.Device, erro
3939
return nil, fmt.Errorf("failed to get edits for device: %v", err)
4040
}
4141

42-
name, err := l.deviceNamer.GetDeviceName(i, d)
42+
name, err := l.deviceNamer.GetDeviceName(i, convert{d})
4343
if err != nil {
4444
return nil, fmt.Errorf("failed to get device name: %v", err)
4545
}

pkg/nvcdi/lib-csv.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,13 @@ func (l *csvlib) GetAllDeviceSpecs() ([]specs.Device, error) {
5353
return nil, fmt.Errorf("failed to create container edits for CSV files: %v", err)
5454
}
5555

56+
name, err := l.deviceNamer.GetDeviceName(0, uuidUnsupported{})
57+
if err != nil {
58+
return nil, fmt.Errorf("failed to get device name: %v", err)
59+
}
60+
5661
deviceSpec := specs.Device{
57-
Name: "all",
62+
Name: name,
5863
ContainerEdits: *e.ContainerEdits,
5964
}
6065
return []specs.Device{deviceSpec}, nil

pkg/nvcdi/mig-device-nvml.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func (l *nvmllib) GetMIGDeviceSpecs(i int, d device.Device, j int, mig device.Mi
3636
return nil, fmt.Errorf("failed to get edits for device: %v", err)
3737
}
3838

39-
name, err := l.deviceNamer.GetMigDeviceName(i, d, j, mig)
39+
name, err := l.deviceNamer.GetMigDeviceName(i, convert{d}, j, convert{mig})
4040
if err != nil {
4141
return nil, fmt.Errorf("failed to get device name: %v", err)
4242
}

pkg/nvcdi/namer.go

Lines changed: 45 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,21 @@
1717
package nvcdi
1818

1919
import (
20+
"errors"
2021
"fmt"
2122

22-
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvlib/device"
2323
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
2424
)
2525

26+
// UUIDer is an interface for getting UUIDs.
27+
type UUIDer interface {
28+
GetUUID() (string, error)
29+
}
30+
2631
// DeviceNamer is an interface for getting device names
2732
type DeviceNamer interface {
28-
GetDeviceName(int, device.Device) (string, error)
29-
GetMigDeviceName(int, device.Device, int, device.MigDevice) (string, error)
33+
GetDeviceName(int, UUIDer) (string, error)
34+
GetMigDeviceName(int, UUIDer, int, UUIDer) (string, error)
3035
}
3136

3237
// Supported device naming strategies
@@ -61,29 +66,57 @@ func NewDeviceNamer(strategy string) (DeviceNamer, error) {
6166
}
6267

6368
// GetDeviceName returns the name for the specified device based on the naming strategy
64-
func (s deviceNameIndex) GetDeviceName(i int, d device.Device) (string, error) {
69+
func (s deviceNameIndex) GetDeviceName(i int, _ UUIDer) (string, error) {
6570
return fmt.Sprintf("%s%d", s.gpuPrefix, i), nil
6671
}
6772

6873
// GetMigDeviceName returns the name for the specified device based on the naming strategy
69-
func (s deviceNameIndex) GetMigDeviceName(i int, d device.Device, j int, mig device.MigDevice) (string, error) {
74+
func (s deviceNameIndex) GetMigDeviceName(i int, _ UUIDer, j int, _ UUIDer) (string, error) {
7075
return fmt.Sprintf("%s%d:%d", s.migPrefix, i, j), nil
7176
}
7277

7378
// GetDeviceName returns the name for the specified device based on the naming strategy
74-
func (s deviceNameUUID) GetDeviceName(i int, d device.Device) (string, error) {
75-
uuid, ret := d.GetUUID()
76-
if ret != nvml.SUCCESS {
77-
return "", fmt.Errorf("failed to get device UUID: %v", ret)
79+
func (s deviceNameUUID) GetDeviceName(i int, d UUIDer) (string, error) {
80+
uuid, err := d.GetUUID()
81+
if err != nil {
82+
return "", fmt.Errorf("failed to get device UUID: %v", err)
7883
}
7984
return uuid, nil
8085
}
8186

8287
// GetMigDeviceName returns the name for the specified device based on the naming strategy
83-
func (s deviceNameUUID) GetMigDeviceName(i int, d device.Device, j int, mig device.MigDevice) (string, error) {
84-
uuid, ret := mig.GetUUID()
88+
func (s deviceNameUUID) GetMigDeviceName(i int, _ UUIDer, j int, mig UUIDer) (string, error) {
89+
uuid, err := mig.GetUUID()
90+
if err != nil {
91+
return "", fmt.Errorf("failed to get device UUID: %v", err)
92+
}
93+
return uuid, nil
94+
}
95+
96+
//go:generate moq -stub -out namer_nvml_mock.go . nvmlUUIDer
97+
type nvmlUUIDer interface {
98+
GetUUID() (string, nvml.Return)
99+
}
100+
101+
type convert struct {
102+
nvmlUUIDer
103+
}
104+
105+
type uuidUnsupported struct{}
106+
107+
func (m convert) GetUUID() (string, error) {
108+
if m.nvmlUUIDer == nil {
109+
return uuidUnsupported{}.GetUUID()
110+
}
111+
uuid, ret := m.nvmlUUIDer.GetUUID()
85112
if ret != nvml.SUCCESS {
86-
return "", fmt.Errorf("failed to get device UUID: %v", ret)
113+
return "", ret
87114
}
88115
return uuid, nil
89116
}
117+
118+
var errUUIDUnsupported = errors.New("GetUUID is not supported")
119+
120+
func (m uuidUnsupported) GetUUID() (string, error) {
121+
return "", errUUIDUnsupported
122+
}

pkg/nvcdi/namer_nvml_mock.go

Lines changed: 72 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/nvcdi/namer_test.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/**
2+
# Copyright (c) NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
**/
16+
17+
package nvcdi
18+
19+
import (
20+
"testing"
21+
22+
"github.com/stretchr/testify/require"
23+
"gitlab.com/nvidia/cloud-native/go-nvlib/pkg/nvml"
24+
)
25+
26+
func TestConvert(t *testing.T) {
27+
testCases := []struct {
28+
description string
29+
nvml nvmlUUIDer
30+
expectedError error
31+
expecteUUID string
32+
}{
33+
{
34+
description: "empty UUIDer returns error",
35+
expectedError: errUUIDUnsupported,
36+
expecteUUID: "",
37+
},
38+
{
39+
description: "nvmlUUIDer returns UUID",
40+
nvml: &nvmlUUIDerMock{
41+
GetUUIDFunc: func() (string, nvml.Return) {
42+
return "SOME_UUID", nvml.SUCCESS
43+
},
44+
},
45+
expectedError: nil,
46+
expecteUUID: "SOME_UUID",
47+
},
48+
{
49+
description: "nvmlUUIDer returns error",
50+
nvml: &nvmlUUIDerMock{
51+
GetUUIDFunc: func() (string, nvml.Return) {
52+
return "SOME_UUID", nvml.ERROR_UNKNOWN
53+
},
54+
},
55+
expectedError: nvml.ERROR_UNKNOWN,
56+
expecteUUID: "",
57+
},
58+
}
59+
60+
for _, tc := range testCases {
61+
t.Run(tc.description, func(t *testing.T) {
62+
uuid, err := convert{tc.nvml}.GetUUID()
63+
require.ErrorIs(t, err, tc.expectedError)
64+
require.Equal(t, tc.expecteUUID, uuid)
65+
})
66+
}
67+
}

0 commit comments

Comments
 (0)