Skip to content

Commit 946dbe7

Browse files
Merge pull request ilkarman#65 from vivgupt/patch-1
Update to utils.py for Macs having CUDA installed Thanks Vivek!!
2 parents 57841da + 8a5f13a commit 946dbe7

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

notebooks/common/utils.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,14 @@ def get_cuda_version():
3636
"""Get CUDA version"""
3737
if sys.platform == 'win32':
3838
raise NotImplementedError("Implement this!")
39-
elif sys.platform == 'linux':
39+
elif sys.platform == 'linux' or sys.platform == 'darwin':
4040
path = '/usr/local/cuda/version.txt'
4141
if os.path.isfile(path):
4242
with open(path, 'r') as f:
4343
data = f.read().replace('\n','')
4444
return data
4545
else:
4646
return "No CUDA in this machine"
47-
elif sys.platform == 'darwin':
48-
raise NotImplementedError("Find a Mac with GPU and implement this!")
4947
else:
5048
raise ValueError("Not in Windows, Linux or Mac")
5149

@@ -78,7 +76,27 @@ def get_cudnn_version():
7876
else:
7977
return "No CUDNN in this machine"
8078
elif sys.platform == 'darwin':
81-
raise NotImplementedError("Find a Mac with GPU and implement this!")
79+
candidates = ['/usr/local/cuda/include/cudnn.h',
80+
'/usr/include/cudnn.h']
81+
for c in candidates:
82+
file = glob.glob(c)
83+
if file: break
84+
if file:
85+
with open(file[0], 'r') as f:
86+
version = ''
87+
for line in f:
88+
if "#define CUDNN_MAJOR" in line:
89+
version = line.split()[-1]
90+
if "#define CUDNN_MINOR" in line:
91+
version += '.' + line.split()[-1]
92+
if "#define CUDNN_PATCHLEVEL" in line:
93+
version += '.' + line.split()[-1]
94+
if version:
95+
return version
96+
else:
97+
return "Cannot find CUDNN version"
98+
else:
99+
return "No CUDNN in this machine"
82100
else:
83101
raise ValueError("Not in Windows, Linux or Mac")
84102

@@ -351,4 +369,4 @@ def get_train_valid_test_split(n, train=0.7, valid=0.1, test=0.2, shuffle=False)
351369
test_size=test/other_split,
352370
shuffle=False)
353371
print("train:{} valid:{} test:{}".format(len(train_set), len(valid_set), len(test_set)))
354-
return train_set, valid_set, test_set
372+
return train_set, valid_set, test_set

0 commit comments

Comments
 (0)