-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
/
Copy pathtest_mask.py
94 lines (70 loc) · 2.34 KB
/
test_mask.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import numpy as np
import pytest
import pandas.util._test_decorators as td
from pandas import Series
import pandas._testing as tm
def test_mask():
# compare with tested results in test_where
s = Series(np.random.default_rng(2).standard_normal(5))
cond = s > 0
rs = s.where(~cond, np.nan)
tm.assert_series_equal(rs, s.mask(cond))
rs = s.where(~cond)
rs2 = s.mask(cond)
tm.assert_series_equal(rs, rs2)
rs = s.where(~cond, -s)
rs2 = s.mask(cond, -s)
tm.assert_series_equal(rs, rs2)
cond = Series([True, False, False, True, False], index=s.index)
s2 = -(s.abs())
rs = s2.where(~cond[:3])
rs2 = s2.mask(cond[:3])
tm.assert_series_equal(rs, rs2)
rs = s2.where(~cond[:3], -s2)
rs2 = s2.mask(cond[:3], -s2)
tm.assert_series_equal(rs, rs2)
msg = "Array conditional must be same shape as self"
with pytest.raises(ValueError, match=msg):
s.mask(1)
with pytest.raises(ValueError, match=msg):
s.mask(cond[:3].values, -s)
def test_mask_casts():
# dtype changes
ser = Series([1, 2, 3, 4])
result = ser.mask(ser > 2, np.nan)
expected = Series([1, 2, np.nan, np.nan])
tm.assert_series_equal(result, expected)
def test_mask_casts2():
# see gh-21891
ser = Series([1, 2])
res = ser.mask([True, False])
exp = Series([np.nan, 2])
tm.assert_series_equal(res, exp)
def test_mask_inplace():
s = Series(np.random.default_rng(2).standard_normal(5))
cond = s > 0
rs = s.copy()
rs.mask(cond, inplace=True)
tm.assert_series_equal(rs.dropna(), s[~cond])
tm.assert_series_equal(rs, s.mask(cond))
rs = s.copy()
rs.mask(cond, -s, inplace=True)
tm.assert_series_equal(rs, s.mask(cond, -s))
@pytest.mark.parametrize(
"dtype",
[
"Int64",
pytest.param("int64[pyarrow]", marks=td.skip_if_no("pyarrow")),
],
)
def test_mask_na(dtype):
# We should not be filling pd.NA. See GH#60729
series = Series([None, 1, 2, None, 3, 4, None], dtype=dtype)
cond = series <= 2
expected = Series([None, -99, -99, None, 3, 4, None], dtype=dtype)
result = series.mask(cond, -99)
tm.assert_series_equal(result, expected)
result = series.mask(cond.to_list(), -99)
tm.assert_series_equal(result, expected)
result = series.mask(cond.to_numpy(), -99)
tm.assert_series_equal(result, expected)