@@ -99,8 +99,8 @@ def _update_default_transformers(self):
99
99
self ._data_processor ._update_transformers_by_sdtypes (sdtype , transformer )
100
100
101
101
def _check_metadata_updated (self ):
102
- if self .metadata ._updated :
103
- self .metadata ._updated = False
102
+ if self .metadata ._check_updated_flag () :
103
+ self .metadata ._reset_updated_flag ()
104
104
warnings .warn (
105
105
"We strongly recommend saving the metadata using 'save_to_json' for replicability"
106
106
' in future SDV versions.'
@@ -110,22 +110,22 @@ def __init__(
110
110
self , metadata , enforce_min_max_values = True , enforce_rounding = True , locales = ['en_US' ]
111
111
):
112
112
self ._validate_inputs (enforce_min_max_values , enforce_rounding )
113
- self .metadata = metadata
114
- self ._table_name = Metadata .DEFAULT_SINGLE_TABLE_NAME
115
113
if isinstance (metadata , Metadata ):
116
114
self ._table_name = metadata ._get_single_table_name ()
117
- self .metadata = metadata . _convert_to_single_table ()
118
- elif isinstance ( metadata , SingleTableMetadata ) :
115
+ self .metadata = metadata
116
+ else :
119
117
warnings .warn (DEPRECATION_MSG , FutureWarning )
118
+ self ._table_name = Metadata .DEFAULT_SINGLE_TABLE_NAME
119
+ self .metadata = Metadata .load_from_dict (metadata .to_dict (), self ._table_name )
120
+ self .metadata .tables [self ._table_name ]._updated = metadata ._updated
120
121
121
- self ._validate_inputs (enforce_min_max_values , enforce_rounding )
122
122
self .metadata .validate ()
123
123
self ._check_metadata_updated ()
124
124
self .enforce_min_max_values = enforce_min_max_values
125
125
self .enforce_rounding = enforce_rounding
126
126
self .locales = locales
127
127
self ._data_processor = DataProcessor (
128
- metadata = self .metadata ,
128
+ metadata = self .metadata . _convert_to_single_table () ,
129
129
enforce_rounding = self .enforce_rounding ,
130
130
enforce_min_max_values = self .enforce_min_max_values ,
131
131
locales = self .locales ,
@@ -158,7 +158,7 @@ def _validate_metadata(self, data):
158
158
"""Validate that the data follows the metadata."""
159
159
errors = []
160
160
try :
161
- self .metadata .validate_data (data )
161
+ self .metadata .validate_data ({ self . _table_name : data } )
162
162
except InvalidDataError as error :
163
163
errors += error .errors
164
164
@@ -183,10 +183,13 @@ def _validate(self, data):
183
183
"""
184
184
return []
185
185
186
+ def _get_table_metadata (self ):
187
+ return self .metadata .tables .get (self ._table_name , SingleTableMetadata ())
188
+
186
189
def _validate_primary_key (self , data ):
187
- primary_key = self .metadata .primary_key
190
+ primary_key = self ._get_table_metadata () .primary_key
188
191
is_int = primary_key and pd .api .types .is_integer_dtype (data [primary_key ])
189
- regex = self .metadata .columns .get (primary_key , {}).get ('regex_format' )
192
+ regex = self ._get_table_metadata () .columns .get (primary_key , {}).get ('regex_format' )
190
193
if is_int and regex :
191
194
possible_characters = get_possible_chars (regex , 1 )
192
195
if '0' in possible_characters :
@@ -225,8 +228,8 @@ def validate(self, data):
225
228
raise InvalidDataError (synthesizer_errors )
226
229
227
230
def _validate_transformers (self , column_name_to_transformer ):
228
- primary_and_alternate_keys = self .metadata ._get_primary_and_alternate_keys ()
229
- sequence_keys = self .metadata ._get_set_of_sequence_keys ()
231
+ primary_and_alternate_keys = self ._get_table_metadata () ._get_primary_and_alternate_keys ()
232
+ sequence_keys = self ._get_table_metadata () ._get_set_of_sequence_keys ()
230
233
keys = primary_and_alternate_keys | sequence_keys
231
234
for column , transformer in column_name_to_transformer .items ():
232
235
if transformer is None :
@@ -251,8 +254,9 @@ def _warn_quality_and_performance(self, column_name_to_transformer):
251
254
column_name_to_transformer (dict):
252
255
Dict mapping column names to transformers to be used for that column.
253
256
"""
257
+ table_metadata = self ._get_table_metadata ()
254
258
for column in column_name_to_transformer :
255
- sdtype = self . metadata .columns .get (column , {}).get ('sdtype' )
259
+ sdtype = table_metadata .columns .get (column , {}).get ('sdtype' )
256
260
if sdtype in {'categorical' , 'boolean' }:
257
261
warnings .warn (
258
262
f"Replacing the default transformer for column '{ column } ' "
@@ -304,8 +308,10 @@ def get_parameters(self):
304
308
305
309
def get_metadata (self ):
306
310
"""Return the ``Metadata`` for this synthesizer."""
307
- table_name = getattr (self , '_table_name' , None )
308
- return Metadata .load_from_dict (self .metadata .to_dict (), table_name )
311
+ if isinstance (self .metadata , SingleTableMetadata ):
312
+ table_name = getattr (self , '_table_name' , None )
313
+ return Metadata .load_from_dict (self .metadata .to_dict (), table_name )
314
+ return self .metadata
309
315
310
316
def load_custom_constraint_classes (self , filepath , class_names ):
311
317
"""Load a custom constraint class for the current synthesizer.
@@ -387,9 +393,10 @@ def get_transformers(self):
387
393
)
388
394
389
395
# Order the output to match metadata
396
+ table_metadata = self ._get_table_metadata ()
390
397
ordered_field_transformers = {
391
398
column_name : field_transformers .get (column_name )
392
- for column_name in self . metadata .columns
399
+ for column_name in table_metadata .columns
393
400
if column_name in field_transformers
394
401
}
395
402
0 commit comments