@@ -26,6 +26,58 @@ def get_gpu_name():
2626 print (e )
2727
2828
29+ def get_cuda_version ():
30+ """Get CUDA version"""
31+ if sys .platform == 'win32' :
32+ raise NotImplementedError ("Implement this!" )
33+ elif sys .platform == 'linux' :
34+ path = '/usr/local/cuda/version.txt'
35+ if os .path .isfile (path ):
36+ with open (path , 'r' ) as f :
37+ data = f .read ().replace ('\n ' ,'' )
38+ return data
39+ else :
40+ return "No CUDA in this machine"
41+ elif sys .platform == 'darwin' :
42+ raise NotImplementedError ("Find a Mac with GPU and implement this!" )
43+ else :
44+ raise ValueError ("Not in Windows, Linux or Mac" )
45+
46+
47+ def get_cudnn_version ():
48+ """Get CUDNN version"""
49+ if sys .platform == 'win32' :
50+ raise NotImplementedError ("Implement this!" )
51+ elif sys .platform == 'linux' :
52+ candidates = ['/usr/include/x86_64-linux-gnu/cudnn_v[0-99].h' ,
53+ '/usr/local/cuda/include/cudnn.h' ,
54+ '/usr/include/cudnn.h' ]
55+ for c in candidates :
56+ file = glob .glob (c )
57+ if file : break
58+ if file :
59+ with open (file [0 ], 'r' ) as f :
60+ version = ''
61+ for line in f :
62+ if "#define CUDNN_MAJOR" in line :
63+ version = line .split ()[- 1 ]
64+ if "#define CUDNN_MINOR" in line :
65+ version += '.' + line .split ()[- 1 ]
66+ if "#define CUDNN_PATCHLEVEL" in line :
67+ version += '.' + line .split ()[- 1 ]
68+ if version :
69+ return version
70+ else :
71+ return "Cannot find CUDNN version"
72+ else :
73+ return "No CUDNN in this machine"
74+ elif sys .platform == 'darwin' :
75+ raise NotImplementedError ("Find a Mac with GPU and implement this!" )
76+ else :
77+ raise ValueError ("Not in Windows, Linux or Mac" )
78+
79+
80+
2981def read_batch (src ):
3082 '''Unpack the pickle files
3183 '''
0 commit comments