Skip to content

Commit c83ac58

Browse files
authored
Store metadata as Metadata for BaseSynthesizer (#2422)
1 parent 6e18d29 commit c83ac58

File tree

20 files changed

+170
-128
lines changed

20 files changed

+170
-128
lines changed

sdv/metadata/multi_table.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def _validate_all_tables(self, data):
776776
self.tables[table_name].validate_data(table_data, table_sdtype_warnings)
777777

778778
except InvalidDataError as error:
779-
error_msg = f"Table: '{table_name}'"
779+
error_msg = f'Errors in {table_name}:'
780780
for _error in error.errors:
781781
error_msg += f'\nError: {_error}'
782782

sdv/sampling/independent_sampler.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ def _finalize(self, sampled_data):
106106
try:
107107
table_rows[name] = table_rows[name].dropna().astype(dtype)
108108
except ValueError as e:
109-
column_metadata = metadata.columns.get(name)
110-
sdtype = column_metadata.get('sdtype')
109+
sdtype = metadata.columns.get(name).get('sdtype')
111110
if sdtype not in dtypes_to_sdtype.values():
112111
LOGGER.info(
113112
f"The real data in '{table_name}' and column '{name}' was stored as "

sdv/sequential/par.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ class PARSynthesizer(LossValuesMixin, BaseSynthesizer):
3434
to be passed into PAR.
3535
3636
Args:
37-
metadata (sdv.metadata.SingleTableMetadata):
38-
Single table metadata representing the data that this synthesizer will be used for.
37+
metadata (sdv.metadata.Metadata):
38+
Metadata representing the data that this synthesizer will be used for.
3939
enforce_min_max_values (bool):
4040
Specify whether or not to clip the data returned by ``reverse_transform`` of
4141
the numerical transformer, ``FloatFormatter``, to the min and max values seen
@@ -78,8 +78,9 @@ def _get_context_metadata(self):
7878
for column, column_metadata in self._extra_context_columns.items():
7979
context_columns_dict[column] = column_metadata
8080

81+
table_metadata = self._get_table_metadata()
8182
for column in context_columns:
82-
context_columns_dict[column] = self.metadata.columns[column]
83+
context_columns_dict[column] = table_metadata.columns[column]
8384

8485
context_columns_dict = self._update_context_column_dict(context_columns_dict)
8586
context_metadata_dict = {'columns': context_columns_dict}
@@ -97,8 +98,9 @@ def _update_context_column_dict(self, context_columns_dict):
9798
Updated context column metadata.
9899
"""
99100
default_transformers_by_sdtype = deepcopy(self._data_processor._transformers_by_sdtype)
101+
table_metadata = self._get_table_metadata()
100102
for column in self.context_columns:
101-
column_metadata = self.metadata.columns[column]
103+
column_metadata = table_metadata.columns[column]
102104
if default_transformers_by_sdtype.get(column_metadata['sdtype']):
103105
context_columns_dict[column] = {'sdtype': 'numerical'}
104106

@@ -107,8 +109,9 @@ def _update_context_column_dict(self, context_columns_dict):
107109
def _get_context_columns_for_processing(self):
108110
columns_to_be_processed = []
109111
default_transformers_by_sdtype = deepcopy(self._data_processor._transformers_by_sdtype)
112+
table_metadata = self._get_table_metadata()
110113
for column in self.context_columns:
111-
if default_transformers_by_sdtype.get(self.metadata.columns[column]['sdtype']):
114+
if default_transformers_by_sdtype.get(table_metadata.columns[column]['sdtype']):
112115
columns_to_be_processed.append(column)
113116

114117
return columns_to_be_processed
@@ -136,15 +139,15 @@ def __init__(
136139
locales=locales,
137140
)
138141

139-
sequence_key = self.metadata.sequence_key
142+
sequence_key = self._get_table_metadata().sequence_key
140143
self._sequence_key = list(_cast_to_iterable(sequence_key)) if sequence_key else None
141144
if not self._sequence_key:
142145
raise SynthesizerInputError(
143146
'The PARSythesizer is designed for multi-sequence data, identifiable through a '
144147
'sequence key. Your metadata does not include a sequence key.'
145148
)
146149

147-
self._sequence_index = self.metadata.sequence_index
150+
self._sequence_index = self._get_table_metadata().sequence_index
148151
self.context_columns = context_columns or []
149152
self._validate_sequence_key_and_context_columns()
150153
self._extra_context_columns = {}
@@ -415,13 +418,14 @@ def _fit_sequence_columns(self, timeseries_data):
415418
)
416419
data_types = []
417420
context_types = []
421+
table_metadata = self._get_table_metadata()
418422
for field in self._output_columns:
419423
dtype = timeseries_data[field].dtype
420424
kind = dtype.kind
421425
if kind in ('i', 'f'):
422426
data_type = 'continuous'
423427
# Check if metadata overrides this data type
424-
if self.metadata.columns.get(field, {}).get('sdtype', None) == 'categorical':
428+
if table_metadata.columns.get(field, {}).get('sdtype', None) == 'categorical':
425429
data_type = 'categorical'
426430
elif kind in ('O', 'b'):
427431
data_type = 'categorical'

sdv/single_table/base.py

+24-17
Original file line numberDiff line numberDiff line change
@@ -99,8 +99,8 @@ def _update_default_transformers(self):
9999
self._data_processor._update_transformers_by_sdtypes(sdtype, transformer)
100100

101101
def _check_metadata_updated(self):
102-
if self.metadata._updated:
103-
self.metadata._updated = False
102+
if self.metadata._check_updated_flag():
103+
self.metadata._reset_updated_flag()
104104
warnings.warn(
105105
"We strongly recommend saving the metadata using 'save_to_json' for replicability"
106106
' in future SDV versions.'
@@ -110,22 +110,22 @@ def __init__(
110110
self, metadata, enforce_min_max_values=True, enforce_rounding=True, locales=['en_US']
111111
):
112112
self._validate_inputs(enforce_min_max_values, enforce_rounding)
113-
self.metadata = metadata
114-
self._table_name = Metadata.DEFAULT_SINGLE_TABLE_NAME
115113
if isinstance(metadata, Metadata):
116114
self._table_name = metadata._get_single_table_name()
117-
self.metadata = metadata._convert_to_single_table()
118-
elif isinstance(metadata, SingleTableMetadata):
115+
self.metadata = metadata
116+
else:
119117
warnings.warn(DEPRECATION_MSG, FutureWarning)
118+
self._table_name = Metadata.DEFAULT_SINGLE_TABLE_NAME
119+
self.metadata = Metadata.load_from_dict(metadata.to_dict(), self._table_name)
120+
self.metadata.tables[self._table_name]._updated = metadata._updated
120121

121-
self._validate_inputs(enforce_min_max_values, enforce_rounding)
122122
self.metadata.validate()
123123
self._check_metadata_updated()
124124
self.enforce_min_max_values = enforce_min_max_values
125125
self.enforce_rounding = enforce_rounding
126126
self.locales = locales
127127
self._data_processor = DataProcessor(
128-
metadata=self.metadata,
128+
metadata=self.metadata._convert_to_single_table(),
129129
enforce_rounding=self.enforce_rounding,
130130
enforce_min_max_values=self.enforce_min_max_values,
131131
locales=self.locales,
@@ -158,7 +158,7 @@ def _validate_metadata(self, data):
158158
"""Validate that the data follows the metadata."""
159159
errors = []
160160
try:
161-
self.metadata.validate_data(data)
161+
self.metadata.validate_data({self._table_name: data})
162162
except InvalidDataError as error:
163163
errors += error.errors
164164

@@ -183,10 +183,13 @@ def _validate(self, data):
183183
"""
184184
return []
185185

186+
def _get_table_metadata(self):
187+
return self.metadata.tables.get(self._table_name, SingleTableMetadata())
188+
186189
def _validate_primary_key(self, data):
187-
primary_key = self.metadata.primary_key
190+
primary_key = self._get_table_metadata().primary_key
188191
is_int = primary_key and pd.api.types.is_integer_dtype(data[primary_key])
189-
regex = self.metadata.columns.get(primary_key, {}).get('regex_format')
192+
regex = self._get_table_metadata().columns.get(primary_key, {}).get('regex_format')
190193
if is_int and regex:
191194
possible_characters = get_possible_chars(regex, 1)
192195
if '0' in possible_characters:
@@ -225,8 +228,8 @@ def validate(self, data):
225228
raise InvalidDataError(synthesizer_errors)
226229

227230
def _validate_transformers(self, column_name_to_transformer):
228-
primary_and_alternate_keys = self.metadata._get_primary_and_alternate_keys()
229-
sequence_keys = self.metadata._get_set_of_sequence_keys()
231+
primary_and_alternate_keys = self._get_table_metadata()._get_primary_and_alternate_keys()
232+
sequence_keys = self._get_table_metadata()._get_set_of_sequence_keys()
230233
keys = primary_and_alternate_keys | sequence_keys
231234
for column, transformer in column_name_to_transformer.items():
232235
if transformer is None:
@@ -251,8 +254,9 @@ def _warn_quality_and_performance(self, column_name_to_transformer):
251254
column_name_to_transformer (dict):
252255
Dict mapping column names to transformers to be used for that column.
253256
"""
257+
table_metadata = self._get_table_metadata()
254258
for column in column_name_to_transformer:
255-
sdtype = self.metadata.columns.get(column, {}).get('sdtype')
259+
sdtype = table_metadata.columns.get(column, {}).get('sdtype')
256260
if sdtype in {'categorical', 'boolean'}:
257261
warnings.warn(
258262
f"Replacing the default transformer for column '{column}' "
@@ -304,8 +308,10 @@ def get_parameters(self):
304308

305309
def get_metadata(self):
306310
"""Return the ``Metadata`` for this synthesizer."""
307-
table_name = getattr(self, '_table_name', None)
308-
return Metadata.load_from_dict(self.metadata.to_dict(), table_name)
311+
if isinstance(self.metadata, SingleTableMetadata):
312+
table_name = getattr(self, '_table_name', None)
313+
return Metadata.load_from_dict(self.metadata.to_dict(), table_name)
314+
return self.metadata
309315

310316
def load_custom_constraint_classes(self, filepath, class_names):
311317
"""Load a custom constraint class for the current synthesizer.
@@ -387,9 +393,10 @@ def get_transformers(self):
387393
)
388394

389395
# Order the output to match metadata
396+
table_metadata = self._get_table_metadata()
390397
ordered_field_transformers = {
391398
column_name: field_transformers.get(column_name)
392-
for column_name in self.metadata.columns
399+
for column_name in table_metadata.columns
393400
if column_name in field_transformers
394401
}
395402

sdv/single_table/copulagan.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,10 @@ def __init__(
164164
cuda=cuda,
165165
)
166166

167-
validate_numerical_distributions(numerical_distributions, self.metadata.columns)
167+
validate_numerical_distributions(
168+
numerical_distributions,
169+
self._get_table_metadata().columns,
170+
)
168171
self.numerical_distributions = numerical_distributions or {}
169172
self.default_distribution = default_distribution or 'beta'
170173

@@ -177,7 +180,7 @@ def __init__(
177180
}
178181

179182
def _create_gaussian_normalizer_config(self, processed_data):
180-
columns = self.metadata.columns
183+
columns = self._get_table_metadata().columns
181184
transformers = {}
182185
sdtypes = {}
183186
for column in processed_data.columns:

sdv/single_table/copulas.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,10 @@ def __init__(
110110
enforce_rounding=enforce_rounding,
111111
locales=locales,
112112
)
113-
validate_numerical_distributions(numerical_distributions, self.metadata.columns)
113+
validate_numerical_distributions(
114+
numerical_distributions,
115+
self._get_table_metadata().columns,
116+
)
114117

115118
self.default_distribution = default_distribution or 'beta'
116119
self._default_distribution = self.get_distribution_class(self.default_distribution)
@@ -192,8 +195,9 @@ def _sample(self, num_rows, conditions=None):
192195

193196
def _get_valid_columns_from_metadata(self, columns):
194197
valid_columns = []
198+
table_metadata = self._get_table_metadata()
195199
for column in columns:
196-
for valid_column in self.metadata.columns:
200+
for valid_column in table_metadata.columns:
197201
if column.startswith(valid_column):
198202
valid_columns.append(column)
199203
break

sdv/single_table/utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77

88
from sdv.errors import SynthesizerInputError
9+
from sdv.metadata import Metadata
910

1011
DISABLE_TMP_FILE = 'disable'
1112
IGNORED_DICT_KEYS = ['fitted', 'distribution', 'type']
@@ -19,7 +20,7 @@ def detect_discrete_columns(metadata, data, transformers):
1920
discrete.
2021
2122
Args:
22-
metadata (sdv.metadata.SingleTableMetadata):
23+
metadata (sdv.metadata.Metadata):
2324
Metadata that belongs to the given ``data``.
2425
2526
data (pandas.DataFrame):
@@ -33,6 +34,9 @@ def detect_discrete_columns(metadata, data, transformers):
3334
discrete_columns (list):
3435
A list of discrete columns to be used with some of ``sdv`` synthesizers.
3536
"""
37+
if isinstance(metadata, Metadata):
38+
metadata = metadata._convert_to_single_table()
39+
3640
discrete_columns = []
3741
for column in data.columns:
3842
if column in metadata.columns:

tests/integration/metadata/test_metadata.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -570,7 +570,7 @@ def test_single_table_compatibility(tmp_path):
570570
loaded_synthesizer = GaussianCopulaSynthesizer.load(model_path)
571571
assert isinstance(synthesizer, GaussianCopulaSynthesizer)
572572
assert loaded_synthesizer.get_info() == synthesizer.get_info()
573-
assert loaded_synthesizer.metadata.to_dict() == metadata.to_dict()
573+
assert loaded_synthesizer.metadata._convert_to_single_table().to_dict() == metadata.to_dict()
574574
loaded_sample = loaded_synthesizer.sample(10)
575575
synthesizer.validate(loaded_sample)
576576

tests/integration/sequential/test_par.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_save_and_load(tmp_path):
124124

125125
# Assert
126126
assert isinstance(loaded_instance, PARSynthesizer)
127-
assert metadata._convert_to_single_table().to_dict() == instance.metadata.to_dict()
127+
assert metadata.to_dict() == instance.metadata.to_dict()
128128

129129

130130
def test_synthesize_sequences(tmp_path):
@@ -193,7 +193,7 @@ def test_synthesize_sequences(tmp_path):
193193
assert model_path.exists()
194194
assert model_path.is_file()
195195
assert loaded_synthesizer.get_info() == synthesizer.get_info()
196-
assert loaded_synthesizer.metadata.to_dict() == metadata._convert_to_single_table().to_dict()
196+
assert loaded_synthesizer.metadata.to_dict() == metadata.to_dict()
197197
synthesizer.validate(synthetic_data)
198198
synthesizer.validate(custom_synthetic_data)
199199
synthesizer.validate(custom_synthetic_data_conditional)

tests/integration/single_table/test_base.py

+5-14
Original file line numberDiff line numberDiff line change
@@ -526,12 +526,7 @@ def test_save_and_load(tmp_path):
526526

527527
# Assert
528528
assert isinstance(loaded_instance, BaseSingleTableSynthesizer)
529-
assert loaded_instance.metadata.columns == {}
530-
assert loaded_instance.metadata.primary_key is None
531-
assert loaded_instance.metadata.alternate_keys == []
532-
assert loaded_instance.metadata.sequence_key is None
533-
assert loaded_instance.metadata.sequence_index is None
534-
assert loaded_instance.metadata._version == 'SINGLE_TABLE_V1'
529+
assert loaded_instance.metadata.tables == {}
535530
assert instance._synthesizer_id == loaded_instance._synthesizer_id
536531

537532

@@ -550,12 +545,8 @@ def test_save_and_load_no_id(tmp_path):
550545

551546
# Assert
552547
assert isinstance(loaded_instance, BaseSingleTableSynthesizer)
553-
assert loaded_instance.metadata.columns == {}
554-
assert loaded_instance.metadata.primary_key is None
555-
assert loaded_instance.metadata.alternate_keys == []
556-
assert loaded_instance.metadata.sequence_key is None
557-
assert loaded_instance.metadata.sequence_index is None
558-
assert loaded_instance.metadata._version == 'SINGLE_TABLE_V1'
548+
549+
assert loaded_instance.metadata.tables == {}
559550
assert hasattr(instance, '_synthesizer_id') is False
560551
assert hasattr(loaded_instance, '_synthesizer_id') is True
561552
assert isinstance(loaded_instance._synthesizer_id, str) is True
@@ -718,10 +709,10 @@ def test_metadata_updated_warning(method, kwargs):
718709
single_metadata = metadata._convert_to_single_table()
719710
single_metadata.__getattribute__(method)(**kwargs)
720711
with pytest.warns(UserWarning, match=expected_message):
721-
BaseSingleTableSynthesizer(single_metadata)
712+
instance = BaseSingleTableSynthesizer(single_metadata)
722713

723714
# Assert
724-
assert single_metadata._updated is False
715+
assert instance.metadata.tables['table']._updated is False
725716

726717

727718
def test_fit_raises_version_error():

tests/integration/single_table/test_copulas.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def test_synthesize_table_gaussian_copula(tmp_path):
106106
loaded_synthesizer = GaussianCopulaSynthesizer.load(model_path)
107107
assert isinstance(synthesizer, GaussianCopulaSynthesizer)
108108
assert loaded_synthesizer.get_info() == synthesizer.get_info()
109-
assert loaded_synthesizer.metadata.to_dict() == metadata._convert_to_single_table().to_dict()
109+
assert loaded_synthesizer.metadata.to_dict() == metadata.to_dict()
110110
loaded_synthesizer.sample(20)
111111

112112
# Assert - custom synthesizer
@@ -192,7 +192,7 @@ def test_adding_constraints(tmp_path):
192192

193193
assert isinstance(loaded_synthesizer, GaussianCopulaSynthesizer)
194194
assert loaded_synthesizer.get_info() == synthesizer.get_info()
195-
assert loaded_synthesizer.metadata.to_dict() == metadata._convert_to_single_table().to_dict()
195+
assert loaded_synthesizer.metadata.to_dict() == metadata.to_dict()
196196
sampled_data = loaded_synthesizer.sample(100)
197197
validation = sampled_data[sampled_data['has_rewards']]
198198
assert validation['amenities_fee'].sum() == 0.0

tests/integration/single_table/test_ctgan.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def test_synthesize_table_ctgan(tmp_path):
114114
loaded_synthesizer = CTGANSynthesizer.load(model_path)
115115
assert isinstance(synthesizer, CTGANSynthesizer)
116116
assert loaded_synthesizer.get_info() == synthesizer.get_info()
117-
assert loaded_synthesizer.metadata.to_dict() == metadata._convert_to_single_table().to_dict()
117+
assert loaded_synthesizer.metadata.to_dict() == metadata.to_dict()
118118
loaded_synthesizer.sample(20)
119119

120120
# Assert - custom synthesizer

0 commit comments

Comments
 (0)