Skip to content

Commit 9c412b0

Browse files
authored
Make GaussianCopula report the correct distribution name in the case of a fallback (#2401)
1 parent fff948d commit 9c412b0

File tree

5 files changed

+39
-6
lines changed

5 files changed

+39
-6
lines changed

sdv/multi_table/utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def _drop_rows(data, metadata, drop_missing_values):
443443

444444
if data[table].empty:
445445
raise InvalidDataError([
446-
f"All references in table '{table}' are unknown and must be dropped."
446+
f"All references in table '{table}' are unknown and must be dropped. "
447447
'Try providing different data for this table.'
448448
])
449449

@@ -558,7 +558,7 @@ def _subsample_parent(
558558
parent_table = parent_table.drop(unreferenced_data_to_drop.index)
559559
if parent_table.empty:
560560
raise InvalidDataError([
561-
f"All references in table '{parent_primary_key}' are unknown and must be dropped."
561+
f"All references in table '{parent_primary_key}' are unknown and must be dropped. "
562562
'Try providing different data for this table.'
563563
])
564564

sdv/single_table/copulas.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -224,10 +224,11 @@ def get_learned_distributions(self):
224224
univariates = deepcopy(parameters['univariates'])
225225
learned_distributions = {}
226226
valid_columns = self._get_valid_columns_from_metadata(columns)
227+
distribution_names = {v.__name__: k for k, v in self._DISTRIBUTIONS.items()}
227228
for column, learned_params in zip(columns, univariates):
228229
if column in valid_columns:
229-
distribution = self.numerical_distributions.get(column, self.default_distribution)
230-
learned_params.pop('type')
230+
distribution_name = learned_params.pop('type').split('.')[-1]
231+
distribution = distribution_names[distribution_name]
231232
learned_distributions[column] = {
232233
'distribution': distribution,
233234
'learned_parameters': learned_params,

tests/integration/single_table/test_copulas.py

+32
Original file line numberDiff line numberDiff line change
@@ -522,3 +522,35 @@ def test_user_warning_for_unused_numerical_distribution():
522522
)
523523
with pytest.warns(UserWarning, match=message):
524524
synthesizer.fit(data)
525+
526+
527+
def test_get_learned_distributions_fallback_distribution():
528+
"""Test it when the fallback distribution is used GH#2394."""
529+
# Setup
530+
data = pd.DataFrame(data={'A': np.concatenate([np.zeros(29), np.ones(21)])})
531+
metadata = Metadata.load_from_dict({
532+
'tables': {
533+
'table': {
534+
'columns': {
535+
'A': {
536+
'sdtype': 'numerical',
537+
},
538+
},
539+
},
540+
},
541+
})
542+
543+
# Run
544+
synthesizer = GaussianCopulaSynthesizer(metadata, default_distribution='beta')
545+
synthesizer.fit(data)
546+
547+
# Assert
548+
assert synthesizer.get_learned_distributions() == {
549+
'A': {
550+
'distribution': 'norm',
551+
'learned_parameters': {
552+
'loc': 0.42,
553+
'scale': 0.4935585071701226,
554+
},
555+
},
556+
}

tests/unit/multi_table/test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def test_drop_unknown_references_drop_all_rows(mock_get_rows_to_drop):
469469
# Run and Assert
470470
expected_message = re.escape(
471471
'The provided data does not match the metadata:\n'
472-
"All references in table 'child' are unknown and must be dropped."
472+
"All references in table 'child' are unknown and must be dropped. "
473473
'Try providing different data for this table.'
474474
)
475475
with pytest.raises(InvalidDataError, match=expected_message):

tests/unit/utils/test_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def test_drop_unknown_references_drop_all_rows(mock_get_rows_to_drop):
351351
# Run and Assert
352352
expected_message = re.escape(
353353
'The provided data does not match the metadata:\n'
354-
"All references in table 'child' are unknown and must be dropped."
354+
"All references in table 'child' are unknown and must be dropped. "
355355
'Try providing different data for this table.'
356356
)
357357
with pytest.raises(InvalidDataError, match=expected_message):

0 commit comments

Comments
 (0)