diff --git a/sdv/single_table/utils.py b/sdv/single_table/utils.py index c11138c33..1f6459735 100644 --- a/sdv/single_table/utils.py +++ b/sdv/single_table/utils.py @@ -103,7 +103,7 @@ def handle_sampling_error(output_file_path, sampling_error): ) if error_msg: - raise type(sampling_error)(error_msg + '\n' + str(sampling_error)) + raise type(sampling_error)(error_msg) from sampling_error raise sampling_error diff --git a/tests/unit/single_table/test_base.py b/tests/unit/single_table/test_base.py index a0f45cbc8..e188e6d39 100644 --- a/tests/unit/single_table/test_base.py +++ b/tests/unit/single_table/test_base.py @@ -1546,15 +1546,17 @@ def test__sample_with_progress_bar_without_output_filepath(self): instance._fitted = True expected_message = re.escape( 'Error: Sampling terminated. No results were saved due to unspecified ' - '"output_file_path".\nMocked Error' + '"output_file_path".' ) instance._sample_in_batches.side_effect = RuntimeError('Mocked Error') # Run and Assert - with pytest.raises(RuntimeError, match=expected_message): + with pytest.raises(RuntimeError, match=expected_message) as exception: BaseSingleTableSynthesizer._sample_with_progress_bar( instance, output_file_path=None, num_rows=10 ) + assert isinstance(exception.value.__cause__, RuntimeError) + assert 'Mocked Error' in str(exception.value.__cause__) @patch('sdv.single_table.base.datetime') def test_sample(self, mock_datetime, caplog): diff --git a/tests/unit/single_table/test_utils.py b/tests/unit/single_table/test_utils.py index dbf84cdaa..1d98e10c6 100644 --- a/tests/unit/single_table/test_utils.py +++ b/tests/unit/single_table/test_utils.py @@ -215,10 +215,13 @@ def test_unflatten_dict(): def test_handle_sampling_error_temp_file(): """Test that an error is raised when temp dir is ``False``.""" # Run and Assert - error_msg = 'Error: Sampling terminated. Partial results are stored in test.csv.\nTest error' - with pytest.raises(ValueError, match=error_msg): + error_msg = 'Error: Sampling terminated. Partial results are stored in test.csv.' + with pytest.raises(ValueError, match=error_msg) as exception: handle_sampling_error('test.csv', ValueError('Test error')) + assert isinstance(exception.value.__cause__, ValueError) + assert 'Test error' in str(exception.value.__cause__) + def test_handle_sampling_error_false_temp_file_none_output_file(): """Test the ``handle_sampling_error`` function. @@ -228,9 +231,12 @@ def test_handle_sampling_error_false_temp_file_none_output_file(): """ # Run and Assert error_msg = 'Test error' - with pytest.raises(ValueError, match=error_msg): + with pytest.raises(ValueError) as exception: handle_sampling_error('test.csv', ValueError('Test error')) + assert isinstance(exception.value.__cause__, ValueError) + assert error_msg in str(exception.value.__cause__) + def test_handle_sampling_error_ignore(): """Test that the error is raised if the error is the no rows error."""