Skip to content

Commit e980483

Browse files
authored
fix: preserve dtypes when using with_columns and length-1 pandas df (#1201)
* fix: preserve dtypes when using with_columns and length-1 pandas df * pyarrow versions
1 parent 879d3cf commit e980483

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

narwhals/_pandas_like/series.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -619,9 +619,7 @@ def quantile(
619619

620620
def zip_with(self: Self, mask: Any, other: Any) -> PandasLikeSeries:
621621
ser = self._native_series
622-
mask = validate_column_comparand(
623-
ser.index, mask, treat_length_one_as_scalar=False
624-
)
622+
mask = validate_column_comparand(ser.index, mask)
625623
other = validate_column_comparand(ser.index, other)
626624
res = ser.where(mask, other)
627625
return self._from_native_series(res)

narwhals/_pandas_like/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@
3232
}
3333

3434

35-
def validate_column_comparand(
36-
index: Any, other: Any, *, treat_length_one_as_scalar: bool = True
37-
) -> Any:
35+
def validate_column_comparand(index: Any, other: Any) -> Any:
3836
"""Validate RHS of binary operation.
3937
4038
If the comparison isn't supported, return `NotImplemented` so that the
@@ -55,9 +53,10 @@ def validate_column_comparand(
5553
if isinstance(other, PandasLikeDataFrame):
5654
return NotImplemented
5755
if isinstance(other, PandasLikeSeries):
58-
if other.len() == 1 and treat_length_one_as_scalar:
56+
if other.len() == 1:
5957
# broadcast
60-
return other.item()
58+
s = other._native_series
59+
return s.__class__(s.iloc[0], index=index, dtype=s.dtype)
6160
if other._native_series.index is not index:
6261
return set_axis(
6362
other._native_series,
@@ -83,7 +82,8 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any:
8382
if isinstance(other, PandasLikeSeries):
8483
if other.len() == 1:
8584
# broadcast
86-
return other._native_series.iloc[0]
85+
s = other._native_series
86+
return s.__class__(s.iloc[0], index=index, dtype=s.dtype)
8787
if other._native_series.index is not index:
8888
return set_axis(
8989
other._native_series,

tests/frame/with_columns_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import numpy as np
22
import pandas as pd
3+
import pyarrow as pa
4+
import pytest
35

46
import narwhals.stable.v1 as nw
7+
from narwhals.utils import parse_version
58
from tests.utils import Constructor
69
from tests.utils import compare_dicts
710

@@ -40,3 +43,14 @@ def test_with_columns_order_single_row(constructor: Constructor) -> None:
4043
assert result.collect_schema().names() == ["a", "b", "z", "d"]
4144
expected = {"a": [2], "b": [4], "z": [7.0], "d": [0]}
4245
compare_dicts(result, expected)
46+
47+
48+
def test_with_columns_dtypes_single_row(
49+
constructor: Constructor, request: pytest.FixtureRequest
50+
) -> None:
51+
if "pyarrow_table" in str(constructor) and parse_version(pa.__version__) < (15,):
52+
request.applymarker(pytest.mark.xfail)
53+
data = {"a": ["foo"]}
54+
df = nw.from_native(constructor(data)).with_columns(nw.col("a").cast(nw.Categorical))
55+
result = df.with_columns(nw.col("a"))
56+
assert result.collect_schema() == {"a": nw.Categorical}

0 commit comments

Comments
 (0)