Skip to content

Commit

Permalink
Update RDT version to 1.0 (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
pvk-developer authored Jul 8, 2022
1 parent 715857e commit 4a37508
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 40 deletions.
14 changes: 7 additions & 7 deletions ctgan/data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np
import pandas as pd
from rdt.transformers import BayesGMMTransformer, OneHotEncodingTransformer
from rdt.transformers import ClusterBasedNormalizer, OneHotEncoder

SpanInfo = namedtuple('SpanInfo', ['dim', 'activation_fn'])
ColumnTransformInfo = namedtuple(
Expand Down Expand Up @@ -45,7 +45,7 @@ def _fit_continuous(self, data):
A ``ColumnTransformInfo`` object.
"""
column_name = data.columns[0]
gm = BayesGMMTransformer(max_clusters=min(len(data), 10))
gm = ClusterBasedNormalizer(model_missing_values=True, max_clusters=min(len(data), 10))
gm.fit(data, [column_name])
num_components = sum(gm.valid_component_indicator)

Expand All @@ -66,7 +66,7 @@ def _fit_discrete(self, data):
A ``ColumnTransformInfo`` object.
"""
column_name = data.columns[0]
ohe = OneHotEncodingTransformer()
ohe = OneHotEncoder()
ohe.fit(data, [column_name])
num_categories = len(ohe.dummies)

Expand All @@ -78,8 +78,8 @@ def _fit_discrete(self, data):
def fit(self, raw_data, discrete_columns=()):
"""Fit the ``DataTransformer``.
Fits a ``BayesGMMTransformer`` for continuous columns and a
``OneHotEncodingTransformer`` for discrete columns.
Fits a ``ClusterBasedNormalizer`` for continuous columns and a
``OneHotEncoder`` for discrete columns.
This step also counts the #columns in matrix data and span information.
"""
Expand Down Expand Up @@ -145,7 +145,7 @@ def transform(self, raw_data):

def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st):
gm = column_transform_info.transform
data = pd.DataFrame(column_data[:, :2], columns=list(gm.get_output_types()))
data = pd.DataFrame(column_data[:, :2], columns=list(gm.get_output_sdtypes()))
data.iloc[:, 1] = np.argmax(column_data[:, 1:], axis=1)
if sigmas is not None:
selected_normalized_value = np.random.normal(data.iloc[:, 0], sigmas[st])
Expand All @@ -155,7 +155,7 @@ def _inverse_transform_continuous(self, column_transform_info, column_data, sigm

def _inverse_transform_discrete(self, column_transform_info, column_data):
ohe = column_transform_info.transform
data = pd.DataFrame(column_data, columns=list(ohe.get_output_types()))
data = pd.DataFrame(column_data, columns=list(ohe.get_output_sdtypes()))
return ohe.reverse_transform(data)[column_transform_info.column_name]

def inverse_transform(self, data, sigmas=None):
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
'scikit-learn>=0.24,<2',
'torch>=1.8.0,<2',
'torchvision>=0.9.0,<1',
'rdt>=0.6.2,<0.7',
'rdt>=1.1.0,<2.0',
]

setup_requires = [
Expand All @@ -30,6 +30,7 @@
'pytest>=3.4.2',
'pytest-rerunfailures>=9.1.1,<10',
'pytest-cov>=2.6.0',
'rundoc>=0.4.3,<0.5',
]

development_requires = [
Expand Down
64 changes: 32 additions & 32 deletions tests/unit/test_data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@

class TestDataTransformer(TestCase):

@patch('ctgan.data_transformer.BayesGMMTransformer')
def test___fit_continuous(self, MockBGM):
@patch('ctgan.data_transformer.ClusterBasedNormalizer')
def test___fit_continuous(self, MockCBN):
"""Test ``_fit_continuous`` on a simple continuous column.
A ``BayesGMMTransformer`` will be created and fit with some ``data``.
A ``ClusterBasedNormalizer`` will be created and fit with some ``data``.
Setup:
- Mock the ``BayesGMMTransformer`` with ``valid_component_indicator`` as
- Mock the ``ClusterBasedNormalizer`` with ``valid_component_indicator`` as
``[True, False, True]``.
- Initialize a ``DataTransformer``.
Expand All @@ -28,16 +28,16 @@ def test___fit_continuous(self, MockBGM):
Output:
- A ``ColumnTransformInfo`` object where:
- ``column_name`` matches the column of the data.
- ``transform`` is the ``BayesGMMTransformer`` instance.
- ``transform`` is the ``ClusterBasedNormalizer`` instance.
- ``output_dimensions`` is 3 (matches size of ``valid_component_indicator``).
- ``output_info`` assigns the correct activation functions.
Side Effects:
- ``fit`` should be called with the data.
"""
# Setup
bgm_instance = MockBGM.return_value
bgm_instance.valid_component_indicator = [True, False, True]
cbn_instance = MockCBN.return_value
cbn_instance.valid_component_indicator = [True, False, True]
transformer = DataTransformer()
data = pd.DataFrame(np.random.normal((100, 1)), columns=['column'])

Expand All @@ -46,25 +46,25 @@ def test___fit_continuous(self, MockBGM):

# Assert
assert info.column_name == 'column'
assert info.transform == bgm_instance
assert info.transform == cbn_instance
assert info.output_dimensions == 3
assert info.output_info[0].dim == 1
assert info.output_info[0].activation_fn == 'tanh'
assert info.output_info[1].dim == 2
assert info.output_info[1].activation_fn == 'softmax'

@patch('ctgan.data_transformer.BayesGMMTransformer')
def test__fit_continuous_max_clusters(self, MockBGM):
@patch('ctgan.data_transformer.ClusterBasedNormalizer')
def test__fit_continuous_max_clusters(self, MockCBN):
"""Test ``_fit_continuous`` with data that has less than 10 rows.
Expect that a ``BayesGMMTransformer`` is created with the max number of clusters
Expect that a ``ClusterBasedNormalizer`` is created with the max number of clusters
set to the length of the data.
Input:
- Data with less than 10 rows.
Side Effects:
- A ``BayesGMMTransformer`` is created with the max number of clusters set to the
- A ``ClusterBasedNormalizer`` is created with the max number of clusters set to the
length of the data.
"""
# Setup
Expand All @@ -75,16 +75,16 @@ def test__fit_continuous_max_clusters(self, MockBGM):
transformer._fit_continuous(data)

# Assert
MockBGM.assert_called_once_with(max_clusters=len(data))
MockCBN.assert_called_once_with(model_missing_values=True, max_clusters=len(data))

@patch('ctgan.data_transformer.OneHotEncodingTransformer')
@patch('ctgan.data_transformer.OneHotEncoder')
def test___fit_discrete(self, MockOHE):
"""Test ``_fit_discrete_`` on a simple discrete column.
A ``OneHotEncodingTransformer`` will be created and fit with the ``data``.
A ``OneHotEncoder`` will be created and fit with the ``data``.
Setup:
- Mock the ``OneHotEncodingTransformer``.
- Mock the ``OneHotEncoder``.
- Create ``DataTransformer``.
Input:
Expand All @@ -93,7 +93,7 @@ def test___fit_discrete(self, MockOHE):
Output:
- A ``ColumnTransformInfo`` object where:
- ``column_name`` matches the column of the data.
- ``transform`` is the ``OneHotEncodingTransformer`` instance.
- ``transform`` is the ``OneHotEncoder`` instance.
- ``output_dimensions`` is 2.
- ``output_info`` assigns the correct activation function.
Expand Down Expand Up @@ -172,12 +172,12 @@ def test_fit(self):
transformer._fit_continuous.assert_called_once()
assert transformer.output_dimensions == 6

@patch('ctgan.data_transformer.BayesGMMTransformer')
def test__transform_continuous(self, MockBGM):
@patch('ctgan.data_transformer.ClusterBasedNormalizer')
def test__transform_continuous(self, MockCBN):
"""Test ``_transform_continuous``.
Setup:
- Mock the ``BayesGMMTransformer`` with the transform method returning
- Mock the ``ClusterBasedNormalizer`` with the transform method returning
some dataframe.
- Create ``DataTransformer``.
Expand All @@ -191,16 +191,16 @@ def test__transform_continuous(self, MockBGM):
representation of the component part of the mocked transform.
"""
# Setup
bgm_instance = MockBGM.return_value
bgm_instance.transform.return_value = pd.DataFrame({
cbn_instance = MockCBN.return_value
cbn_instance.transform.return_value = pd.DataFrame({
'x.normalized': [0.1, 0.2, 0.3],
'x.component': [0.0, 1.0, 1.0]
})

transformer = DataTransformer()
data = pd.DataFrame({'x': np.array([0.1, 0.3, 0.5])})
column_transform_info = ColumnTransformInfo(
column_name='x', column_type='continuous', transform=bgm_instance,
column_name='x', column_type='continuous', transform=cbn_instance,
output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')],
output_dimensions=1 + 3
)
Expand Down Expand Up @@ -291,14 +291,14 @@ def test_transform(self):
assert (result[:, 1:4] == expected[:, 1:4]).all(), 'continuous-softmax'
assert (result[:, 4:6] == expected[:, 4:6]).all(), 'discrete'

@patch('ctgan.data_transformer.BayesGMMTransformer')
def test__inverse_transform_continuous(self, MockBGM):
@patch('ctgan.data_transformer.ClusterBasedNormalizer')
def test__inverse_transform_continuous(self, MockCBN):
"""Test ``_inverse_transform_continuous``.
Setup:
- Create ``DataTransformer``.
- Mock the ``BayesGMMTransformer`` where:
- ``get_output_types`` returns the appropriate dictionary.
- Mock the ``ClusterBasedNormalizer`` where:
- ``get_output_sdtypes`` returns the appropriate dictionary.
- ``reverse_transform`` returns some dataframe.
Input:
Expand All @@ -317,13 +317,13 @@ def test__inverse_transform_continuous(self, MockBGM):
where the first column are floats and the second is a lable encoding.
"""
# Setup
bgm_instance = MockBGM.return_value
bgm_instance.get_output_types.return_value = {
cbn_instance = MockCBN.return_value
cbn_instance.get_output_sdtypes.return_value = {
'x.normalized': 'numerical',
'x.component': 'numerical'
}

bgm_instance.reverse_transform.return_value = pd.DataFrame({
cbn_instance.reverse_transform.return_value = pd.DataFrame({
'x.normalized': [0.1, 0.2, 0.3],
'x.component': [0.0, 1.0, 1.0]
})
Expand All @@ -336,7 +336,7 @@ def test__inverse_transform_continuous(self, MockBGM):
])

column_transform_info = ColumnTransformInfo(
column_name='x', column_type='continuous', transform=bgm_instance,
column_name='x', column_type='continuous', transform=cbn_instance,
output_info=[SpanInfo(1, 'tanh'), SpanInfo(3, 'softmax')],
output_dimensions=1 + 3
)
Expand All @@ -359,7 +359,7 @@ def test__inverse_transform_continuous(self, MockBGM):
})

pd.testing.assert_frame_equal(
bgm_instance.reverse_transform.call_args[0][0],
cbn_instance.reverse_transform.call_args[0][0],
expected_data
)

Expand Down

0 comments on commit 4a37508

Please sign in to comment.