Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion sdv/multi_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,7 +694,22 @@ def load(cls, filepath):
The loaded synthesizer.
"""
with open(filepath, 'rb') as f:
synthesizer = cloudpickle.load(f)
try:
synthesizer = cloudpickle.load(f)
except RuntimeError as e:
err_msg = (
'Attempting to deserialize object on a CUDA device but '
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
" please use torch.load with map_location=torch.device('cpu') "
'to map your storages to the CPU.'
)
if str(e) == err_msg:
raise SamplingError(
'This synthesizer was created on a machine with GPU but the current '
'machine is CPU-only. This feature is currently unsupported. We recommend'
' sampling on the same GPU-enabled machine.'
)
raise e

check_synthesizer_version(synthesizer)
check_sdv_versions_and_warn(synthesizer)
Expand Down
17 changes: 16 additions & 1 deletion sdv/single_table/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,22 @@ def load(cls, filepath):
The loaded synthesizer.
"""
with open(filepath, 'rb') as f:
synthesizer = cloudpickle.load(f)
try:
synthesizer = cloudpickle.load(f)
except RuntimeError as e:
err_msg = (
'Attempting to deserialize object on a CUDA device but '
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
" please use torch.load with map_location=torch.device('cpu') "
'to map your storages to the CPU.'
)
if str(e) == err_msg:
raise SamplingError(
'This synthesizer was created on a machine with GPU but the current '
'machine is CPU-only. This feature is currently unsupported. We recommend'
' sampling on the same GPU-enabled machine.'
)
raise e

check_synthesizer_version(synthesizer)
check_sdv_versions_and_warn(synthesizer)
Expand Down
32 changes: 32 additions & 0 deletions tests/unit/multi_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,3 +1603,35 @@ def test_load(self, mock_file, cloudpickle_mock,
'SYNTHESIZER CLASS NAME': 'Mock',
'SYNTHESIZER ID': 'BaseMultiTableSynthesizer_1.0.0_92aff11e9a5649d1a280990d1231a5f5',
})

@patch('builtins.open')
@patch('sdv.multi_table.base.cloudpickle')
def test_load_runtime_error(self, cloudpickle_mock, mock_open):
"""Test that the synthesizer's load method errors with the correct message."""
# Setup
cloudpickle_mock.load.side_effect = RuntimeError((
'Attempting to deserialize object on a CUDA device but '
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
" please use torch.load with map_location=torch.device('cpu') "
'to map your storages to the CPU.'
))

# Run and Assert
err_msg = re.escape(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
)
with pytest.raises(SamplingError, match=err_msg):
BaseMultiTableSynthesizer.load('synth.pkl')

@patch('builtins.open')
@patch('sdv.multi_table.base.cloudpickle')
def test_load_runtime_error_no_change(self, cloudpickle_mock, mock_open):
"""Test that the synthesizer's load method errors with the correct message."""
# Setup
cloudpickle_mock.load.side_effect = RuntimeError('Error')

# Run and Assert
with pytest.raises(RuntimeError, match='Error'):
BaseMultiTableSynthesizer.load('synth.pkl')
32 changes: 32 additions & 0 deletions tests/unit/single_table/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1914,6 +1914,38 @@ def test_load_custom_constraint_classes(self):
['Custom', 'Constr', 'UpperPlus']
)

@patch('builtins.open')
@patch('sdv.single_table.base.cloudpickle')
def test_load_runtime_error(self, cloudpickle_mock, mock_open):
"""Test that the synthesizer's load method errors with the correct message."""
# Setup
cloudpickle_mock.load.side_effect = RuntimeError((
'Attempting to deserialize object on a CUDA device but '
'torch.cuda.is_available() is False. If you are running on a CPU-only machine,'
" please use torch.load with map_location=torch.device('cpu') "
'to map your storages to the CPU.'
))

# Run and Assert
err_msg = re.escape(
'This synthesizer was created on a machine with GPU but the current machine is'
' CPU-only. This feature is currently unsupported. We recommend sampling on '
'the same GPU-enabled machine.'
)
with pytest.raises(SamplingError, match=err_msg):
BaseSingleTableSynthesizer.load('synth.pkl')

@patch('builtins.open')
@patch('sdv.single_table.base.cloudpickle')
def test_load_runtime_error_no_change(self, cloudpickle_mock, mock_open):
"""Test that the synthesizer's load method errors with the correct message."""
# Setup
cloudpickle_mock.load.side_effect = RuntimeError('Error')

# Run and Assert
with pytest.raises(RuntimeError, match='Error'):
BaseSingleTableSynthesizer.load('synth.pkl')

def test_add_custom_constraint_class(self):
"""Test that this method calls the ``DataProcessor``'s method."""
# Setup
Expand Down