@@ -20,33 +20,38 @@ import (
20
20
"log"
21
21
"strings"
22
22
23
- "github.com/NVIDIA/gpu-monitoring-tools/bindings/go /nvml"
23
+ "github.com/NVIDIA/go-nvml/pkg /nvml"
24
24
25
25
"golang.org/x/net/context"
26
26
pluginapi "k8s.io/kubernetes/pkg/kubelet/apis/deviceplugin/v1beta1"
27
27
)
28
28
29
- func check (err error ) {
30
- if err != nil {
31
- log .Panicln ("Fatal:" , err )
29
+ func check (ret nvml.Return ) bool {
30
+ if ret != nvml .SUCCESS {
31
+ log .Printf ("Error: %s" , nvml .ErrorString (ret ))
32
+ return false
33
+ }
34
+ return true
35
+ }
36
+
37
+ func checkAndPanic (ret nvml.Return ) {
38
+ if ret != nvml .SUCCESS {
39
+ log .Panicf ("Fatal: %s" , nvml .ErrorString (ret ))
32
40
}
33
41
}
34
42
35
43
// Instead of returning physical GPU devices, device plugin returns vGPU devices here.
36
44
// Total number of vGPU depends on the vGPU count user specify.
37
45
func getVGPUDevices (vGPUCount int ) []* pluginapi.Device {
38
- n , err := nvml .GetDeviceCount ()
39
- check (err )
40
-
41
46
var devs []* pluginapi.Device
42
- for i := uint ( 0 ) ; i < n ; i ++ {
43
- d , err := nvml .NewDevice (i )
44
- check ( err )
47
+ for i := 0 ; i < getDeviceCount () ; i ++ {
48
+ d , ret := nvml .DeviceGetHandleByIndex (i )
49
+ checkAndPanic ( ret )
45
50
46
- log .Printf ("Device Memory: %d, vGPU Count: %d" , uint ( * d . Memory ), vGPUCount )
51
+ log .Printf ("Device Memory: %d, vGPU Count: %d" , getDeviceMemory ( d ), vGPUCount )
47
52
48
- for j := uint ( 0 ) ; j < uint ( vGPUCount ) ; j ++ {
49
- vGPUDeviceID := getVGPUID (d . UUID , j )
53
+ for j := 0 ; j < vGPUCount ; j ++ {
54
+ vGPUDeviceID := getVGPUID (getDeviceUUID ( d ) , j )
50
55
dev := pluginapi.Device {
51
56
ID : vGPUDeviceID ,
52
57
Health : pluginapi .Healthy ,
@@ -70,27 +75,35 @@ func getVGPUDevices(vGPUCount int) []*pluginapi.Device {
70
75
return devs
71
76
}
72
77
73
- func getDeviceCount () uint {
74
- n , err := nvml .GetDeviceCount ()
75
- check ( err )
78
+ func getDeviceCount () int {
79
+ n , ret := nvml .DeviceGetCount ()
80
+ checkAndPanic ( ret )
76
81
return n
77
82
}
78
83
79
- func getPhysicalGPUDevices () []string {
80
- n , err := nvml .GetDeviceCount ()
81
- check (err )
84
+ func getDeviceUUID (device nvml.Device ) string {
85
+ uuid , ret := device .GetUUID ()
86
+ checkAndPanic (ret )
87
+ return uuid
88
+ }
82
89
90
+ func getDeviceMemory (device nvml.Device ) uint64 {
91
+ mem , ret := device .GetMemoryInfo ()
92
+ checkAndPanic (ret )
93
+ return mem .Total
94
+ }
95
+
96
+ func getPhysicalGPUDevices () []string {
83
97
var devs []string
84
- for i := uint ( 0 ) ; i < n ; i ++ {
85
- d , err := nvml .NewDevice (i )
86
- check ( err )
87
- devs = append (devs , d . UUID )
98
+ for i := 0 ; i < getDeviceCount () ; i ++ {
99
+ d , ret := nvml .DeviceGetHandleByIndex (i )
100
+ checkAndPanic ( ret )
101
+ devs = append (devs , getDeviceUUID ( d ) )
88
102
}
89
-
90
103
return devs
91
104
}
92
105
93
- func getVGPUID (deviceID string , vGPUIndex uint ) string {
106
+ func getVGPUID (deviceID string , vGPUIndex int ) string {
94
107
return fmt .Sprintf ("%s-%d" , deviceID , vGPUIndex )
95
108
}
96
109
@@ -118,11 +131,12 @@ func physicialDeviceExists(devs []string, id string) bool {
118
131
}
119
132
120
133
func watchXIDs (ctx context.Context , devs []* pluginapi.Device , xids chan <- * pluginapi.Device ) {
121
- eventSet := nvml .NewEventSet ()
122
- defer nvml .DeleteEventSet (eventSet )
134
+ eventSet , ret := nvml .EventSetCreate ()
135
+ checkAndPanic (ret )
136
+ defer nvml .EventSetFree (eventSet )
123
137
var physicalDeviceIDs []string
124
138
125
- // We don't have to loop all virtual GPUs here. Only need to check physical CPUs .
139
+ // We don't have to loop all virtual GPUs here. Only need to check physical GPUs .
126
140
for _ , d := range devs {
127
141
physicalDeviceID := getPhysicalDeviceID (d .ID )
128
142
if physicialDeviceExists (physicalDeviceIDs , physicalDeviceID ) {
@@ -131,17 +145,16 @@ func watchXIDs(ctx context.Context, devs []*pluginapi.Device, xids chan<- *plugi
131
145
physicalDeviceIDs = append (physicalDeviceIDs , physicalDeviceID )
132
146
133
147
log .Printf ("virtual id %s physical id %s" , d .ID , physicalDeviceID )
134
- err := nvml .RegisterEventForDevice (eventSet , nvml .XidCriticalError , physicalDeviceID )
135
- if err != nil && strings .HasSuffix (err .Error (), "Not Supported" ) {
136
- log .Printf ("Warning: %s is too old to support healthchecking: %s. Marking it unhealthy." , physicalDeviceID , err )
137
148
149
+ device , ret := nvml .DeviceGetHandleByUUID (physicalDeviceID )
150
+ checkAndPanic (ret )
151
+ ret = nvml .DeviceRegisterEvents (device , nvml .EventTypeXidCriticalError , eventSet )
152
+ if ret == nvml .ERROR_NOT_SUPPORTED {
153
+ log .Printf ("Warning: %s is too old to support healthchecking: %s. Marking it unhealthy." , physicalDeviceID , nvml .ErrorString (ret ))
138
154
xids <- d
139
155
continue
140
156
}
141
-
142
- if err != nil {
143
- log .Panicln ("Fatal:" , err )
144
- }
157
+ checkAndPanic (ret )
145
158
}
146
159
147
160
for {
@@ -151,30 +164,33 @@ func watchXIDs(ctx context.Context, devs []*pluginapi.Device, xids chan<- *plugi
151
164
default :
152
165
}
153
166
154
- e , err := nvml .WaitForEvent (eventSet , 5000 )
155
- if err != nil && e .Etype != nvml .XidCriticalError {
167
+ e , ret := nvml .EventSetWait (eventSet , 5000 )
168
+ checkAndPanic (ret )
169
+ if e .EventType != nvml .EventTypeXidCriticalError {
156
170
continue
157
171
}
158
172
159
173
// FIXME: formalize the full list and document it.
160
174
// http://docs.nvidia.com/deploy/xid-errors/index.html#topic_4
161
175
// Application errors: the GPU should still be healthy
162
- if e .Edata == 31 || e .Edata == 43 || e .Edata == 45 {
176
+ if e .EventData == 31 || e .EventData == 43 || e .EventData == 45 {
163
177
continue
164
178
}
165
179
166
- if e .UUID == nil || len (* e .UUID ) == 0 {
180
+ uuid , ret := e .Device .GetUUID ()
181
+ checkAndPanic (ret )
182
+ if len (uuid ) == 0 {
167
183
// All devices are unhealthy
168
184
for _ , d := range devs {
169
- log .Printf ("XidCriticalError: Xid=%d, All devices will go unhealthy." , e .Edata )
185
+ log .Printf ("XidCriticalError: Xid=%d, All devices will go unhealthy." , e .EventData )
170
186
xids <- d
171
187
}
172
188
continue
173
189
}
174
190
175
191
for _ , d := range devs {
176
- if d .ID == * e . UUID {
177
- log .Printf ("XidCriticalError: Xid=%d on GPU=%s, the device will go unhealthy." , e .Edata , d .ID )
192
+ if d .ID == uuid {
193
+ log .Printf ("XidCriticalError: Xid=%d on GPU=%s, the device will go unhealthy." , e .EventData , d .ID )
178
194
xids <- d
179
195
}
180
196
}
0 commit comments