Skip to content

Commit eb7e112

Browse files
ghislainpkeewis
andauthored
Fix DataArray.to_dataframe when the array has MultiIndex (pydata#4442)
Co-authored-by: Keewis <[email protected]>
1 parent c4ad6f1 commit eb7e112

File tree

3 files changed

+74
-2
lines changed

3 files changed

+74
-2
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,9 @@ Bug fixes
144144
a float64 array (:issue:`4898`, :pull:`4911`). By `Blair Bonnett <https://github.com/bcbnz>`_.
145145
- Fix decoding of vlen strings using h5py versions greater than 3.0.0 with h5netcdf backend (:issue:`4570`, :pull:`4893`).
146146
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
147+
- Allow converting :py:class:`Dataset` or :py:class:`DataArray` objects with a ``MultiIndex``
148+
and at least one other dimension to a ``pandas`` object (:issue:`3008`, :pull:`4442`).
149+
By `ghislainp <https://github.com/ghislainp>`_.
147150

148151
Documentation
149152
~~~~~~~~~~~~~

xarray/core/coordinates.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
cast,
1414
)
1515

16+
import numpy as np
1617
import pandas as pd
1718

1819
from . import formatting, indexing
@@ -107,8 +108,49 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
107108
return self._data.get_index(dim) # type: ignore
108109
else:
109110
indexes = [self._data.get_index(k) for k in ordered_dims] # type: ignore
110-
names = list(ordered_dims)
111-
return pd.MultiIndex.from_product(indexes, names=names)
111+
112+
# compute the sizes of the repeat and tile for the cartesian product
113+
# (taken from pandas.core.reshape.util)
114+
index_lengths = np.fromiter(
115+
(len(index) for index in indexes), dtype=np.intp
116+
)
117+
cumprod_lengths = np.cumproduct(index_lengths)
118+
119+
if cumprod_lengths[-1] != 0:
120+
# sizes of the repeats
121+
repeat_counts = cumprod_lengths[-1] / cumprod_lengths
122+
else:
123+
# if any factor is empty, the cartesian product is empty
124+
repeat_counts = np.zeros_like(cumprod_lengths)
125+
126+
# sizes of the tiles
127+
tile_counts = np.roll(cumprod_lengths, 1)
128+
tile_counts[0] = 1
129+
130+
# loop over the indexes
131+
# for each MultiIndex or Index compute the cartesian product of the codes
132+
133+
code_list = []
134+
level_list = []
135+
names = []
136+
137+
for i, index in enumerate(indexes):
138+
if isinstance(index, pd.MultiIndex):
139+
codes, levels = index.codes, index.levels
140+
else:
141+
code, level = pd.factorize(index)
142+
codes = [code]
143+
levels = [level]
144+
145+
# compute the cartesian product
146+
code_list += [
147+
np.tile(np.repeat(code, repeat_counts[i]), tile_counts[i])
148+
for code in codes
149+
]
150+
level_list += levels
151+
names += index.names
152+
153+
return pd.MultiIndex(level_list, code_list, names=names)
112154

113155
def update(self, other: Mapping[Hashable, Any]) -> None:
114156
other_vars = getattr(other, "variables", other)

xarray/tests/test_dataarray.py

+27
Original file line numberDiff line numberDiff line change
@@ -3635,6 +3635,33 @@ def test_to_dataframe(self):
36353635
with raises_regex(ValueError, "unnamed"):
36363636
arr.to_dataframe()
36373637

3638+
def test_to_dataframe_multiindex(self):
3639+
# regression test for #3008
3640+
arr_np = np.random.randn(4, 3)
3641+
3642+
mindex = pd.MultiIndex.from_product([[1, 2], list("ab")], names=["A", "B"])
3643+
3644+
arr = DataArray(arr_np, [("MI", mindex), ("C", [5, 6, 7])], name="foo")
3645+
3646+
actual = arr.to_dataframe()
3647+
assert_array_equal(actual["foo"].values, arr_np.flatten())
3648+
assert_array_equal(actual.index.names, list("ABC"))
3649+
assert_array_equal(actual.index.levels[0], [1, 2])
3650+
assert_array_equal(actual.index.levels[1], ["a", "b"])
3651+
assert_array_equal(actual.index.levels[2], [5, 6, 7])
3652+
3653+
def test_to_dataframe_0length(self):
3654+
# regression test for #3008
3655+
arr_np = np.random.randn(4, 0)
3656+
3657+
mindex = pd.MultiIndex.from_product([[1, 2], list("ab")], names=["A", "B"])
3658+
3659+
arr = DataArray(arr_np, [("MI", mindex), ("C", [])], name="foo")
3660+
3661+
actual = arr.to_dataframe()
3662+
assert len(actual) == 0
3663+
assert_array_equal(actual.index.names, list("ABC"))
3664+
36383665
def test_to_pandas_name_matches_coordinate(self):
36393666
# coordinate with same name as array
36403667
arr = DataArray([1, 2, 3], dims="x", name="x")

0 commit comments

Comments
 (0)