@@ -36,16 +36,14 @@ def get_cuda_version():
3636 """Get CUDA version"""
3737 if sys .platform == 'win32' :
3838 raise NotImplementedError ("Implement this!" )
39- elif sys .platform == 'linux' :
39+ elif sys .platform == 'linux' or sys . platform == 'darwin' :
4040 path = '/usr/local/cuda/version.txt'
4141 if os .path .isfile (path ):
4242 with open (path , 'r' ) as f :
4343 data = f .read ().replace ('\n ' ,'' )
4444 return data
4545 else :
4646 return "No CUDA in this machine"
47- elif sys .platform == 'darwin' :
48- raise NotImplementedError ("Find a Mac with GPU and implement this!" )
4947 else :
5048 raise ValueError ("Not in Windows, Linux or Mac" )
5149
@@ -78,7 +76,27 @@ def get_cudnn_version():
7876 else :
7977 return "No CUDNN in this machine"
8078 elif sys .platform == 'darwin' :
81- raise NotImplementedError ("Find a Mac with GPU and implement this!" )
79+ candidates = ['/usr/local/cuda/include/cudnn.h' ,
80+ '/usr/include/cudnn.h' ]
81+ for c in candidates :
82+ file = glob .glob (c )
83+ if file : break
84+ if file :
85+ with open (file [0 ], 'r' ) as f :
86+ version = ''
87+ for line in f :
88+ if "#define CUDNN_MAJOR" in line :
89+ version = line .split ()[- 1 ]
90+ if "#define CUDNN_MINOR" in line :
91+ version += '.' + line .split ()[- 1 ]
92+ if "#define CUDNN_PATCHLEVEL" in line :
93+ version += '.' + line .split ()[- 1 ]
94+ if version :
95+ return version
96+ else :
97+ return "Cannot find CUDNN version"
98+ else :
99+ return "No CUDNN in this machine"
82100 else :
83101 raise ValueError ("Not in Windows, Linux or Mac" )
84102
@@ -351,4 +369,4 @@ def get_train_valid_test_split(n, train=0.7, valid=0.1, test=0.2, shuffle=False)
351369 test_size = test / other_split ,
352370 shuffle = False )
353371 print ("train:{} valid:{} test:{}" .format (len (train_set ), len (valid_set ), len (test_set )))
354- return train_set , valid_set , test_set
372+ return train_set , valid_set , test_set
0 commit comments