Skip to content

Commit 0ac8308

Browse files
authored
Sampling with HMA Synthesizer generates many SingleTableMetadata deprecation warnings (#2332)
1 parent 974da0e commit 0ac8308

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

Diff for: sdv/multi_table/hma.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Hierarchical Modeling Algorithms."""
22

33
import logging
4+
import warnings
45
from collections import defaultdict
56
from copy import deepcopy
67

@@ -552,7 +553,12 @@ def _recreate_child_synthesizer(self, child_name, parent_name, parent_row):
552553
parameters = self._extract_parameters(parent_row, child_name, foreign_key)
553554
default_parameters = getattr(self, '_default_parameters', {}).get(child_name, {})
554555
table_meta = self.metadata.get_table_metadata(child_name)
555-
synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name])
556+
with warnings.catch_warnings():
557+
warnings.filterwarnings(
558+
'ignore', message=".*The 'SingleTableMetadata' is deprecated.*"
559+
)
560+
synthesizer = self._synthesizer(table_meta, **self._table_parameters[child_name])
561+
556562
synthesizer._set_parameters(parameters, default_parameters)
557563
else:
558564
synthesizer = self._null_child_synthesizers[f'__{child_name}__{foreign_key}']

Diff for: tests/integration/multi_table/test_hma.py

+18
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from sdv.datasets.local import load_csvs
2020
from sdv.errors import SamplingError, SynthesizerInputError, VersionError
2121
from sdv.evaluation.multi_table import evaluate_quality, get_column_pair_plot, get_column_plot
22+
from sdv.metadata import MultiTableMetadata
2223
from sdv.metadata.metadata import Metadata
2324
from sdv.multi_table import HMASynthesizer
2425
from tests.integration.single_table.custom_constraints import MyConstraint
@@ -2637,3 +2638,20 @@ def test_column_order():
26372638
assert table_1_column != list(data['table_1'].columns)
26382639
assert table_1_column == ['col_1', 'col_2', 'col_3']
26392640
assert list(synthetic_data['table_2'].columns) == ['col_A', 'col_B', 'col_C']
2641+
2642+
2643+
def test_no_deprecation_warning_single_table_metadata_sampling():
2644+
"""Test that no single-table metadata deprecation warning raises with `MultiTableMetadata`."""
2645+
# Setup
2646+
data, _ = download_demo(modality='multi_table', dataset_name='fake_hotels')
2647+
multi_metadata = MultiTableMetadata()
2648+
multi_metadata.detect_from_dataframes(data)
2649+
synthesizer = HMASynthesizer(multi_metadata)
2650+
synthesizer.fit(data)
2651+
2652+
# Run
2653+
with warnings.catch_warnings(record=True) as captured_warnings:
2654+
synthesizer.sample()
2655+
2656+
# Assert
2657+
assert len(captured_warnings) == 0

0 commit comments

Comments
 (0)