Skip to content

Commit a32f339

Browse files
authored
Add CAG support to single table synthesizers (#2419)
1 parent c83ac58 commit a32f339

File tree

7 files changed

+650
-11
lines changed

7 files changed

+650
-11
lines changed

sdv/evaluation/single_table.py

-3
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@ def evaluate_quality(real_data, synthetic_data, metadata, verbose=True):
3131
metadata = metadata._convert_to_single_table()
3232

3333
quality_report = QualityReport()
34-
if isinstance(metadata, Metadata):
35-
metadata = metadata._convert_to_single_table()
36-
3734
quality_report.generate(real_data, synthetic_data, metadata.to_dict(), verbose)
3835
return quality_report
3936

sdv/single_table/base.py

+123
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import uuid
1111
import warnings
1212
from collections import defaultdict
13+
from copy import deepcopy
1314

1415
import cloudpickle
1516
import copulas
@@ -26,6 +27,7 @@
2627
generate_synthesizer_id,
2728
get_possible_chars,
2829
)
30+
from sdv.cag._errors import PatternNotMetError
2931
from sdv.constraints.errors import AggregateConstraintsError
3032
from sdv.data_processing.data_processor import DataProcessor
3133
from sdv.errors import (
@@ -610,6 +612,118 @@ class BaseSingleTableSynthesizer(BaseSynthesizer):
610612
for all single-table synthesizers.
611613
"""
612614

615+
def __init__(
616+
self,
617+
metadata,
618+
enforce_min_max_values=True,
619+
enforce_rounding=True,
620+
locales=['en_US'],
621+
):
622+
super().__init__(metadata, enforce_min_max_values, enforce_rounding, locales)
623+
self._chained_patterns = [] # chain of patterns used to preprocess the data
624+
self._reject_sampling_patterns = [] # patterns used only for reject sampling
625+
self._original_metadata = self.metadata
626+
627+
def add_cag(self, patterns):
628+
"""Add the list of constraint-augmented generation patterns to the synthesizer.
629+
630+
Args:
631+
patterns (list):
632+
A list of CAG patterns to apply to the synthesizer.
633+
"""
634+
for pattern in patterns:
635+
try:
636+
self.metadata = pattern.get_updated_metadata(self.metadata)
637+
self._chained_patterns.append(pattern)
638+
except PatternNotMetError as e:
639+
LOGGER.info(
640+
'Enforcing pattern %s using reject sampling.', pattern.__class__.__name__
641+
)
642+
643+
try:
644+
pattern.get_updated_metadata(self._original_metadata)
645+
self._reject_sampling_patterns.append(pattern)
646+
except PatternNotMetError:
647+
raise e
648+
649+
self._data_processor = DataProcessor(
650+
metadata=self.metadata._convert_to_single_table(),
651+
enforce_rounding=self.enforce_rounding,
652+
enforce_min_max_values=self.enforce_min_max_values,
653+
locales=self.locales,
654+
)
655+
656+
def get_cag(self):
657+
"""Get a list of constraint-augmented generation patterns applied to the synthesizer."""
658+
return deepcopy(self._chained_patterns + self._reject_sampling_patterns)
659+
660+
def get_metadata(self, version='original'):
661+
"""Get the metadata, either original or modified after applying CAG patterns.
662+
663+
Args:
664+
version (str, optional):
665+
The version of metadata to return, must be one of 'original' or 'modified'. If
666+
'original', will return the original metadata used to instantiate the
667+
synthesizer. If 'modified', will return the modified metadata after applying this
668+
synthesizer's CAG patterns. Defaults to 'original'.
669+
"""
670+
if version not in ('original', 'modified'):
671+
raise ValueError(
672+
f"Unrecognized version '{version}', please use 'original' or 'modified'."
673+
)
674+
675+
return self._original_metadata if version == 'original' else self.metadata
676+
677+
def _transform_helper(self, data):
678+
"""Validate and transform all CAG patterns during preprocessing.
679+
680+
Args:
681+
data (dict[str, pd.DataFrame]):
682+
The data dictionary.
683+
"""
684+
if self._fitted:
685+
for pattern in self._chained_patterns:
686+
data = pattern.transform(data)
687+
return data
688+
689+
metadata = self._original_metadata
690+
original_data = data
691+
for pattern in self._chained_patterns:
692+
pattern.fit(data, metadata)
693+
metadata = pattern.get_updated_metadata(metadata)
694+
data = pattern.transform(data)
695+
696+
for pattern in self._reject_sampling_patterns:
697+
pattern.fit(original_data, self._original_metadata)
698+
699+
return data
700+
701+
def preprocess(self, data):
702+
"""Transform the raw data to numerical space.
703+
704+
Args:
705+
data (pandas.DataFrame):
706+
The raw data to be transformed.
707+
708+
Returns:
709+
pandas.DataFrame:
710+
The preprocessed data.
711+
"""
712+
if self._fitted:
713+
warnings.warn(
714+
'This model has already been fitted. To use the new preprocessed data, '
715+
"please refit the model using 'fit' or 'fit_processed_data'."
716+
)
717+
718+
is_converted = self._store_and_convert_original_cols(data)
719+
data = self._transform_helper(data)
720+
preprocess_data = self._preprocess(data)
721+
722+
if is_converted:
723+
data.columns = self._original_columns
724+
725+
return preprocess_data
726+
613727
def _set_random_state(self, random_state):
614728
"""Set the random state of the model's random number generator.
615729
@@ -720,6 +834,15 @@ def _sample_rows(
720834
)
721835
sampled = pd.concat([sampled, raw_sampled[missing_cols]], axis=1)
722836

837+
for pattern in reversed(self._chained_patterns):
838+
sampled = pattern.reverse_transform(sampled)
839+
valid_rows = pattern.is_valid(sampled)
840+
sampled = sampled[valid_rows]
841+
842+
for pattern in reversed(self._reject_sampling_patterns):
843+
valid_rows = pattern.is_valid(sampled)
844+
sampled = sampled[valid_rows]
845+
723846
if previous_rows is not None:
724847
sampled = pd.concat([previous_rows, sampled], ignore_index=True)
725848
sampled = self._data_processor.filter_valid(sampled)

tests/integration/cag/test_fixed_combinations.py

+219
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from sdv.cag import FixedCombinations
77
from sdv.metadata import Metadata
8+
from sdv.single_table.copulas import GaussianCopulaSynthesizer
89
from tests.utils import run_pattern
910

1011

@@ -117,3 +118,221 @@ def test_fixed_null_combinations_with_multi_table():
117118
assert set(data.keys()) == set(reverse_transformed.keys())
118119
for table_name, table in data.items():
119120
pd.testing.assert_frame_equal(table, reverse_transformed[table_name])
121+
122+
123+
def test_fixed_combinations_multiple_patterns():
124+
"""Test that FixedCombinations pattern works with multiple patterns."""
125+
# Setup
126+
data = pd.DataFrame({
127+
'A': [1, 2, 3, 1, 2, 1],
128+
'B': [10, 20, 30, 10, 20, 10],
129+
'C': [100, 200, 300, 100, 200, 100],
130+
'D': [1000, 2000, 3000, 1000, 2000, 1000],
131+
})
132+
metadata = Metadata.load_from_dict({
133+
'columns': {
134+
'A': {'sdtype': 'categorical'},
135+
'B': {'sdtype': 'categorical'},
136+
'C': {'sdtype': 'categorical'},
137+
'D': {'sdtype': 'categorical'},
138+
}
139+
})
140+
pattern1 = FixedCombinations(['A', 'B'])
141+
pattern2 = FixedCombinations(['C', 'D'])
142+
143+
# Run
144+
synthesizer = GaussianCopulaSynthesizer(metadata)
145+
synthesizer.add_cag(patterns=[pattern1, pattern2])
146+
synthesizer.fit(data)
147+
samples = synthesizer.sample(100)
148+
updated_metadata = synthesizer.get_metadata('modified')
149+
original_metadata = synthesizer.get_metadata('original')
150+
151+
# Assert
152+
expected_updated_metadata = Metadata.load_from_dict({
153+
'columns': {
154+
'A#B': {'sdtype': 'categorical'},
155+
'C#D': {'sdtype': 'categorical'},
156+
}
157+
}).to_dict()
158+
assert expected_updated_metadata == updated_metadata.to_dict()
159+
160+
assert original_metadata.to_dict() == metadata.to_dict()
161+
162+
# Get unique combinations from original data
163+
original_ab_combos = set(zip(data['A'], data['B']))
164+
original_cd_combos = set(zip(data['C'], data['D']))
165+
166+
# Get unique combinations from synthetic data
167+
synthetic_ab_combos = set(zip(samples['A'], samples['B']))
168+
synthetic_cd_combos = set(zip(samples['C'], samples['D']))
169+
170+
# Assert combinations match
171+
assert original_ab_combos == synthetic_ab_combos
172+
assert original_cd_combos == synthetic_cd_combos
173+
174+
175+
def test_fixed_combinations_multiple_patterns_reject_sampling():
176+
"""Test that FixedCombinations pattern works with multiple patterns and reject sampling."""
177+
# Setup
178+
data = pd.DataFrame({
179+
'A': [1, 2, 3, 1, 2, 1],
180+
'B': [10, 20, 30, 10, 20, 10],
181+
'C': [100, 200, 300, 100, 200, 100],
182+
})
183+
metadata = Metadata.load_from_dict({
184+
'columns': {
185+
'A': {'sdtype': 'categorical'},
186+
'B': {'sdtype': 'categorical'},
187+
'C': {'sdtype': 'categorical'},
188+
}
189+
})
190+
pattern1 = FixedCombinations(['A', 'B'])
191+
pattern2 = FixedCombinations(['A', 'C'])
192+
193+
# Run
194+
synthesizer = GaussianCopulaSynthesizer(metadata)
195+
synthesizer.add_cag(patterns=[pattern1, pattern2])
196+
synthesizer.fit(data)
197+
samples = synthesizer.sample(100)
198+
updated_metadata = synthesizer.get_metadata('modified')
199+
original_metadata = synthesizer.get_metadata('original')
200+
201+
# Assert
202+
expected_updated_metadata = Metadata.load_from_dict({
203+
'columns': {
204+
'A#B': {'sdtype': 'categorical'},
205+
'C': {'sdtype': 'categorical'},
206+
}
207+
}).to_dict()
208+
assert expected_updated_metadata == updated_metadata.to_dict()
209+
210+
assert original_metadata.to_dict() == metadata.to_dict()
211+
212+
# Get unique combinations from original data
213+
original_ab_combos = set(zip(data['A'], data['B']))
214+
original_ac_combos = set(zip(data['A'], data['C']))
215+
216+
# Get unique combinations from synthetic data
217+
synthetic_ab_combos = set(zip(samples['A'], samples['B']))
218+
synthetic_ac_combos = set(zip(samples['A'], samples['C']))
219+
220+
# Assert combinations match
221+
assert original_ab_combos == synthetic_ab_combos
222+
assert original_ac_combos == synthetic_ac_combos
223+
224+
225+
def test_fixed_combinations_multiple_patterns_three_patterns():
226+
"""Test that FixedCombinations pattern works with multiple patterns."""
227+
# Setup
228+
data = pd.DataFrame({
229+
'A': [1, 2, 3, 1, 2, 1],
230+
'B': [10, 20, 30, 10, 20, 10],
231+
'C': [100, 200, 300, 100, 200, 100],
232+
'D': [1000, 2000, 3000, 1000, 2000, 1000],
233+
})
234+
metadata = Metadata.load_from_dict({
235+
'columns': {
236+
'A': {'sdtype': 'categorical'},
237+
'B': {'sdtype': 'categorical'},
238+
'C': {'sdtype': 'categorical'},
239+
'D': {'sdtype': 'categorical'},
240+
}
241+
})
242+
pattern1 = FixedCombinations(['A', 'B'])
243+
pattern2 = FixedCombinations(['C', 'D'])
244+
pattern3 = FixedCombinations(['A', 'C'])
245+
246+
# Run
247+
synthesizer = GaussianCopulaSynthesizer(metadata)
248+
synthesizer.add_cag(patterns=[pattern1, pattern2, pattern3])
249+
synthesizer.fit(data)
250+
samples = synthesizer.sample(100)
251+
updated_metadata = synthesizer.get_metadata('modified')
252+
original_metadata = synthesizer.get_metadata('original')
253+
254+
# Assert
255+
expected_updated_metadata = Metadata.load_from_dict({
256+
'columns': {
257+
'A#B': {'sdtype': 'categorical'},
258+
'C#D': {'sdtype': 'categorical'},
259+
}
260+
}).to_dict()
261+
assert expected_updated_metadata == updated_metadata.to_dict()
262+
263+
assert original_metadata.to_dict() == metadata.to_dict()
264+
265+
# Get unique combinations from original data
266+
original_ab_combos = set(zip(data['A'], data['B']))
267+
original_cd_combos = set(zip(data['C'], data['D']))
268+
original_ac_combos = set(zip(data['A'], data['C']))
269+
270+
# Get unique combinations from synthetic data
271+
synthetic_ab_combos = set(zip(samples['A'], samples['B']))
272+
synthetic_cd_combos = set(zip(samples['C'], samples['D']))
273+
synthetic_ac_combos = set(zip(samples['A'], samples['C']))
274+
275+
# Assert combinations match
276+
assert original_ab_combos == synthetic_ab_combos
277+
assert original_cd_combos == synthetic_cd_combos
278+
assert original_ac_combos == synthetic_ac_combos
279+
280+
281+
def test_fixed_combinations_multiple_patterns_three_patterns_reject_sampling():
282+
"""Test that FixedCombinations pattern works with multiple patterns.
283+
284+
Test that when the second pattern in the chain fails, the third pattern still works.
285+
"""
286+
# Setup
287+
data = pd.DataFrame({
288+
'A': [1, 2, 3, 1, 2, 1],
289+
'B': [10, 20, 30, 10, 20, 10],
290+
'C': [100, 200, 300, 100, 200, 100],
291+
'D': [1000, 2000, 3000, 1000, 2000, 1000],
292+
})
293+
metadata = Metadata.load_from_dict({
294+
'columns': {
295+
'A': {'sdtype': 'categorical'},
296+
'B': {'sdtype': 'categorical'},
297+
'C': {'sdtype': 'categorical'},
298+
'D': {'sdtype': 'categorical'},
299+
}
300+
})
301+
pattern1 = FixedCombinations(['A', 'B'])
302+
pattern2 = FixedCombinations(['C', 'D'])
303+
pattern3 = FixedCombinations(['A', 'C'])
304+
305+
# Run
306+
synthesizer = GaussianCopulaSynthesizer(metadata)
307+
synthesizer.add_cag(patterns=[pattern1, pattern3, pattern2])
308+
synthesizer.fit(data)
309+
samples = synthesizer.sample(100)
310+
updated_metadata = synthesizer.get_metadata('modified')
311+
original_metadata = synthesizer.get_metadata('original')
312+
313+
# Assert
314+
expected_updated_metadata = Metadata.load_from_dict({
315+
'columns': {
316+
'A#B': {'sdtype': 'categorical'},
317+
'C#D': {'sdtype': 'categorical'},
318+
}
319+
}).to_dict()
320+
321+
assert expected_updated_metadata == updated_metadata.to_dict()
322+
323+
assert original_metadata.to_dict() == metadata.to_dict()
324+
325+
# Get unique combinations from original data
326+
original_ab_combos = set(zip(data['A'], data['B']))
327+
original_cd_combos = set(zip(data['C'], data['D']))
328+
original_ac_combos = set(zip(data['A'], data['C']))
329+
330+
# Get unique combinations from synthetic data
331+
synthetic_ab_combos = set(zip(samples['A'], samples['B']))
332+
synthetic_cd_combos = set(zip(samples['C'], samples['D']))
333+
synthetic_ac_combos = set(zip(samples['A'], samples['C']))
334+
335+
# Assert combinations match
336+
assert original_ab_combos == synthetic_ab_combos
337+
assert original_cd_combos == synthetic_cd_combos
338+
assert original_ac_combos == synthetic_ac_combos

0 commit comments

Comments
 (0)