Skip to content

Commit 794a95c

Browse files
committed
Make functions more generalized
1 parent c123a8d commit 794a95c

File tree

2 files changed

+54
-30
lines changed

2 files changed

+54
-30
lines changed

sdv/sequential/par.py

Lines changed: 42 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import uuid
66
import warnings
7+
from copy import deepcopy
78

89
import numpy as np
910
import pandas as pd
@@ -74,25 +75,44 @@ def _get_context_metadata(self):
7475
if self._sequence_key:
7576
context_columns += self._sequence_key
7677

77-
for column in context_columns:
78-
context_columns_dict[column] = self.metadata.columns[column]
79-
# Context datetime SDTypes for PAR have already been converted to float timestamp
80-
if context_columns_dict[column]['sdtype'] == 'datetime':
81-
context_columns_dict[column] = {'sdtype': 'numerical'}
82-
8378
for column, column_metadata in self._extra_context_columns.items():
8479
context_columns_dict[column] = column_metadata
8580

81+
for column in context_columns:
82+
context_columns_dict[column] = self.metadata.columns[column]
83+
84+
context_columns_dict = self._update_context_column_dict(context_columns_dict)
8685
context_metadata_dict = {'columns': context_columns_dict}
8786
return SingleTableMetadata.load_from_dict(context_metadata_dict)
8887

89-
def _get_context_datetime_columns(self):
90-
datetime_columns = []
88+
def _update_context_column_dict(self, context_columns_dict):
89+
"""Update context column dictionary based on available transformers.
90+
91+
Args:
92+
context_columns_dict (dict):
93+
Dictionary of context columns.
94+
95+
Returns:
96+
dict:
97+
Updated context column metadata.
98+
"""
99+
default_transformers_by_sdtype = deepcopy(self._data_processor._transformers_by_sdtype)
91100
for column in self.context_columns:
92-
if self.metadata.columns[column]['sdtype'] == 'datetime':
93-
datetime_columns.append(column)
101+
column_metadata = self.metadata.columns[column]
102+
sdtype = column_metadata['sdtype']
103+
if default_transformers_by_sdtype.get(column_metadata['sdtype']):
104+
context_columns_dict[column] = {'sdtype': 'numerical'}
94105

95-
return datetime_columns
106+
return context_columns_dict
107+
108+
def _get_context_columns_for_processing(self):
109+
columns_to_be_processed = []
110+
default_transformers_by_sdtype = deepcopy(self._data_processor._transformers_by_sdtype)
111+
for column in self.context_columns:
112+
if default_transformers_by_sdtype.get(self.metadata.columns[column]['sdtype']):
113+
columns_to_be_processed.append(column)
114+
115+
return columns_to_be_processed
96116

97117
def __init__(
98118
self,
@@ -545,28 +565,30 @@ def sample_sequential_columns(self, context_columns, sequence_length=None):
545565
set(context_columns.columns), set(self._context_synthesizer._model.columns)
546566
)
547567
)
548-
context_columns = self._process_datetime_columns_in_context_columns(context_columns)
568+
context_columns = self._process_context_columns(context_columns)
549569

550570
condition_columns = context_columns[condition_columns].to_dict('records')
551571
synthesizer_conditions = [Condition(conditions) for conditions in condition_columns]
552572
context = self._context_synthesizer.sample_from_conditions(synthesizer_conditions)
553573
context.update(context_columns)
554574
return self._sample(context, sequence_length)
555575

556-
def _process_datetime_columns_in_context_columns(self, context_columns):
557-
"""Process datetime columns by transforming them using the data processor.
576+
def _process_context_columns(self, context_columns):
577+
"""Process context columns by applying appropriate transformations.
558578
559579
Args:
560580
context_columns (pandas.DataFrame):
561-
Context values containing potential datetime columns.
581+
Context values containing potential columns for transformation.
562582
563583
Returns:
564584
context_columns (pandas.DataFrame):
565-
Updated context columns with transformed datetime values.
585+
Updated context columns with transformed values.
566586
"""
567-
datetime_columns = self._get_context_datetime_columns()
568-
if datetime_columns:
569-
transformed = self._data_processor.transform(context_columns[datetime_columns])
570-
context_columns[datetime_columns] = transformed[datetime_columns]
587+
columns_to_be_processed = self._get_context_columns_for_processing()
588+
589+
if columns_to_be_processed:
590+
context_columns[columns_to_be_processed] = self._data_processor.transform(
591+
context_columns[columns_to_be_processed]
592+
)
571593

572594
return context_columns

tests/unit/sequential/test_par.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,8 @@ def test_sample_sequential_columns(self):
934934
"""Test that the method uses the provided context columns to sample."""
935935
# Setup
936936
par = PARSynthesizer(metadata=self.get_metadata(), context_columns=['gender'])
937-
par._get_context_datetime_columns = Mock(return_value=None)
937+
par._process_context_columns = Mock()
938+
par._process_context_columns.side_effect = lambda value: value
938939
par._context_synthesizer = Mock()
939940
par._context_synthesizer._model.columns = ['gender', 'extra_col']
940941
par._context_synthesizer.sample_from_conditions.return_value = pd.DataFrame({
@@ -971,7 +972,7 @@ def test_sample_sequential_columns(self):
971972
call_args, _ = par._sample.call_args
972973
pd.testing.assert_frame_equal(call_args[0], expected_call_arg)
973974
assert call_args[1] == 5
974-
par._get_context_datetime_columns.assert_called_once_with()
975+
par._process_context_columns.assert_called_once_with(context_columns)
975976

976977
def test_sample_sequential_columns_no_context_columns(self):
977978
"""Test that the method raises an error if the synthesizer has no context columns.
@@ -1127,24 +1128,25 @@ def test_sample_sequential_columns_with_datetime_values(self):
11271128
assert arg.column_values == expected.column_values
11281129
assert arg.num_rows == expected.num_rows
11291130

1130-
def test__process_datetime_columns_in_context_columns(self):
1131-
"""Test that the method converts datetime columns into numerical space."""
1131+
def test__process_context_columns(self):
1132+
"""Test that the method processes specified columns using appropriate transformations."""
11321133
# Setup
11331134
instance = Mock()
1134-
instance._get_context_datetime_columns.return_value = ['Date']
1135+
instance._get_context_columns_for_processing.return_value = ['datetime_col']
11351136
instance._data_processor.transform.return_value = pd.DataFrame({'datetime_col': [1, 2, 3]})
1136-
instance._get_context_datetime_columns.return_value = ['datetime_col']
11371137

11381138
context_columns = pd.DataFrame({
11391139
'datetime_col': ['2021-01-01', '2022-01-01', '2023-01-01'],
11401140
'col2': [4, 5, 6],
11411141
})
11421142

1143+
expected_result = pd.DataFrame({
1144+
'datetime_col': [1, 2, 3],
1145+
'col2': [4, 5, 6],
1146+
})
1147+
11431148
# Run
1144-
result = PARSynthesizer._process_datetime_columns_in_context_columns(
1145-
instance, context_columns
1146-
)
1149+
result = PARSynthesizer._process_context_columns(instance, context_columns)
11471150

11481151
# Assert
1149-
expected_result = pd.DataFrame({'datetime_col': [1, 2, 3], 'col2': [4, 5, 6]})
11501152
pd.testing.assert_frame_equal(result, expected_result)

0 commit comments

Comments
 (0)