Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored DateTimeTransformer to use narwhals #382

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#
cfgv==3.4.0
# via pre-commit
colorama==0.4.6
# via pytest
coverage[toml]==7.6.0
# via pytest-cov
distlib==0.3.8
Expand Down
57 changes: 35 additions & 22 deletions tests/dates/test_ToDatetimeTransformer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import datetime

import narwhals as nw
import numpy as np
import pandas as pd
import polars as pl
import pytest
import test_aide as ta
from pandas.testing import assert_frame_equal

from tests.base_tests import (
ColumnStrListInitTests,
Expand All @@ -30,15 +32,15 @@ def test_to_datetime_kwargs_type_error(self):
"""Test that an exception is raised if to_datetime_kwargs is not a dict."""
with pytest.raises(
TypeError,
match=r"""ToDatetimeTransformer: to_datetime_kwargs should be a dict but got type \<class 'int'\>""",
match=r"""ToDatetimeTransformer: to_datetime_kwargs should be a dict but got \<class 'int'\>""",
):
ToDatetimeTransformer(column="b", new_column_name="a", to_datetime_kwargs=1)

def test_to_datetime_kwargs_key_type_error(self):
"""Test that an exception is raised if to_datetime_kwargs has keys which are not str."""
with pytest.raises(
TypeError,
match=r"""ToDatetimeTransformer: unexpected type \(\<class 'int'\>\) for to_datetime_kwargs key in position 1, must be str""",
match=r"""ToDatetimeTransformer: unexpected type <class 'int'> for to_datetime_kwargs key, must be str""",
):
ToDatetimeTransformer(
new_column_name="a",
Expand All @@ -48,7 +50,7 @@ def test_to_datetime_kwargs_key_type_error(self):


class TestTransform(GenericTransformTests):
"""Tests for BaseDatetimeTransformer.transform."""
"""Tests for ToDatetimeTransformer.transform."""

@classmethod
def setup_class(cls):
Expand Down Expand Up @@ -77,29 +79,36 @@ def expected_df_1():
pd.NaT,
],
},
)
).astype({"a": "float64", "b": "float64"})

def create_to_datetime_test_df():
"""Create DataFrame to be used in the ToDatetimeTransformer tests."""
def create_to_datetime_test_df_pandas():
"""Create Pandas DataFrame to be used in the ToDatetimeTransformer tests."""
return pd.DataFrame(
{"a": [1950, 1960, 2000, 2001, np.nan, 2010], "b": [1, 2, 3, 4, 5, np.nan]},
)

def create_to_datetime_test_df_polars():
"""Create Polars DataFrame to be used in the ToDatetimeTransformer tests."""
return pl.DataFrame(
{"a": [1950, 1960, 2000, 2001, None, 2010], "b": [1, 2, 3, 4, 5, None]},
)

@pytest.mark.parametrize(
("df", "expected"),
ta.pandas.adjusted_dataframe_params(
create_to_datetime_test_df(),
expected_df_1(),
),
"df",
[create_to_datetime_test_df_pandas(), create_to_datetime_test_df_polars()],
)
def test_expected_output(self, df, expected):
"""Test input data is transformed as expected."""
def test_expected_output(self, df):
"""Test input data is transformed as expected for both Pandas and Polars."""

df = nw.from_native(df)

df = df.with_columns(df["a"].cast(nw.String).alias("a"))

to_dt_1 = ToDatetimeTransformer(
column="a",
new_column_name="a_Y",
to_datetime_kwargs={"format": "%Y", "utc": datetime.timezone.utc},
)

to_dt_2 = ToDatetimeTransformer(
column="b",
new_column_name="b_m",
Expand All @@ -109,21 +118,25 @@ def test_expected_output(self, df, expected):
df_transformed = to_dt_1.transform(df)
df_transformed = to_dt_2.transform(df_transformed)

print(df_transformed)
print(expected)
df_transformed_native = (
df_transformed.to_native()
if hasattr(df_transformed, "to_native")
else df_transformed
)
expected_native = TestTransform.expected_df_1()

ta.equality.assert_equal_dispatch(
expected=expected,
actual=df_transformed,
msg="ToDatetimeTransformer.transform output",
df_transformed_native = df_transformed_native.astype(
{"a": "float64", "b": "float64"},
)

assert_frame_equal(df_transformed_native, expected_native)


class TestOtherBaseBehaviour(OtherBaseBehaviourTests):
"""
Class to run tests for BaseTransformerBehaviour outside the three standard methods.

May need to overwite specific tests in this class if the tested transformer modifies this behaviour.
May need to overwrite specific tests in this class if the tested transformer modifies this behaviour.
"""

@classmethod
Expand Down
69 changes: 52 additions & 17 deletions tubular/dates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
import datetime
import warnings

import narwhals as nw
import numpy as np
import pandas as pd
import polars as pl

from tubular.base import BaseTransformer
from tubular.mixins import DropOriginalMixin, NewColumnNameMixin, TwoColumnMixin
Expand Down Expand Up @@ -478,7 +480,7 @@ def transform(self, X: pd.DataFrame) -> pd.DataFrame:
class ToDatetimeTransformer(BaseGenericDateTransformer):
"""Class to transform convert specified columns to datetime.

Class simply uses the pd.to_datetime method on the specified columns.
Uses the pd.to_datetime method for Pandas or pl.col().str.strptime for Polars.

Parameters
----------
Expand All @@ -505,7 +507,7 @@ class attribute, indicates whether transformer has been converted to polars/pand

"""

polars_compatible = False
polars_compatible = True

def __init__(
self,
Expand All @@ -517,15 +519,15 @@ def __init__(
) -> None:
if to_datetime_kwargs is None:
to_datetime_kwargs = {}
else:
if type(to_datetime_kwargs) is not dict:
msg = f"{self.classname()}: to_datetime_kwargs should be a dict but got type {type(to_datetime_kwargs)}"
raise TypeError(msg)

for i, k in enumerate(to_datetime_kwargs.keys()):
if type(k) is not str:
msg = f"{self.classname()}: unexpected type ({type(k)}) for to_datetime_kwargs key in position {i}, must be str"
raise TypeError(msg)
if not isinstance(to_datetime_kwargs, dict):
msg = f"{self.classname()}: to_datetime_kwargs should be a dict but got {type(to_datetime_kwargs)}"
raise TypeError(msg)

for k in to_datetime_kwargs:
if not isinstance(k, str):
msg = f"{self.classname()}: unexpected type {type(k)} for to_datetime_kwargs key, must be str"
raise TypeError(msg)

self.to_datetime_kwargs = to_datetime_kwargs

Expand All @@ -540,23 +542,56 @@ def __init__(
**kwargs,
)

def transform(self, X: pd.DataFrame) -> pd.DataFrame:
"""Convert specified column to datetime using pd.to_datetime.
@nw.narwhalify
def transform(self, X: nw.DataFrame) -> nw.DataFrame:
"""Convert specified column to datetime using Narwhals.

Parameters
----------
X : pd.DataFrame
Data with column to transform.

"""

# purposely avoid BaseDateTransformer method, as uniquely for this transformer columns
# are not yet date/datetime
X = BaseTransformer.transform(self, X)
X = nw.from_native(BaseTransformer.transform(self, X))

X[self.new_column_name] = pd.to_datetime(
X[self.columns[0]],
**self.to_datetime_kwargs,
)
native_X = X.to_native()
if isinstance(native_X, pd.DataFrame):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

narwhals is designed to avoid having to implement branching polars/pandas handling like this, the idea is to make use of narwhals methods, which will translate between pandas/polars behind the scenes. In this case, would try something like

time_unit=self.to_datetime_kwargs.get('time_unit', None)
time_zone=self.to_datetime_kwargs.get('time_zone', None)
df=df.with_columns(nw.col(columns[0]).cast(nw.Datetime(time_unit=time_unit, time_zone=time_zone)))

native_X[self.columns[0]] = (
native_X[self.columns[0]].astype(str).str.split(".").str[0]
)
native_X[self.new_column_name] = pd.to_datetime(
native_X[self.columns[0]],
**{k: v for k, v in self.to_datetime_kwargs.items() if k != "utc"},
)
if self.to_datetime_kwargs.get("utc", False):
native_X[self.new_column_name] = native_X[
self.new_column_name
].dt.tz_localize("UTC")
X = nw.from_native(native_X)

elif isinstance(native_X, pl.DataFrame):
X = X.with_columns(
nw.col(self.columns[0]).str.replace(".0", "").alias(self.columns[0]),
)

X = X.with_columns(
nw.col(self.columns[0])
.str.strptime(nw.Datetime, "%Y", strict=False)
.alias(self.new_column_name),
)

if self.to_datetime_kwargs.get("utc", False):
X = X.with_columns(
nw.col(self.new_column_name)
.dt.convert_time_zone("UTC")
.alias(self.new_column_name),
)
else:
error_message = f"Unsupported DataFrame type: {type(native_X)}"
raise TypeError(error_message)

# Drop original columns if self.drop_original is True
return DropOriginalMixin.drop_original_column(
Expand Down
Loading