Skip to content

Commit c123a8d

Browse files
committed
Address comments
1 parent 27a2ca4 commit c123a8d

File tree

2 files changed

+42
-7
lines changed

2 files changed

+42
-7
lines changed

sdv/sequential/par.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -545,15 +545,28 @@ def sample_sequential_columns(self, context_columns, sequence_length=None):
545545
set(context_columns.columns), set(self._context_synthesizer._model.columns)
546546
)
547547
)
548-
549-
datetime_columns = self._get_context_datetime_columns()
550-
if datetime_columns:
551-
context_columns[datetime_columns] = self._data_processor.transform(
552-
context_columns[datetime_columns]
553-
)
548+
context_columns = self._process_datetime_columns_in_context_columns(context_columns)
554549

555550
condition_columns = context_columns[condition_columns].to_dict('records')
556551
synthesizer_conditions = [Condition(conditions) for conditions in condition_columns]
557552
context = self._context_synthesizer.sample_from_conditions(synthesizer_conditions)
558553
context.update(context_columns)
559554
return self._sample(context, sequence_length)
555+
556+
def _process_datetime_columns_in_context_columns(self, context_columns):
557+
"""Process datetime columns by transforming them using the data processor.
558+
559+
Args:
560+
context_columns (pandas.DataFrame):
561+
Context values containing potential datetime columns.
562+
563+
Returns:
564+
context_columns (pandas.DataFrame):
565+
Updated context columns with transformed datetime values.
566+
"""
567+
datetime_columns = self._get_context_datetime_columns()
568+
if datetime_columns:
569+
transformed = self._data_processor.transform(context_columns[datetime_columns])
570+
context_columns[datetime_columns] = transformed[datetime_columns]
571+
572+
return context_columns

tests/unit/sequential/test_par.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,7 @@ def test___init__with_unified_metadata(self):
10871087
PARSynthesizer(multi_metadata)
10881088

10891089
def test_sample_sequential_columns_with_datetime_values(self):
1090-
"""Test that the method uses converts datetime values to numerical space before sampling."""
1090+
"""Test that the method converts datetime values to numerical space before sampling."""
10911091
# Setup
10921092
par = PARSynthesizer(metadata=self.get_metadata(), context_columns=['time'])
10931093
data = self.get_data()
@@ -1126,3 +1126,25 @@ def test_sample_sequential_columns_with_datetime_values(self):
11261126
for arg, expected in zip(call_args[0], expected_conditions):
11271127
assert arg.column_values == expected.column_values
11281128
assert arg.num_rows == expected.num_rows
1129+
1130+
def test__process_datetime_columns_in_context_columns(self):
1131+
"""Test that the method converts datetime columns into numerical space."""
1132+
# Setup
1133+
instance = Mock()
1134+
instance._get_context_datetime_columns.return_value = ['Date']
1135+
instance._data_processor.transform.return_value = pd.DataFrame({'datetime_col': [1, 2, 3]})
1136+
instance._get_context_datetime_columns.return_value = ['datetime_col']
1137+
1138+
context_columns = pd.DataFrame({
1139+
'datetime_col': ['2021-01-01', '2022-01-01', '2023-01-01'],
1140+
'col2': [4, 5, 6],
1141+
})
1142+
1143+
# Run
1144+
result = PARSynthesizer._process_datetime_columns_in_context_columns(
1145+
instance, context_columns
1146+
)
1147+
1148+
# Assert
1149+
expected_result = pd.DataFrame({'datetime_col': [1, 2, 3], 'col2': [4, 5, 6]})
1150+
pd.testing.assert_frame_equal(result, expected_result)

0 commit comments

Comments
 (0)