Skip to content

Commit c2b840f

Browse files
MarcoGorellitwiecki
authored andcommitted
let pandas_to_array take pandas Index
1 parent 79245ce commit c2b840f

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

pymc3/aesaraf.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,13 @@ def pandas_to_array(data):
8686
if hasattr(data, "to_numpy") and hasattr(data, "isnull"):
8787
# typically, but not limited to pandas objects
8888
vals = data.to_numpy()
89-
mask = data.isnull().to_numpy()
89+
null_data = data.isnull()
90+
if hasattr(null_data, "to_numpy"):
91+
# pandas Series
92+
mask = null_data.to_numpy()
93+
else:
94+
# pandas Index
95+
mask = null_data
9096
if mask.any():
9197
# there are missing values
9298
ret = np.ma.MaskedArray(vals, mask)

pymc3/tests/test_aesaraf.py

+7
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,13 @@ def test_pandas_to_array(input_dtype):
430430
assert isinstance(wrapped, TensorVariable)
431431

432432

433+
def test_pandas_to_array_pandas_index():
434+
data = pd.Index([1, 2, 3])
435+
result = pandas_to_array(data)
436+
expected = np.array([1, 2, 3])
437+
np.testing.assert_array_equal(result, expected)
438+
439+
433440
def test_walk_model():
434441
d = at.vector("d")
435442
b = at.vector("b")

0 commit comments

Comments
 (0)