Skip to content

Commit 6e18d29

Browse files
authored
Add OneHotEncoding CAG (#2414)
1 parent 6a4cd48 commit 6e18d29

File tree

9 files changed

+372
-37
lines changed

9 files changed

+372
-37
lines changed

sdv/cag/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,12 @@
44
from sdv.cag.fixed_increments import FixedIncrements
55
from sdv.cag.inequality import Inequality
66
from sdv.cag.range import Range
7+
from sdv.cag.one_hot_encoding import OneHotEncoding
78

8-
__all__ = ('FixedCombinations', 'FixedIncrements', 'Inequality', 'Range')
9+
__all__ = (
10+
'FixedCombinations',
11+
'FixedIncrements',
12+
'Inequality',
13+
'Range',
14+
'OneHotEncoding',
15+
)

sdv/cag/_utils.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pandas as pd
23

34
from sdv.cag._errors import PatternNotMetError
45
from sdv.metadata import Metadata
@@ -93,9 +94,9 @@ def _remove_columns_from_metadata(metadata, table_name, columns_to_drop):
9394
return Metadata.load_from_dict(metadata)
9495

9596

96-
def _is_list_of_strings(values):
97-
"""Checks that a list contains all strings."""
98-
return isinstance(values, list) and all(isinstance(value, str) for value in values)
97+
def _is_list_of_type(values, type_to_check=str):
98+
"""Checks that 'values' is a list and all elements are of type 'type_to_check'."""
99+
return isinstance(values, list) and all(isinstance(value, type_to_check) for value in values)
99100

100101

101102
def _get_invalid_rows(valid):
@@ -117,3 +118,26 @@ def _get_invalid_rows(valid):
117118
remaining = len(invalid_rows) - 5
118119
invalid_rows_str = f'{first_five}, +{remaining} more'
119120
return invalid_rows_str
121+
122+
123+
def _get_is_valid_dict(data, table_name):
124+
"""Create a dictionary of True values for each table besides table_name.
125+
126+
Besides table_name, all rows of every other table are considered valid,
127+
so the boolean Series will be True for all rows of every other table.
128+
129+
Args:
130+
data (dict):
131+
The data.
132+
table_name (str):
133+
The name of the table to exclude from the dictionary.
134+
135+
Returns:
136+
dict:
137+
Dictionary of table names to boolean Series of True values.
138+
"""
139+
return {
140+
table: pd.Series(True, index=table_data.index)
141+
for table, table_data in data.items()
142+
if table != table_name
143+
}

sdv/cag/fixed_combinations.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from sdv._utils import _create_unique_name
99
from sdv.cag._errors import PatternNotMetError
1010
from sdv.cag._utils import (
11-
_is_list_of_strings,
11+
_get_is_valid_dict,
12+
_is_list_of_type,
1213
_validate_table_and_column_names,
1314
_validate_table_name_if_defined,
1415
)
@@ -41,7 +42,7 @@ class FixedCombinations(BasePattern):
4142

4243
def __init__(self, column_names, table_name=None):
4344
super().__init__()
44-
if not _is_list_of_strings(column_names):
45+
if not _is_list_of_type(column_names):
4546
raise ValueError('`column_names` must be a list of strings.')
4647

4748
if len(column_names) < 2:
@@ -192,11 +193,7 @@ def _reverse_transform(self, data):
192193
def _is_valid(self, data):
193194
"""Determine whether the data matches the pattern."""
194195
table_name = self._get_single_table_name(self.metadata)
195-
is_valid = {
196-
table: pd.Series(True, index=table_data.index)
197-
for table, table_data in data.items()
198-
if table != table_name
199-
}
196+
is_valid = _get_is_valid_dict(data, table_name)
200197
merged = data[table_name].merge(
201198
self._combinations, how='left', on=self.column_names, indicator=self._joint_column
202199
)

sdv/cag/fixed_increments.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from sdv.cag._errors import PatternNotMetError
77
from sdv.cag._utils import (
88
_get_invalid_rows,
9+
_get_is_valid_dict,
910
_remove_columns_from_metadata,
1011
_validate_table_and_column_names,
1112
_validate_table_name_if_defined,
@@ -173,11 +174,7 @@ def _is_valid(self, data):
173174
table names.
174175
"""
175176
table_name = self._get_single_table_name(self.metadata)
176-
is_valid = {
177-
table: pd.Series(True, index=table_data.index)
178-
for table, table_data in data.items()
179-
if table != table_name
180-
}
177+
is_valid = _get_is_valid_dict(data, table_name)
181178
valid = self._check_if_divisible(data, table_name, self.column_name, self.increment_value)
182179
is_valid[table_name] = valid
183180
return is_valid

sdv/cag/inequality.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
from sdv._utils import _convert_to_timedelta, _create_unique_name
77
from sdv.cag._errors import PatternNotMetError
8-
from sdv.cag._utils import _validate_table_and_column_names
8+
from sdv.cag._utils import (
9+
_get_is_valid_dict,
10+
_validate_table_and_column_names,
11+
)
912
from sdv.cag.base import BasePattern
1013
from sdv.constraints.utils import (
1114
cast_to_datetime64,
@@ -293,12 +296,7 @@ def _is_valid(self, data):
293296
Whether each row is valid.
294297
"""
295298
table_name = self._get_single_table_name(self.metadata)
296-
is_valid = {
297-
table: pd.Series(True, index=table_data.index)
298-
for table, table_data in data.items()
299-
if table != table_name
300-
}
301-
299+
is_valid = _get_is_valid_dict(data, table_name)
302300
table_data = data[table_name]
303301
low, high = self._get_data(table_data)
304302
if self._is_datetime and self._dtype == 'O':

sdv/cag/one_hot_encoding.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""One Hot Encoding CAG pattern."""
2+
3+
import numpy as np
4+
5+
from sdv.cag._errors import PatternNotMetError
6+
from sdv.cag._utils import (
7+
_get_invalid_rows,
8+
_get_is_valid_dict,
9+
_is_list_of_type,
10+
_validate_table_and_column_names,
11+
)
12+
from sdv.cag.base import BasePattern
13+
14+
15+
class OneHotEncoding(BasePattern):
16+
"""Ensure the appropriate columns are one hot encoded.
17+
18+
This constraint allows the user to specify a list of columns where each row
19+
is a one hot vector. During the reverse transform, the output of the model
20+
is transformed so that the column with the largest value is set to 1 while
21+
all other columns are set to 0.
22+
23+
Args:
24+
column_names (list[str]):
25+
Names of the columns containing one hot rows.
26+
table_name (str, optional):
27+
The name of the table that contains the columns. Optional if the
28+
data is only a single table. Defaults to None.
29+
"""
30+
31+
@staticmethod
32+
def _validate_init_inputs(column_names, table_name):
33+
if not _is_list_of_type(column_names):
34+
raise ValueError('`column_names` must be a list of strings.')
35+
36+
if table_name and not isinstance(table_name, str):
37+
raise ValueError('`table_name` must be a string or None.')
38+
39+
def __init__(self, column_names, table_name=None):
40+
super().__init__()
41+
self._validate_init_inputs(column_names, table_name)
42+
self._column_names = column_names
43+
self.table_name = table_name
44+
45+
def _validate_pattern_with_metadata(self, metadata):
46+
"""Validate the pattern is compatible with the provided metadata.
47+
48+
Validates that:
49+
- If no table_name is provided the metadata contains a single table
50+
- All input columns exist in the table in the metadata.
51+
52+
Args:
53+
metadata (sdv.metadata.Metadata):
54+
The metadata to validate against.
55+
56+
Raises:
57+
PatternNotMetError:
58+
If any of the validations fail.
59+
"""
60+
_validate_table_and_column_names(self.table_name, self._column_names, metadata)
61+
62+
def _get_valid_table_data(self, table_data):
63+
one_hot_data = table_data[self._column_names]
64+
65+
sum_one = one_hot_data.sum(axis=1) == 1.0
66+
max_one = one_hot_data.max(axis=1) == 1.0
67+
min_zero = one_hot_data.min(axis=1) == 0.0
68+
no_nans = ~one_hot_data.isna().any(axis=1)
69+
70+
return sum_one & max_one & min_zero & no_nans
71+
72+
def _validate_pattern_with_data(self, data, metadata):
73+
"""Validate the data is compatible with the pattern."""
74+
table_name = self._get_single_table_name(metadata)
75+
valid = self._get_valid_table_data(data[table_name])
76+
if not valid.all():
77+
invalid_rows_str = _get_invalid_rows(valid)
78+
raise PatternNotMetError(
79+
f'The one hot encoding requirement is not met for row indices: [{invalid_rows_str}]'
80+
)
81+
82+
def _fit(self, data, metadata):
83+
"""Fit the pattern.
84+
85+
Args:
86+
data (dict[str, pd.DataFrame]):
87+
Table data.
88+
metadata (sdv.metadata.Metadata):
89+
Metadata.
90+
"""
91+
pass
92+
93+
def _transform(self, data):
94+
"""Transform the data.
95+
96+
Args:
97+
data (dict[str, pd.DataFrame]):
98+
Table data.
99+
100+
Returns:
101+
dict[str, pd.DataFrame]:
102+
Transformed data.
103+
"""
104+
return data
105+
106+
def _reverse_transform(self, data):
107+
"""Reverse transform the table data.
108+
109+
Set the column with the largest value to one, set all other columns to zero.
110+
111+
Args:
112+
data (dict[str, pd.DataFrame]):
113+
Table data.
114+
115+
Returns:
116+
dict[str, pd.DataFrame]:
117+
Transformed data.
118+
"""
119+
table_name = self._get_single_table_name(self.metadata)
120+
table_data = data[table_name]
121+
one_hot_data = table_data[self._column_names]
122+
transformed_data = np.zeros_like(one_hot_data.to_numpy())
123+
max_category_indices = np.argmax(one_hot_data.to_numpy(), axis=1)
124+
transformed_data[np.arange(len(one_hot_data)), max_category_indices] = 1
125+
table_data[self._column_names] = transformed_data
126+
data[table_name] = table_data
127+
128+
return data
129+
130+
def _is_valid(self, data):
131+
"""Check whether the data satisfies the one-hot constraint.
132+
133+
Args:
134+
data (dict[str, pd.DataFrame]):
135+
Table data.
136+
137+
Returns:
138+
dict[str, pd.Series]:
139+
Whether each row is valid.
140+
"""
141+
table_name = self._get_single_table_name(self.metadata)
142+
is_valid = _get_is_valid_dict(data, table_name)
143+
is_valid[table_name] = self._get_valid_table_data(data[table_name])
144+
145+
return is_valid

sdv/cag/range.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
from sdv._utils import _convert_to_timedelta, _create_unique_name
99
from sdv.cag._errors import PatternNotMetError
10-
from sdv.cag._utils import _validate_table_and_column_names
10+
from sdv.cag._utils import (
11+
_get_is_valid_dict,
12+
_validate_table_and_column_names,
13+
)
1114
from sdv.cag.base import BasePattern
1215
from sdv.constraints.utils import (
1316
cast_to_datetime64,
@@ -355,11 +358,7 @@ def _is_valid(self, data):
355358
Whether each row is valid.
356359
"""
357360
table_name = self._get_single_table_name(self.metadata)
358-
is_valid = {
359-
table: pd.Series(True, index=table_data.index)
360-
for table, table_data in data.items()
361-
if table != table_name
362-
}
361+
is_valid = _get_is_valid_dict(data, table_name)
363362
is_valid[table_name] = self._get_valid_table_data(data[table_name])
364363

365364
return is_valid

tests/unit/cag/test__utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from sdv.cag._errors import PatternNotMetError
99
from sdv.cag._utils import (
10-
_is_list_of_strings,
10+
_is_list_of_type,
1111
_remove_columns_from_metadata,
1212
_validate_table_and_column_names,
1313
_validate_table_name_if_defined,
@@ -192,10 +192,10 @@ def test__remove_columns_from_metadata_raises_pk():
192192
)
193193

194194

195-
def test__is_list_of_strings():
196-
"""Test `_is_list_of_strings` method"""
197-
assert _is_list_of_strings(['a', 'b'])
198-
assert not _is_list_of_strings(['a', 1])
199-
assert not _is_list_of_strings([1, 2])
200-
assert not _is_list_of_strings(1)
201-
assert not _is_list_of_strings('a')
195+
def test__is_list_of_type():
196+
"""Test `_is_list_of_type` method"""
197+
assert _is_list_of_type(['a', 'b'])
198+
assert not _is_list_of_type(['a', 1])
199+
assert not _is_list_of_type([1, 2])
200+
assert not _is_list_of_type(1)
201+
assert not _is_list_of_type('a')

0 commit comments

Comments
 (0)