Skip to content

Commit cf52dec

Browse files
BUG (string dtype): fix where() for string dtype with python storage (#60195)
1 parent 169b00e commit cf52dec

File tree

2 files changed

+12
-12
lines changed

2 files changed

+12
-12
lines changed

pandas/core/arrays/string_.py

+6
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,12 @@ def _putmask(self, mask: npt.NDArray[np.bool_], value) -> None:
757757
# base class implementation that uses __setitem__
758758
ExtensionArray._putmask(self, mask, value)
759759

760+
def _where(self, mask: npt.NDArray[np.bool_], value) -> Self:
761+
# the super() method NDArrayBackedExtensionArray._where uses
762+
# np.putmask which doesn't properly handle None/pd.NA, so using the
763+
# base class implementation that uses __setitem__
764+
return ExtensionArray._where(self, mask, value)
765+
760766
def isin(self, values: ArrayLike) -> npt.NDArray[np.bool_]:
761767
if isinstance(values, BaseStringArray) or (
762768
isinstance(values, ExtensionArray) and is_string_dtype(values.dtype)

pandas/tests/frame/indexing/test_where.py

+6-12
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66

77
from pandas._config import using_string_dtype
88

9-
from pandas.compat import HAS_PYARROW
10-
119
from pandas.core.dtypes.common import is_scalar
1210

1311
import pandas as pd
@@ -940,9 +938,6 @@ def test_where_nullable_invalid_na(frame_or_series, any_numeric_ea_dtype):
940938
obj.mask(mask, null)
941939

942940

943-
@pytest.mark.xfail(
944-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
945-
)
946941
@given(data=OPTIONAL_ONE_OF_ALL)
947942
def test_where_inplace_casting(data):
948943
# GH 22051
@@ -1023,19 +1018,18 @@ def test_where_producing_ea_cond_for_np_dtype():
10231018
tm.assert_frame_equal(result, expected)
10241019

10251020

1026-
@pytest.mark.xfail(
1027-
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)", strict=False
1028-
)
10291021
@pytest.mark.parametrize(
10301022
"replacement", [0.001, True, "snake", None, datetime(2022, 5, 4)]
10311023
)
1032-
def test_where_int_overflow(replacement, using_infer_string, request):
1024+
def test_where_int_overflow(replacement, using_infer_string):
10331025
# GH 31687
10341026
df = DataFrame([[1.0, 2e25, "nine"], [np.nan, 0.1, None]])
10351027
if using_infer_string and replacement not in (None, "snake"):
1036-
request.node.add_marker(
1037-
pytest.mark.xfail(reason="Can't set non-string into string column")
1038-
)
1028+
with pytest.raises(
1029+
TypeError, match="Cannot set non-string value|Scalar must be NA or str"
1030+
):
1031+
df.where(pd.notnull(df), replacement)
1032+
return
10391033
result = df.where(pd.notnull(df), replacement)
10401034
expected = DataFrame([[1.0, 2e25, "nine"], [replacement, 0.1, replacement]])
10411035

0 commit comments

Comments
 (0)