Skip to content

Commit 7749f05

Browse files
committed
Add msg
1 parent c2a1a6a commit 7749f05

File tree

4 files changed

+48
-2
lines changed

4 files changed

+48
-2
lines changed

sdv/multi_table/base.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -694,7 +694,14 @@ def load(cls, filepath):
694694
The loaded synthesizer.
695695
"""
696696
with open(filepath, 'rb') as f:
697-
synthesizer = cloudpickle.load(f)
697+
try:
698+
synthesizer = cloudpickle.load(f)
699+
except RuntimeError:
700+
raise SamplingError(
701+
'This synthesizer was created on a machine with GPU but the current machine is'
702+
' CPU-only. This feature is currently unsupported. We recommend sampling on '
703+
'the same GPU-enabled machine.'
704+
)
698705

699706
check_synthesizer_version(synthesizer)
700707
check_sdv_versions_and_warn(synthesizer)

sdv/single_table/base.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,14 @@ def load(cls, filepath):
492492
The loaded synthesizer.
493493
"""
494494
with open(filepath, 'rb') as f:
495-
synthesizer = cloudpickle.load(f)
495+
try:
496+
synthesizer = cloudpickle.load(f)
497+
except RuntimeError:
498+
raise SamplingError(
499+
'This synthesizer was created on a machine with GPU but the current machine is'
500+
' CPU-only. This feature is currently unsupported. We recommend sampling on '
501+
'the same GPU-enabled machine.'
502+
)
496503

497504
check_synthesizer_version(synthesizer)
498505
check_sdv_versions_and_warn(synthesizer)

tests/unit/multi_table/test_base.py

+16
Original file line numberDiff line numberDiff line change
@@ -1603,3 +1603,19 @@ def test_load(self, mock_file, cloudpickle_mock,
16031603
'SYNTHESIZER CLASS NAME': 'Mock',
16041604
'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
16051605
})
1606+
1607+
@patch('builtins.open')
1608+
@patch('sdv.multi_table.base.cloudpickle')
1609+
def test_load_runtime_error(self, cloudpickle_mock, mock_open):
1610+
"""Test that the synthesizer's load method errors with the correct message."""
1611+
# Setup
1612+
cloudpickle_mock.load.side_effect = RuntimeError
1613+
1614+
# Run and Assert
1615+
err_msg = re.escape(
1616+
'This synthesizer was created on a machine with GPU but the current machine is'
1617+
' CPU-only. This feature is currently unsupported. We recommend sampling on '
1618+
'the same GPU-enabled machine.'
1619+
)
1620+
with pytest.raises(SamplingError, match=err_msg):
1621+
BaseMultiTableSynthesizer.load('synth.pkl')

tests/unit/single_table/test_base.py

+16
Original file line numberDiff line numberDiff line change
@@ -1914,6 +1914,22 @@ def test_load_custom_constraint_classes(self):
19141914
['Custom', 'Constr', 'UpperPlus']
19151915
)
19161916

1917+
@patch('builtins.open')
1918+
@patch('sdv.single_table.base.cloudpickle')
1919+
def test_load_runtime_error(self, cloudpickle_mock, mock_open):
1920+
"""Test that the synthesizer's load method errors with the correct message."""
1921+
# Setup
1922+
cloudpickle_mock.load.side_effect = RuntimeError
1923+
1924+
# Run and Assert
1925+
err_msg = re.escape(
1926+
'This synthesizer was created on a machine with GPU but the current machine is'
1927+
' CPU-only. This feature is currently unsupported. We recommend sampling on '
1928+
'the same GPU-enabled machine.'
1929+
)
1930+
with pytest.raises(SamplingError, match=err_msg):
1931+
BaseSingleTableSynthesizer.load('synth.pkl')
1932+
19171933
def test_add_custom_constraint_class(self):
19181934
"""Test that this method calls the ``DataProcessor``'s method."""
19191935
# Setup

0 commit comments

Comments
 (0)