Skip to content

Commit f483afb

Browse files
committed
Verify err msg matches
1 parent 7749f05 commit f483afb

File tree

4 files changed

+36
-12
lines changed

4 files changed

+36
-12
lines changed

sdv/multi_table/base.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -696,12 +696,19 @@ def load(cls, filepath):
696696
with open(filepath, 'rb') as f:
697697
try:
698698
synthesizer = cloudpickle.load(f)
699-
except RuntimeError:
700-
raise SamplingError(
701-
'This synthesizer was created on a machine with GPU but the current machine is'
702-
' CPU-only. This feature is currently unsupported. We recommend sampling on '
703-
'the same GPU-enabled machine.'
699+
except RuntimeError as e:
700+
err_msg = (
701+
'Attempting to deserialize object on a CUDA device but '
702+
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
703+
" please use torch.load with map_location=torch.device('cpu') "
704+
'to map your storages to the CPU.'
704705
)
706+
if str(e) == err_msg:
707+
raise SamplingError(
708+
'This synthesizer was created on a machine with GPU but the current machine is'
709+
' CPU-only. This feature is currently unsupported. We recommend sampling on '
710+
'the same GPU-enabled machine.'
711+
)
705712

706713
check_synthesizer_version(synthesizer)
707714
check_sdv_versions_and_warn(synthesizer)

sdv/single_table/base.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -494,12 +494,19 @@ def load(cls, filepath):
494494
with open(filepath, 'rb') as f:
495495
try:
496496
synthesizer = cloudpickle.load(f)
497-
except RuntimeError:
498-
raise SamplingError(
499-
'This synthesizer was created on a machine with GPU but the current machine is'
500-
' CPU-only. This feature is currently unsupported. We recommend sampling on '
501-
'the same GPU-enabled machine.'
497+
except RuntimeError as e:
498+
err_msg = (
499+
'Attempting to deserialize object on a CUDA device but '
500+
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
501+
" please use torch.load with map_location=torch.device('cpu') "
502+
'to map your storages to the CPU.'
502503
)
504+
if str(e) == err_msg:
505+
raise SamplingError(
506+
'This synthesizer was created on a machine with GPU but the current machine is'
507+
' CPU-only. This feature is currently unsupported. We recommend sampling on '
508+
'the same GPU-enabled machine.'
509+
)
503510

504511
check_synthesizer_version(synthesizer)
505512
check_sdv_versions_and_warn(synthesizer)

tests/unit/multi_table/test_base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1609,7 +1609,12 @@ def test_load(self, mock_file, cloudpickle_mock,
16091609
def test_load_runtime_error(self, cloudpickle_mock, mock_open):
16101610
"""Test that the synthesizer's load method errors with the correct message."""
16111611
# Setup
1612-
cloudpickle_mock.load.side_effect = RuntimeError
1612+
cloudpickle_mock.load.side_effect = RuntimeError((
1613+
'Attempting to deserialize object on a CUDA device but '
1614+
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
1615+
" please use torch.load with map_location=torch.device('cpu') "
1616+
'to map your storages to the CPU.'
1617+
))
16131618

16141619
# Run and Assert
16151620
err_msg = re.escape(

tests/unit/single_table/test_base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1919,7 +1919,12 @@ def test_load_custom_constraint_classes(self):
19191919
def test_load_runtime_error(self, cloudpickle_mock, mock_open):
19201920
"""Test that the synthesizer's load method errors with the correct message."""
19211921
# Setup
1922-
cloudpickle_mock.load.side_effect = RuntimeError
1922+
cloudpickle_mock.load.side_effect = RuntimeError((
1923+
'Attempting to deserialize object on a CUDA device but '
1924+
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
1925+
" please use torch.load with map_location=torch.device('cpu') "
1926+
'to map your storages to the CPU.'
1927+
))
19231928

19241929
# Run and Assert
19251930
err_msg = re.escape(

0 commit comments

Comments
 (0)