Skip to content

Commit eb2ff69

Browse files
authored
Fixes dimension order in xarray.Dataset.to_stacked_array (#10205)
* Fixes dimension order in xarray.Dataset.to_stacked_array * corrected dummy variable name to satisfy mypy * added type annotation to satisfy mypy * corrected type annotation to satisfy mypy
1 parent 05072ed commit eb2ff69

File tree

3 files changed

+37
-1
lines changed

3 files changed

+37
-1
lines changed

doc/whats-new.rst

+3
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ Deprecations
3636
Bug fixes
3737
~~~~~~~~~
3838

39+
- :py:meth:`~xarray.Dataset.to_stacked_array` now uses dimensions in order of appearance.
40+
This fixes the issue where using :py:meth:`~xarray.Dataset.transpose` before :py:meth:`~xarray.Dataset.to_stacked_array`
41+
had no effect. (Mentioned in :issue:`9921`)
3942

4043
Documentation
4144
~~~~~~~~~~~~~

xarray/core/dataset.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -5246,7 +5246,13 @@ def to_stacked_array(
52465246
"""
52475247
from xarray.structure.concat import concat
52485248

5249-
stacking_dims = tuple(dim for dim in self.dims if dim not in sample_dims)
5249+
# add stacking dims by order of appearance
5250+
stacking_dims_list: list[Hashable] = []
5251+
for da in self.data_vars.values():
5252+
for dim in da.dims:
5253+
if dim not in sample_dims and dim not in stacking_dims_list:
5254+
stacking_dims_list.append(dim)
5255+
stacking_dims = tuple(stacking_dims_list)
52505256

52515257
for key, da in self.data_vars.items():
52525258
missing_sample_dims = set(sample_dims) - set(da.dims)

xarray/tests/test_dataset.py

+27
Original file line numberDiff line numberDiff line change
@@ -4098,6 +4098,33 @@ def test_to_stacked_array_preserves_dtype(self) -> None:
40984098
expected_stacked_variable,
40994099
)
41004100

4101+
def test_to_stacked_array_transposed(self) -> None:
4102+
# test that to_stacked_array uses updated dim order after transposition
4103+
ds = xr.Dataset(
4104+
data_vars=dict(
4105+
v1=(["d1", "d2"], np.arange(6).reshape((2, 3))),
4106+
),
4107+
coords=dict(
4108+
d1=(["d1"], np.arange(2)),
4109+
d2=(["d2"], np.arange(3)),
4110+
),
4111+
)
4112+
da = ds.to_stacked_array(
4113+
new_dim="new_dim",
4114+
sample_dims=[],
4115+
variable_dim="variable",
4116+
)
4117+
dsT = ds.transpose()
4118+
daT = dsT.to_stacked_array(
4119+
new_dim="new_dim",
4120+
sample_dims=[],
4121+
variable_dim="variable",
4122+
)
4123+
v1 = np.arange(6)
4124+
v1T = np.arange(6).reshape((2, 3)).T.flatten()
4125+
np.testing.assert_equal(da.to_numpy(), v1)
4126+
np.testing.assert_equal(daT.to_numpy(), v1T)
4127+
41014128
def test_update(self) -> None:
41024129
data = create_test_data(seed=0)
41034130
expected = data.copy()

0 commit comments

Comments
 (0)