@@ -487,7 +487,7 @@ def test__fit_context_model_with_datetime_context_column(self, gaussian_copula_m
487
487
par = PARSynthesizer (metadata , context_columns = ['time' ])
488
488
initial_synthesizer = Mock ()
489
489
context_metadata = SingleTableMetadata .load_from_dict ({
490
- 'columns' : {'time' : {'sdtype' : 'datetime ' }, 'name' : {'sdtype' : 'id' }}
490
+ 'columns' : {'time' : {'sdtype' : 'numerical ' }, 'name' : {'sdtype' : 'id' }}
491
491
})
492
492
par ._context_synthesizer = initial_synthesizer
493
493
par ._get_context_metadata = Mock ()
@@ -934,6 +934,8 @@ def test_sample_sequential_columns(self):
934
934
"""Test that the method uses the provided context columns to sample."""
935
935
# Setup
936
936
par = PARSynthesizer (metadata = self .get_metadata (), context_columns = ['gender' ])
937
+ par ._process_context_columns = Mock ()
938
+ par ._process_context_columns .side_effect = lambda value : value
937
939
par ._context_synthesizer = Mock ()
938
940
par ._context_synthesizer ._model .columns = ['gender' , 'extra_col' ]
939
941
par ._context_synthesizer .sample_from_conditions .return_value = pd .DataFrame ({
@@ -970,6 +972,7 @@ def test_sample_sequential_columns(self):
970
972
call_args , _ = par ._sample .call_args
971
973
pd .testing .assert_frame_equal (call_args [0 ], expected_call_arg )
972
974
assert call_args [1 ] == 5
975
+ par ._process_context_columns .assert_called_once_with (context_columns )
973
976
974
977
def test_sample_sequential_columns_no_context_columns (self ):
975
978
"""Test that the method raises an error if the synthesizer has no context columns.
@@ -1083,3 +1086,67 @@ def test___init__with_unified_metadata(self):
1083
1086
1084
1087
with pytest .raises (InvalidMetadataError , match = error_msg ):
1085
1088
PARSynthesizer (multi_metadata )
1089
+
1090
+ def test_sample_sequential_columns_with_datetime_values (self ):
1091
+ """Test that the method converts datetime values to numerical space before sampling."""
1092
+ # Setup
1093
+ par = PARSynthesizer (metadata = self .get_metadata (), context_columns = ['time' ])
1094
+ data = self .get_data ()
1095
+ par .fit (data )
1096
+
1097
+ par ._context_synthesizer = Mock ()
1098
+ par ._context_synthesizer ._model .columns = ['time' , 'extra_col' ]
1099
+ par ._context_synthesizer .sample_from_conditions .return_value = pd .DataFrame ({
1100
+ 'id_col' : ['A' , 'A' , 'A' ],
1101
+ 'time' : ['2020-01-01' , '2020-01-02' , '2020-01-03' ],
1102
+ 'extra_col' : [0 , 1 , 1 ],
1103
+ })
1104
+ par ._sample = Mock ()
1105
+ context_columns = pd .DataFrame ({
1106
+ 'id_col' : ['ID-1' , 'ID-2' , 'ID-3' ],
1107
+ 'time' : ['2020-01-01' , '2020-01-02' , '2020-01-03' ],
1108
+ })
1109
+
1110
+ # Run
1111
+ par .sample_sequential_columns (context_columns , 5 )
1112
+
1113
+ # Assert
1114
+ time_values = par ._data_processor .transform (
1115
+ pd .DataFrame ({'time' : ['2020-01-01' , '2020-01-02' , '2020-01-03' ]})
1116
+ )
1117
+
1118
+ time_values = time_values ['time' ].tolist ()
1119
+ expected_conditions = [
1120
+ Condition ({'time' : time_values [0 ]}),
1121
+ Condition ({'time' : time_values [1 ]}),
1122
+ Condition ({'time' : time_values [2 ]}),
1123
+ ]
1124
+ call_args , _ = par ._context_synthesizer .sample_from_conditions .call_args
1125
+
1126
+ assert len (call_args [0 ]) == len (expected_conditions )
1127
+ for arg , expected in zip (call_args [0 ], expected_conditions ):
1128
+ assert arg .column_values == expected .column_values
1129
+ assert arg .num_rows == expected .num_rows
1130
+
1131
+ def test__process_context_columns (self ):
1132
+ """Test that the method processes specified columns using appropriate transformations."""
1133
+ # Setup
1134
+ instance = Mock ()
1135
+ instance ._get_context_columns_for_processing .return_value = ['datetime_col' ]
1136
+ instance ._data_processor .transform .return_value = pd .DataFrame ({'datetime_col' : [1 , 2 , 3 ]})
1137
+
1138
+ context_columns = pd .DataFrame ({
1139
+ 'datetime_col' : ['2021-01-01' , '2022-01-01' , '2023-01-01' ],
1140
+ 'col2' : [4 , 5 , 6 ],
1141
+ })
1142
+
1143
+ expected_result = pd .DataFrame ({
1144
+ 'datetime_col' : [1 , 2 , 3 ],
1145
+ 'col2' : [4 , 5 , 6 ],
1146
+ })
1147
+
1148
+ # Run
1149
+ result = PARSynthesizer ._process_context_columns (instance , context_columns )
1150
+
1151
+ # Assert
1152
+ pd .testing .assert_frame_equal (result , expected_result )
0 commit comments