Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

narwhalified BaseMappingTransformer #368

Merged
merged 2 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Changed
- fixed issues with all null and nullable-bool column handling in dataframe_init_dispatch
- added NaN error handling to WeightColumnMixin
- narwhalified MeanImputer `#344 https://github.com/lvgig/tubular/issues/344_`
- narwhalified BaseMappingTransformer `#367 <https://github.com/lvgig/tubular/issues/367>_`
- placeholder
- placeholder
- placeholder
- placeholder
Expand Down
22 changes: 12 additions & 10 deletions tests/mapping/test_BaseMappingTransformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re

import polars as pl
import pytest
import test_aide as ta

Expand Down Expand Up @@ -101,13 +102,15 @@ class BaseMappingTransformerTransformTests(GenericTransformTests):
Note this deliberately avoids starting with "Tests" so that the tests are not run on import.
"""

@pytest.mark.parametrize("library", ["pandas", "polars"])
def test_mappings_unchanged(
self,
minimal_attribute_dict,
uninitialized_transformers,
library,
):
"""Test that mappings is unchanged in transform."""
df = d.create_df_3()
df = d.create_df_3(library=library)

mapping = {
"b": {1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7},
Expand All @@ -116,18 +119,17 @@ def test_mappings_unchanged(
args = minimal_attribute_dict[self.transformer_name].copy()
args["mappings"] = mapping

x = uninitialized_transformers[self.transformer_name](**args)
transformer = uninitialized_transformers[self.transformer_name](**args)

x.transform(df)

ta.equality.assert_equal_dispatch(
expected=mapping,
actual=x.mappings,
msg=f"{self.transformer_name}.transform has changed self.mappings unexpectedly",
)
# if transformer is not yet polars compatible, skip this test
if not transformer.polars_compatible and isinstance(df, pl.DataFrame):
return

transformer.transform(df)

# Running the BaseMappingTransformerTestSuite
assert (
mapping == transformer.mappings
), f"{self.transformer_name}.transform has changed self.mappings unexpectedly, expected {mapping} but got {transformer.mappings}"


class TestInit(BaseMappingTransformerInitTests):
Expand Down
14 changes: 10 additions & 4 deletions tubular/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@

import warnings
from collections import OrderedDict
from typing import TYPE_CHECKING

import narwhals as nw
import numpy as np
import pandas as pd
from pandas.api.types import is_categorical_dtype

from tubular.base import BaseTransformer

if TYPE_CHECKING:
from narwhals.typing import FrameT


class BaseMappingTransformer(BaseTransformer):
"""Base Transformer Extension for mapping transformers.
Expand All @@ -37,7 +42,7 @@ class attribute, indicates whether transformer has been converted to polars/pand

"""

polars_compatible = False
polars_compatible = True

def __init__(self, mappings: dict[str, dict], **kwargs: dict[str, bool]) -> None:
if isinstance(mappings, dict):
Expand All @@ -60,18 +65,19 @@ def __init__(self, mappings: dict[str, dict], **kwargs: dict[str, bool]) -> None

super().__init__(columns=columns, **kwargs)

def transform(self, X: pd.DataFrame) -> pd.DataFrame:
@nw.narwhalify
def transform(self, X: FrameT) -> FrameT:
"""Base mapping transformer transform method. Checks that the mappings
dict has been fitted and calls the BaseTransformer transform method.

Parameters
----------
X : pd.DataFrame
X : pd/pl.DataFrame
Data to apply mappings to.

Returns
-------
X : pd.DataFrame
X : pd/pl.DataFrame
Input X, copied if specified by user.

"""
Expand Down