Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convered DropOriginalMixin to narwhals #354

Merged
merged 6 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Changed
- fixed issues with all null and nullable-bool column handling in dataframe_init_dispatch
- added NaN error handling to WeightColumnMixin
- narwhalified MeanImputer `#344 https://github.com/lvgig/tubular/issues/344_`
- narwhalified DropOriginalMixin `#352 <https://github.com/lvgig/tubular/issues/352>_`
- narwhalified BaseMappingTransformer `#367 <https://github.com/lvgig/tubular/issues/367>_`
- placeholder
- placeholder
Expand Down
38 changes: 30 additions & 8 deletions tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import polars as pl
import pytest
import sklearn.base as b
import test_aide as ta

from tests.utils import assert_frame_equal_dispatch

Expand Down Expand Up @@ -877,6 +876,11 @@ class DropOriginalTransformMixinTests:
Note this deliberately avoids starting with "Tests" so that the tests are not run on import.
"""

@pytest.mark.parametrize(
"minimal_dataframe_lookup",
["pandas", "polars"],
indirect=["minimal_dataframe_lookup"],
)
def test_original_columns_dropped_when_specified(
self,
initialized_transformers,
Expand All @@ -888,15 +892,24 @@ def test_original_columns_dropped_when_specified(

x = initialized_transformers[self.transformer_name]

# skip polars test if not narwhalified
if not x.polars_compatible and isinstance(df, pl.DataFrame):
return

x.drop_original = True

x.fit(df)

df_transformed = x.transform(df)
remaining_cols = df_transformed.columns.to_numpy()
remaining_cols = df_transformed.columns
for col in x.columns:
assert col not in remaining_cols, "original columns not dropped"

@pytest.mark.parametrize(
"minimal_dataframe_lookup",
["pandas", "polars"],
indirect=["minimal_dataframe_lookup"],
)
def test_original_columns_kept_when_specified(
self,
initialized_transformers,
Expand All @@ -908,15 +921,24 @@ def test_original_columns_kept_when_specified(

x = initialized_transformers[self.transformer_name]

# skip polars test if not narwhalified
if not x.polars_compatible and isinstance(df, pl.DataFrame):
return

x.drop_original = False

x.fit(df)

df_transformed = x.transform(df)
remaining_cols = df_transformed.columns.to_numpy()
remaining_cols = df_transformed.columns
for col in x.columns:
assert col in remaining_cols, "original columns not kept"

@pytest.mark.parametrize(
"minimal_dataframe_lookup",
["pandas", "polars"],
indirect=["minimal_dataframe_lookup"],
)
def test_other_columns_not_modified(
self,
initialized_transformers,
Expand All @@ -928,18 +950,18 @@ def test_other_columns_not_modified(

x = initialized_transformers[self.transformer_name]

# skip polars test if not narwhalified
if not x.polars_compatible and isinstance(df, pl.DataFrame):
return

other_columns = list(set(df.columns) - set(x.columns))
x.drop_original = True

x.fit(df)

df_transformed = x.transform(df)

ta.equality.assert_equal_dispatch(
expected=df[other_columns],
actual=df_transformed[other_columns],
msg=f"{self.transformer_name}.transform has changed other columns unexpectedly",
)
assert_frame_equal_dispatch(df[other_columns], df_transformed[other_columns])


class ColumnsCheckTests:
Expand Down
84 changes: 84 additions & 0 deletions tests/mixins/test_DropOriginalMixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import pytest

from tests.test_data import create_df_1
from tests.utils import assert_frame_equal_dispatch
from tubular.mixins import DropOriginalMixin


class TestSetDropOriginalColumn:
"tests for DropOriginalMixin.set_drop_original_column"

@pytest.mark.parametrize("drop_orginal_column", (0, "a", ["a"], {"a": 10}, None))
def test_drop_column_arg_errors(
self,
drop_orginal_column,
):
"""Test that appropriate errors are throwm for non boolean arg."""

obj = DropOriginalMixin()

with pytest.raises(
TypeError,
match="DropOriginalMixin: drop_original should be bool",
):
obj.set_drop_original_column(drop_original=drop_orginal_column)


class TestDropOriginalColumn:
"tests for DropOriginalMixin.drop_original_column"

@pytest.mark.parametrize("library", ["pandas", "polars"])
@pytest.mark.parametrize("drop_original", [True, False])
def test_drop_original_arg_handling(
self,
library,
drop_original,
):
"""Test transformer drops/keeps original columns when specified/not specified."""

df = create_df_1(library=library)

obj = DropOriginalMixin()

columns = list(df.columns)

df_transformed = obj.drop_original_column(
df,
drop_original=drop_original,
columns=columns,
)

remaining_cols = df_transformed.columns

if drop_original:
for col in columns:
assert col not in remaining_cols, "original columns not dropped"

else:
for col in columns:
assert col in remaining_cols, "original columns not kept"

@pytest.mark.parametrize("library", ["pandas", "polars"])
def test_other_columns_not_modified(
self,
library,
):
"""Test transformer does not modify unspecified columns."""

df = create_df_1(library=library)

obj = DropOriginalMixin()

drop_original = False

columns = ["a"]

df_transformed = obj.drop_original_column(
df,
drop_original=drop_original,
columns=columns,
)

other_columns = list(set(df.columns) - set(columns))

assert_frame_equal_dispatch(df[other_columns], df_transformed[other_columns])
4 changes: 1 addition & 3 deletions tubular/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,9 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
)

# Drop original columns if self.drop_original is True
DropOriginalMixin.drop_original_column(
return DropOriginalMixin.drop_original_column(
self,
X,
self.drop_original,
self.columns,
)

return X
4 changes: 1 addition & 3 deletions tubular/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,9 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
X[self.new_column_name] = X[self.columns[0]] == X[self.columns[1]]

# Drop original columns if self.drop_original is True
DropOriginalMixin.drop_original_column(
return DropOriginalMixin.drop_original_column(
self,
X,
self.drop_original,
self.columns,
)

return X
28 changes: 7 additions & 21 deletions tubular/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,15 +377,13 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
X[self.new_column_name] = X.apply(self.calculate_age, axis=1)

# Drop original columns if self.drop_original is True
DropOriginalMixin.drop_original_column(
return DropOriginalMixin.drop_original_column(
self,
X,
self.drop_original,
self.columns,
)

return X


class DateDifferenceTransformer(BaseDateTwoColumnTransformer):
"""Class to transform calculate the difference between 2 date fields in specified units.
Expand Down Expand Up @@ -469,15 +467,13 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
) / np.timedelta64(1, self.units)

# Drop original columns if self.drop_original is True
DropOriginalMixin.drop_original_column(
return DropOriginalMixin.drop_original_column(
self,
X,
self.drop_original,
self.columns,
)

return X


class ToDatetimeTransformer(BaseGenericDateTransformer):
"""Class to transform convert specified columns to datetime.
Expand Down Expand Up @@ -563,15 +559,13 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
)

# Drop original columns if self.drop_original is True
DropOriginalMixin.drop_original_column(
return DropOriginalMixin.drop_original_column(
self,
X,
self.drop_original,
self.columns,
)

return X


class SeriesDtMethodTransformer(BaseDatetimeTransformer):
"""Tranformer that applies a pandas.Series.dt method.
Expand Down Expand Up @@ -733,15 +727,13 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
)

# Drop original columns if self.drop_original is True
DropOriginalMixin.drop_original_column(
return DropOriginalMixin.drop_original_column(
self,
X,
self.drop_original,
self.columns,
)

return X


class BetweenDatesTransformer(BaseGenericDateTransformer):
"""Transformer to generate a boolean column indicating if one date is between two others.
Expand Down Expand Up @@ -886,15 +878,13 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
X[self.new_column_name] = lower_comparison & upper_comparison

# Drop original columns if self.drop_original is True
DropOriginalMixin.drop_original_column(
return DropOriginalMixin.drop_original_column(
self,
X,
self.drop_original,
self.columns,
)

return X


class DatetimeInfoExtractor(BaseDatetimeTransformer):
"""Transformer to extract various features from datetime var.
Expand Down Expand Up @@ -1168,15 +1158,13 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
)

# Drop original columns if self.drop_original is True
DropOriginalMixin.drop_original_column(
return DropOriginalMixin.drop_original_column(
self,
X,
self.drop_original,
self.columns,
)

return X


class DatetimeSinusoidCalculator(BaseDatetimeTransformer):
"""Transformer to derive a feature in a dataframe by calculating the
Expand Down Expand Up @@ -1383,11 +1371,9 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
)

# Drop original columns if self.drop_original is True
DropOriginalMixin.drop_original_column(
return DropOriginalMixin.drop_original_column(
self,
X,
self.drop_original,
self.columns,
)

return X
15 changes: 10 additions & 5 deletions tubular/mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ class DropOriginalMixin:

"""

def classname(self) -> str:
"""Method that returns the name of the current class when called."""

return type(self).__name__

def set_drop_original_column(self, drop_original: bool) -> None:
"""Helper method for validating 'drop_original' argument.

Expand All @@ -66,17 +71,18 @@ def set_drop_original_column(self, drop_original: bool) -> None:

self.drop_original = drop_original

@nw.narwhalify
def drop_original_column(
self,
X: pd.DataFrame,
X: FrameT,
drop_original: bool,
columns: list[str] | str | None,
) -> pd.DataFrame:
"""Method for dropping input columns from X if drop_original set to True.

Parameters
----------
X : pd.DataFrame
X : pd/pl.DataFrame
Data with columns to drop.

drop_original : bool
Expand All @@ -87,14 +93,13 @@ def drop_original_column(

Returns
-------
X : pd.DataFrame
X : pd/pl.DataFrame
Transformed input X with columns dropped.

"""

if drop_original:
for col in columns:
del X[col]
X = X.drop(columns)

return X

Expand Down
2 changes: 1 addition & 1 deletion tubular/nominal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1299,7 +1299,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
)

# Drop original columns if self.drop_original is True
DropOriginalMixin.drop_original_column(
X = DropOriginalMixin.drop_original_column(
self,
X,
self.drop_original,
Expand Down
4 changes: 1 addition & 3 deletions tubular/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,9 +225,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
else:
X[new_column_names] = np.log(X[self.columns]) / np.log(self.base)

self.drop_original_column(X, self.drop_original, self.columns)

return X
return self.drop_original_column(X, self.drop_original, self.columns)


class CutTransformer(BaseNumericTransformer):
Expand Down