@@ -20,33 +20,38 @@ import (
2020 "log"
2121 "strings"
2222
23- "github.com/NVIDIA/gpu-monitoring-tools/bindings/go /nvml"
23+ "github.com/NVIDIA/go-nvml/pkg /nvml"
2424
2525 "golang.org/x/net/context"
2626 pluginapi "k8s.io/kubernetes/pkg/kubelet/apis/deviceplugin/v1beta1"
2727)
2828
29- func check (err error ) {
30- if err != nil {
31- log .Panicln ("Fatal:" , err )
29+ func check (ret nvml.Return ) bool {
30+ if ret != nvml .SUCCESS {
31+ log .Printf ("Error: %s" , nvml .ErrorString (ret ))
32+ return false
33+ }
34+ return true
35+ }
36+
37+ func checkAndPanic (ret nvml.Return ) {
38+ if ret != nvml .SUCCESS {
39+ log .Panicf ("Fatal: %s" , nvml .ErrorString (ret ))
3240 }
3341}
3442
3543// Instead of returning physical GPU devices, device plugin returns vGPU devices here.
3644// Total number of vGPU depends on the vGPU count user specify.
3745func getVGPUDevices (vGPUCount int ) []* pluginapi.Device {
38- n , err := nvml .GetDeviceCount ()
39- check (err )
40-
4146 var devs []* pluginapi.Device
42- for i := uint ( 0 ) ; i < n ; i ++ {
43- d , err := nvml .NewDevice (i )
44- check ( err )
47+ for i := 0 ; i < getDeviceCount () ; i ++ {
48+ d , ret := nvml .DeviceGetHandleByIndex (i )
49+ checkAndPanic ( ret )
4550
46- log .Printf ("Device Memory: %d, vGPU Count: %d" , uint ( * d . Memory ), vGPUCount )
51+ log .Printf ("Device Memory: %d, vGPU Count: %d" , getDeviceMemory ( d ), vGPUCount )
4752
48- for j := uint ( 0 ) ; j < uint ( vGPUCount ) ; j ++ {
49- vGPUDeviceID := getVGPUID (d . UUID , j )
53+ for j := 0 ; j < vGPUCount ; j ++ {
54+ vGPUDeviceID := getVGPUID (getDeviceUUID ( d ) , j )
5055 dev := pluginapi.Device {
5156 ID : vGPUDeviceID ,
5257 Health : pluginapi .Healthy ,
@@ -70,27 +75,35 @@ func getVGPUDevices(vGPUCount int) []*pluginapi.Device {
7075 return devs
7176}
7277
73- func getDeviceCount () uint {
74- n , err := nvml .GetDeviceCount ()
75- check ( err )
78+ func getDeviceCount () int {
79+ n , ret := nvml .DeviceGetCount ()
80+ checkAndPanic ( ret )
7681 return n
7782}
7883
79- func getPhysicalGPUDevices () []string {
80- n , err := nvml .GetDeviceCount ()
81- check (err )
84+ func getDeviceUUID (device nvml.Device ) string {
85+ uuid , ret := device .GetUUID ()
86+ checkAndPanic (ret )
87+ return uuid
88+ }
8289
90+ func getDeviceMemory (device nvml.Device ) uint64 {
91+ mem , ret := device .GetMemoryInfo ()
92+ checkAndPanic (ret )
93+ return mem .Total
94+ }
95+
96+ func getPhysicalGPUDevices () []string {
8397 var devs []string
84- for i := uint ( 0 ) ; i < n ; i ++ {
85- d , err := nvml .NewDevice (i )
86- check ( err )
87- devs = append (devs , d . UUID )
98+ for i := 0 ; i < getDeviceCount () ; i ++ {
99+ d , ret := nvml .DeviceGetHandleByIndex (i )
100+ checkAndPanic ( ret )
101+ devs = append (devs , getDeviceUUID ( d ) )
88102 }
89-
90103 return devs
91104}
92105
93- func getVGPUID (deviceID string , vGPUIndex uint ) string {
106+ func getVGPUID (deviceID string , vGPUIndex int ) string {
94107 return fmt .Sprintf ("%s-%d" , deviceID , vGPUIndex )
95108}
96109
@@ -118,11 +131,12 @@ func physicialDeviceExists(devs []string, id string) bool {
118131}
119132
120133func watchXIDs (ctx context.Context , devs []* pluginapi.Device , xids chan <- * pluginapi.Device ) {
121- eventSet := nvml .NewEventSet ()
122- defer nvml .DeleteEventSet (eventSet )
134+ eventSet , ret := nvml .EventSetCreate ()
135+ checkAndPanic (ret )
136+ defer nvml .EventSetFree (eventSet )
123137 var physicalDeviceIDs []string
124138
125- // We don't have to loop all virtual GPUs here. Only need to check physical CPUs .
139+ // We don't have to loop all virtual GPUs here. Only need to check physical GPUs .
126140 for _ , d := range devs {
127141 physicalDeviceID := getPhysicalDeviceID (d .ID )
128142 if physicialDeviceExists (physicalDeviceIDs , physicalDeviceID ) {
@@ -131,17 +145,16 @@ func watchXIDs(ctx context.Context, devs []*pluginapi.Device, xids chan<- *plugi
131145 physicalDeviceIDs = append (physicalDeviceIDs , physicalDeviceID )
132146
133147 log .Printf ("virtual id %s physical id %s" , d .ID , physicalDeviceID )
134- err := nvml .RegisterEventForDevice (eventSet , nvml .XidCriticalError , physicalDeviceID )
135- if err != nil && strings .HasSuffix (err .Error (), "Not Supported" ) {
136- log .Printf ("Warning: %s is too old to support healthchecking: %s. Marking it unhealthy." , physicalDeviceID , err )
137148
149+ device , ret := nvml .DeviceGetHandleByUUID (physicalDeviceID )
150+ checkAndPanic (ret )
151+ ret = nvml .DeviceRegisterEvents (device , nvml .EventTypeXidCriticalError , eventSet )
152+ if ret == nvml .ERROR_NOT_SUPPORTED {
153+ log .Printf ("Warning: %s is too old to support healthchecking: %s. Marking it unhealthy." , physicalDeviceID , nvml .ErrorString (ret ))
138154 xids <- d
139155 continue
140156 }
141-
142- if err != nil {
143- log .Panicln ("Fatal:" , err )
144- }
157+ checkAndPanic (ret )
145158 }
146159
147160 for {
@@ -151,30 +164,33 @@ func watchXIDs(ctx context.Context, devs []*pluginapi.Device, xids chan<- *plugi
151164 default :
152165 }
153166
154- e , err := nvml .WaitForEvent (eventSet , 5000 )
155- if err != nil && e .Etype != nvml .XidCriticalError {
167+ e , ret := nvml .EventSetWait (eventSet , 5000 )
168+ checkAndPanic (ret )
169+ if e .EventType != nvml .EventTypeXidCriticalError {
156170 continue
157171 }
158172
159173 // FIXME: formalize the full list and document it.
160174 // http://docs.nvidia.com/deploy/xid-errors/index.html#topic_4
161175 // Application errors: the GPU should still be healthy
162- if e .Edata == 31 || e .Edata == 43 || e .Edata == 45 {
176+ if e .EventData == 31 || e .EventData == 43 || e .EventData == 45 {
163177 continue
164178 }
165179
166- if e .UUID == nil || len (* e .UUID ) == 0 {
180+ uuid , ret := e .Device .GetUUID ()
181+ checkAndPanic (ret )
182+ if len (uuid ) == 0 {
167183 // All devices are unhealthy
168184 for _ , d := range devs {
169- log .Printf ("XidCriticalError: Xid=%d, All devices will go unhealthy." , e .Edata )
185+ log .Printf ("XidCriticalError: Xid=%d, All devices will go unhealthy." , e .EventData )
170186 xids <- d
171187 }
172188 continue
173189 }
174190
175191 for _ , d := range devs {
176- if d .ID == * e . UUID {
177- log .Printf ("XidCriticalError: Xid=%d on GPU=%s, the device will go unhealthy." , e .Edata , d .ID )
192+ if d .ID == uuid {
193+ log .Printf ("XidCriticalError: Xid=%d on GPU=%s, the device will go unhealthy." , e .EventData , d .ID )
178194 xids <- d
179195 }
180196 }
0 commit comments