File tree Expand file tree Collapse file tree 3 files changed +21
-9
lines changed Expand file tree Collapse file tree 3 files changed +21
-9
lines changed Original file line number Diff line number Diff line change @@ -619,9 +619,7 @@ def quantile(
619
619
620
620
def zip_with (self : Self , mask : Any , other : Any ) -> PandasLikeSeries :
621
621
ser = self ._native_series
622
- mask = validate_column_comparand (
623
- ser .index , mask , treat_length_one_as_scalar = False
624
- )
622
+ mask = validate_column_comparand (ser .index , mask )
625
623
other = validate_column_comparand (ser .index , other )
626
624
res = ser .where (mask , other )
627
625
return self ._from_native_series (res )
Original file line number Diff line number Diff line change 32
32
}
33
33
34
34
35
- def validate_column_comparand (
36
- index : Any , other : Any , * , treat_length_one_as_scalar : bool = True
37
- ) -> Any :
35
+ def validate_column_comparand (index : Any , other : Any ) -> Any :
38
36
"""Validate RHS of binary operation.
39
37
40
38
If the comparison isn't supported, return `NotImplemented` so that the
@@ -55,9 +53,10 @@ def validate_column_comparand(
55
53
if isinstance (other , PandasLikeDataFrame ):
56
54
return NotImplemented
57
55
if isinstance (other , PandasLikeSeries ):
58
- if other .len () == 1 and treat_length_one_as_scalar :
56
+ if other .len () == 1 :
59
57
# broadcast
60
- return other .item ()
58
+ s = other ._native_series
59
+ return s .__class__ (s .iloc [0 ], index = index , dtype = s .dtype )
61
60
if other ._native_series .index is not index :
62
61
return set_axis (
63
62
other ._native_series ,
@@ -83,7 +82,8 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any:
83
82
if isinstance (other , PandasLikeSeries ):
84
83
if other .len () == 1 :
85
84
# broadcast
86
- return other ._native_series .iloc [0 ]
85
+ s = other ._native_series
86
+ return s .__class__ (s .iloc [0 ], index = index , dtype = s .dtype )
87
87
if other ._native_series .index is not index :
88
88
return set_axis (
89
89
other ._native_series ,
Original file line number Diff line number Diff line change 1
1
import numpy as np
2
2
import pandas as pd
3
+ import pyarrow as pa
4
+ import pytest
3
5
4
6
import narwhals .stable .v1 as nw
7
+ from narwhals .utils import parse_version
5
8
from tests .utils import Constructor
6
9
from tests .utils import compare_dicts
7
10
@@ -40,3 +43,14 @@ def test_with_columns_order_single_row(constructor: Constructor) -> None:
40
43
assert result .collect_schema ().names () == ["a" , "b" , "z" , "d" ]
41
44
expected = {"a" : [2 ], "b" : [4 ], "z" : [7.0 ], "d" : [0 ]}
42
45
compare_dicts (result , expected )
46
+
47
+
48
+ def test_with_columns_dtypes_single_row (
49
+ constructor : Constructor , request : pytest .FixtureRequest
50
+ ) -> None :
51
+ if "pyarrow_table" in str (constructor ) and parse_version (pa .__version__ ) < (15 ,):
52
+ request .applymarker (pytest .mark .xfail )
53
+ data = {"a" : ["foo" ]}
54
+ df = nw .from_native (constructor (data )).with_columns (nw .col ("a" ).cast (nw .Categorical ))
55
+ result = df .with_columns (nw .col ("a" ))
56
+ assert result .collect_schema () == {"a" : nw .Categorical }
You can’t perform that action at this time.
0 commit comments