Skip to content

Commit ca7892c

Browse files
Fix: PARSynthesizer not being able to conditionally sample with date time as context (#2347)
1 parent 0ddaf6a commit ca7892c

File tree

3 files changed

+133
-14
lines changed

3 files changed

+133
-14
lines changed

sdv/sequential/par.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import logging
55
import uuid
66
import warnings
7+
from copy import deepcopy
78

89
import numpy as np
910
import pandas as pd
@@ -74,15 +75,44 @@ def _get_context_metadata(self):
7475
if self._sequence_key:
7576
context_columns += self._sequence_key
7677

77-
for column in context_columns:
78-
context_columns_dict[column] = self.metadata.columns[column]
79-
8078
for column, column_metadata in self._extra_context_columns.items():
8179
context_columns_dict[column] = column_metadata
8280

81+
for column in context_columns:
82+
context_columns_dict[column] = self.metadata.columns[column]
83+
84+
context_columns_dict = self._update_context_column_dict(context_columns_dict)
8385
context_metadata_dict = {'columns': context_columns_dict}
8486
return SingleTableMetadata.load_from_dict(context_metadata_dict)
8587

88+
def _update_context_column_dict(self, context_columns_dict):
89+
"""Update context column dictionary based on available transformers.
90+
91+
Args:
92+
context_columns_dict (dict):
93+
Dictionary of context columns.
94+
95+
Returns:
96+
dict:
97+
Updated context column metadata.
98+
"""
99+
default_transformers_by_sdtype = deepcopy(self._data_processor._transformers_by_sdtype)
100+
for column in self.context_columns:
101+
column_metadata = self.metadata.columns[column]
102+
if default_transformers_by_sdtype.get(column_metadata['sdtype']):
103+
context_columns_dict[column] = {'sdtype': 'numerical'}
104+
105+
return context_columns_dict
106+
107+
def _get_context_columns_for_processing(self):
108+
columns_to_be_processed = []
109+
default_transformers_by_sdtype = deepcopy(self._data_processor._transformers_by_sdtype)
110+
for column in self.context_columns:
111+
if default_transformers_by_sdtype.get(self.metadata.columns[column]['sdtype']):
112+
columns_to_be_processed.append(column)
113+
114+
return columns_to_be_processed
115+
86116
def __init__(
87117
self,
88118
metadata,
@@ -352,12 +382,6 @@ def _fit_context_model(self, transformed):
352382
context[constant_column] = 0
353383
context_metadata.add_column(constant_column, sdtype='numerical')
354384

355-
for column in self.context_columns:
356-
# Context datetime SDTypes for PAR have already been converted to float timestamp
357-
if context_metadata.columns[column]['sdtype'] == 'datetime':
358-
if pd.api.types.is_numeric_dtype(context[column]):
359-
context_metadata.update_column(column, sdtype='numerical')
360-
361385
with warnings.catch_warnings():
362386
warnings.filterwarnings('ignore', message=".*The 'SingleTableMetadata' is deprecated.*")
363387
self._context_synthesizer = GaussianCopulaSynthesizer(
@@ -540,9 +564,30 @@ def sample_sequential_columns(self, context_columns, sequence_length=None):
540564
set(context_columns.columns), set(self._context_synthesizer._model.columns)
541565
)
542566
)
567+
context_columns = self._process_context_columns(context_columns)
568+
543569
condition_columns = context_columns[condition_columns].to_dict('records')
544-
context = self._context_synthesizer.sample_from_conditions([
545-
Condition(conditions) for conditions in condition_columns
546-
])
570+
synthesizer_conditions = [Condition(conditions) for conditions in condition_columns]
571+
context = self._context_synthesizer.sample_from_conditions(synthesizer_conditions)
547572
context.update(context_columns)
548573
return self._sample(context, sequence_length)
574+
575+
def _process_context_columns(self, context_columns):
576+
"""Process context columns by applying appropriate transformations.
577+
578+
Args:
579+
context_columns (pandas.DataFrame):
580+
Context values containing potential columns for transformation.
581+
582+
Returns:
583+
context_columns (pandas.DataFrame):
584+
Updated context columns with transformed values.
585+
"""
586+
columns_to_be_processed = self._get_context_columns_for_processing()
587+
588+
if columns_to_be_processed:
589+
context_columns[columns_to_be_processed] = self._data_processor.transform(
590+
context_columns[columns_to_be_processed]
591+
)
592+
593+
return context_columns

tests/integration/sequential/test_par.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _get_par_data_and_metadata():
2020
'column2': ['b', 'a', 'a', 'c'],
2121
'entity': [1, 1, 2, 2],
2222
'context': ['a', 'a', 'b', 'b'],
23+
'context_date': [date, date, date, date],
2324
})
2425
metadata = Metadata.detect_from_dataframes({'table': data})
2526
metadata.update_column('entity', 'table', sdtype='id')
@@ -94,15 +95,21 @@ def test_column_after_date_complex():
9495
data, metadata = _get_par_data_and_metadata()
9596

9697
# Run
97-
model = PARSynthesizer(metadata=metadata, context_columns=['context'], epochs=1)
98+
model = PARSynthesizer(metadata=metadata, context_columns=['context', 'context_date'], epochs=1)
9899
model.fit(data)
99100
sampled = model.sample(2)
101+
context_columns = data[['context', 'context_date']]
102+
sample_with_conditions = model.sample_sequential_columns(context_columns=context_columns)
100103

101104
# Assert
102105
assert sampled.shape == data.shape
103106
assert (sampled.dtypes == data.dtypes).all()
104107
assert (sampled.notna().sum(axis=1) != 0).all()
105108

109+
expected_date = datetime.datetime.strptime('2020-01-01', '%Y-%m-%d')
110+
assert all(sample_with_conditions['context_date'] == expected_date)
111+
assert all(sample_with_conditions['context'].isin(['a', 'b']))
112+
106113

107114
def test_save_and_load(tmp_path):
108115
"""Test that synthesizers can be saved and loaded properly."""

tests/unit/sequential/test_par.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def test__fit_context_model_with_datetime_context_column(self, gaussian_copula_m
487487
par = PARSynthesizer(metadata, context_columns=['time'])
488488
initial_synthesizer = Mock()
489489
context_metadata = SingleTableMetadata.load_from_dict({
490-
'columns': {'time': {'sdtype': 'datetime'}, 'name': {'sdtype': 'id'}}
490+
'columns': {'time': {'sdtype': 'numerical'}, 'name': {'sdtype': 'id'}}
491491
})
492492
par._context_synthesizer = initial_synthesizer
493493
par._get_context_metadata = Mock()
@@ -934,6 +934,8 @@ def test_sample_sequential_columns(self):
934934
"""Test that the method uses the provided context columns to sample."""
935935
# Setup
936936
par = PARSynthesizer(metadata=self.get_metadata(), context_columns=['gender'])
937+
par._process_context_columns = Mock()
938+
par._process_context_columns.side_effect = lambda value: value
937939
par._context_synthesizer = Mock()
938940
par._context_synthesizer._model.columns = ['gender', 'extra_col']
939941
par._context_synthesizer.sample_from_conditions.return_value = pd.DataFrame({
@@ -970,6 +972,7 @@ def test_sample_sequential_columns(self):
970972
call_args, _ = par._sample.call_args
971973
pd.testing.assert_frame_equal(call_args[0], expected_call_arg)
972974
assert call_args[1] == 5
975+
par._process_context_columns.assert_called_once_with(context_columns)
973976

974977
def test_sample_sequential_columns_no_context_columns(self):
975978
"""Test that the method raises an error if the synthesizer has no context columns.
@@ -1083,3 +1086,67 @@ def test___init__with_unified_metadata(self):
10831086

10841087
with pytest.raises(InvalidMetadataError, match=error_msg):
10851088
PARSynthesizer(multi_metadata)
1089+
1090+
def test_sample_sequential_columns_with_datetime_values(self):
1091+
"""Test that the method converts datetime values to numerical space before sampling."""
1092+
# Setup
1093+
par = PARSynthesizer(metadata=self.get_metadata(), context_columns=['time'])
1094+
data = self.get_data()
1095+
par.fit(data)
1096+
1097+
par._context_synthesizer = Mock()
1098+
par._context_synthesizer._model.columns = ['time', 'extra_col']
1099+
par._context_synthesizer.sample_from_conditions.return_value = pd.DataFrame({
1100+
'id_col': ['A', 'A', 'A'],
1101+
'time': ['2020-01-01', '2020-01-02', '2020-01-03'],
1102+
'extra_col': [0, 1, 1],
1103+
})
1104+
par._sample = Mock()
1105+
context_columns = pd.DataFrame({
1106+
'id_col': ['ID-1', 'ID-2', 'ID-3'],
1107+
'time': ['2020-01-01', '2020-01-02', '2020-01-03'],
1108+
})
1109+
1110+
# Run
1111+
par.sample_sequential_columns(context_columns, 5)
1112+
1113+
# Assert
1114+
time_values = par._data_processor.transform(
1115+
pd.DataFrame({'time': ['2020-01-01', '2020-01-02', '2020-01-03']})
1116+
)
1117+
1118+
time_values = time_values['time'].tolist()
1119+
expected_conditions = [
1120+
Condition({'time': time_values[0]}),
1121+
Condition({'time': time_values[1]}),
1122+
Condition({'time': time_values[2]}),
1123+
]
1124+
call_args, _ = par._context_synthesizer.sample_from_conditions.call_args
1125+
1126+
assert len(call_args[0]) == len(expected_conditions)
1127+
for arg, expected in zip(call_args[0], expected_conditions):
1128+
assert arg.column_values == expected.column_values
1129+
assert arg.num_rows == expected.num_rows
1130+
1131+
def test__process_context_columns(self):
1132+
"""Test that the method processes specified columns using appropriate transformations."""
1133+
# Setup
1134+
instance = Mock()
1135+
instance._get_context_columns_for_processing.return_value = ['datetime_col']
1136+
instance._data_processor.transform.return_value = pd.DataFrame({'datetime_col': [1, 2, 3]})
1137+
1138+
context_columns = pd.DataFrame({
1139+
'datetime_col': ['2021-01-01', '2022-01-01', '2023-01-01'],
1140+
'col2': [4, 5, 6],
1141+
})
1142+
1143+
expected_result = pd.DataFrame({
1144+
'datetime_col': [1, 2, 3],
1145+
'col2': [4, 5, 6],
1146+
})
1147+
1148+
# Run
1149+
result = PARSynthesizer._process_context_columns(instance, context_columns)
1150+
1151+
# Assert
1152+
pd.testing.assert_frame_equal(result, expected_result)

0 commit comments

Comments
 (0)