@@ -36,16 +36,14 @@ def get_cuda_version():
36
36
"""Get CUDA version"""
37
37
if sys .platform == 'win32' :
38
38
raise NotImplementedError ("Implement this!" )
39
- elif sys .platform == 'linux' :
39
+ elif sys .platform == 'linux' or sys . platform == 'darwin' :
40
40
path = '/usr/local/cuda/version.txt'
41
41
if os .path .isfile (path ):
42
42
with open (path , 'r' ) as f :
43
43
data = f .read ().replace ('\n ' ,'' )
44
44
return data
45
45
else :
46
46
return "No CUDA in this machine"
47
- elif sys .platform == 'darwin' :
48
- raise NotImplementedError ("Find a Mac with GPU and implement this!" )
49
47
else :
50
48
raise ValueError ("Not in Windows, Linux or Mac" )
51
49
@@ -78,7 +76,27 @@ def get_cudnn_version():
78
76
else :
79
77
return "No CUDNN in this machine"
80
78
elif sys .platform == 'darwin' :
81
- raise NotImplementedError ("Find a Mac with GPU and implement this!" )
79
+ candidates = ['/usr/local/cuda/include/cudnn.h' ,
80
+ '/usr/include/cudnn.h' ]
81
+ for c in candidates :
82
+ file = glob .glob (c )
83
+ if file : break
84
+ if file :
85
+ with open (file [0 ], 'r' ) as f :
86
+ version = ''
87
+ for line in f :
88
+ if "#define CUDNN_MAJOR" in line :
89
+ version = line .split ()[- 1 ]
90
+ if "#define CUDNN_MINOR" in line :
91
+ version += '.' + line .split ()[- 1 ]
92
+ if "#define CUDNN_PATCHLEVEL" in line :
93
+ version += '.' + line .split ()[- 1 ]
94
+ if version :
95
+ return version
96
+ else :
97
+ return "Cannot find CUDNN version"
98
+ else :
99
+ return "No CUDNN in this machine"
82
100
else :
83
101
raise ValueError ("Not in Windows, Linux or Mac" )
84
102
@@ -351,4 +369,4 @@ def get_train_valid_test_split(n, train=0.7, valid=0.1, test=0.2, shuffle=False)
351
369
test_size = test / other_split ,
352
370
shuffle = False )
353
371
print ("train:{} valid:{} test:{}" .format (len (train_set ), len (valid_set ), len (test_set )))
354
- return train_set , valid_set , test_set
372
+ return train_set , valid_set , test_set
0 commit comments