Skip to content

Commit a199a9f

Browse files
committed
Allow default transformers to have specific arguments
1 parent c1acb42 commit a199a9f

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

Diff for: sdv/data_processing/data_processor.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Single table data processing."""
22

3+
import inspect
34
import json
45
import logging
56
import warnings
@@ -494,6 +495,15 @@ def create_regex_generator(self, column_name, sdtype, column_metadata, is_numeri
494495

495496
return transformer
496497

498+
@staticmethod
499+
def _get_transformer_kwargs(transformer):
500+
args = inspect.getfullargspec(transformer.__init__).args[1:]
501+
return {
502+
key: getattr(transformer, key)
503+
for key in args
504+
if key != 'model_missing_values' and hasattr(transformer, key)
505+
}
506+
497507
def _get_transformer_instance(self, sdtype, column_metadata):
498508
transformer = self._transformers_by_sdtype[sdtype]
499509
if isinstance(transformer, AnonymizedFaker):
@@ -512,7 +522,8 @@ def _get_transformer_instance(self, sdtype, column_metadata):
512522

513523
if kwargs and transformer is not None:
514524
transformer_class = transformer.__class__
515-
return transformer_class(**kwargs)
525+
default_transformer_kwargs = self._get_transformer_kwargs(transformer)
526+
return transformer_class(**{**default_transformer_kwargs, **kwargs})
516527

517528
return deepcopy(transformer)
518529

Diff for: tests/unit/data_processing/test_data_processor.py

+24
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,30 @@ def test__get_transformer_instance_kwargs(self):
10331033
assert isinstance(result, FloatFormatter)
10341034
assert result.computer_representation == 'Int32'
10351035

1036+
def test__get_transformer_instance_passes_kwargs_from_default(self):
1037+
"""Test the ``_get_transformer_instance`` uses the default transformers kwargs.
1038+
1039+
Test than when the default transformer has custom kwargs, they are also used
1040+
when creating a new instance of a transformer.
1041+
"""
1042+
# Setup
1043+
dp = DataProcessor(SingleTableMetadata())
1044+
dp._transformers_by_sdtype['numerical'] = FloatFormatter(
1045+
missing_value_replacement='random',
1046+
missing_value_generation='from_column',
1047+
learn_rounding_scheme=False
1048+
)
1049+
1050+
# Run
1051+
result = dp._get_transformer_instance(
1052+
'numerical', {'computer_representation': 'Int32'})
1053+
1054+
# Assert
1055+
assert isinstance(result, FloatFormatter)
1056+
assert result.missing_value_replacement == 'random'
1057+
assert result.missing_value_generation == 'from_column'
1058+
assert result.learn_rounding_scheme is False
1059+
10361060
@patch('sdv.data_processing.data_processor.LOGGER')
10371061
@patch('sdv.data_processing.data_processor.rdt')
10381062
def test__update_constraint_transformers(self, mock_rdt, mock_log):

0 commit comments

Comments
 (0)