@@ -175,7 +175,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None:
175175
176176 @pytest .mark .parametrize ("center" , (True , False ))
177177 @pytest .mark .parametrize ("window" , (1 , 2 , 3 , 4 ))
178- def test_rolling_construct (self , center , window ) -> None :
178+ def test_rolling_construct (self , center : bool , window : int ) -> None :
179179 s = pd .Series (np .arange (10 ))
180180 da = DataArray .from_series (s )
181181
@@ -610,7 +610,7 @@ def test_rolling_pandas_compat(self, center, window, min_periods) -> None:
610610
611611 @pytest .mark .parametrize ("center" , (True , False ))
612612 @pytest .mark .parametrize ("window" , (1 , 2 , 3 , 4 ))
613- def test_rolling_construct (self , center , window ) -> None :
613+ def test_rolling_construct (self , center : bool , window : int ) -> None :
614614 df = pd .DataFrame (
615615 {
616616 "x" : np .random .randn (20 ),
@@ -627,19 +627,58 @@ def test_rolling_construct(self, center, window) -> None:
627627 np .testing .assert_allclose (df_rolling ["x" ].values , ds_rolling_mean ["x" ].values )
628628 np .testing .assert_allclose (df_rolling .index , ds_rolling_mean ["index" ])
629629
630- # with stride
631- ds_rolling_mean = ds_rolling .construct ("window" , stride = 2 ).mean ("window" )
632- np .testing .assert_allclose (
633- df_rolling ["x" ][::2 ].values , ds_rolling_mean ["x" ].values
634- )
635- np .testing .assert_allclose (df_rolling .index [::2 ], ds_rolling_mean ["index" ])
636630 # with fill_value
637631 ds_rolling_mean = ds_rolling .construct ("window" , stride = 2 , fill_value = 0.0 ).mean (
638632 "window"
639633 )
640634 assert (ds_rolling_mean .isnull ().sum () == 0 ).to_array (dim = "vars" ).all ()
641635 assert (ds_rolling_mean ["x" ] == 0.0 ).sum () >= 0
642636
637+ @pytest .mark .parametrize ("center" , (True , False ))
638+ @pytest .mark .parametrize ("window" , (1 , 2 , 3 , 4 ))
639+ def test_rolling_construct_stride (self , center : bool , window : int ) -> None :
640+ df = pd .DataFrame (
641+ {
642+ "x" : np .random .randn (20 ),
643+ "y" : np .random .randn (20 ),
644+ "time" : np .linspace (0 , 1 , 20 ),
645+ }
646+ )
647+ ds = Dataset .from_dataframe (df )
648+ df_rolling_mean = df .rolling (window , center = center , min_periods = 1 ).mean ()
649+
650+ # With an index (dimension coordinate)
651+ ds_rolling = ds .rolling (index = window , center = center )
652+ ds_rolling_mean = ds_rolling .construct ("w" , stride = 2 ).mean ("w" )
653+ np .testing .assert_allclose (
654+ df_rolling_mean ["x" ][::2 ].values , ds_rolling_mean ["x" ].values
655+ )
656+ np .testing .assert_allclose (df_rolling_mean .index [::2 ], ds_rolling_mean ["index" ])
657+
658+ # Without index (https://github.com/pydata/xarray/issues/7021)
659+ ds2 = ds .drop_vars ("index" )
660+ ds2_rolling = ds2 .rolling (index = window , center = center )
661+ ds2_rolling_mean = ds2_rolling .construct ("w" , stride = 2 ).mean ("w" )
662+ np .testing .assert_allclose (
663+ df_rolling_mean ["x" ][::2 ].values , ds2_rolling_mean ["x" ].values
664+ )
665+
666+ # Mixed coordinates, indexes and 2D coordinates
667+ ds3 = xr .Dataset (
668+ {"x" : ("t" , range (20 )), "x2" : ("y" , range (5 ))},
669+ {
670+ "t" : range (20 ),
671+ "y" : ("y" , range (5 )),
672+ "t2" : ("t" , range (20 )),
673+ "y2" : ("y" , range (5 )),
674+ "yt" : (["t" , "y" ], np .ones ((20 , 5 ))),
675+ },
676+ )
677+ ds3_rolling = ds3 .rolling (t = window , center = center )
678+ ds3_rolling_mean = ds3_rolling .construct ("w" , stride = 2 ).mean ("w" )
679+ for coord in ds3 .coords :
680+ assert coord in ds3_rolling_mean .coords
681+
643682 @pytest .mark .slow
644683 @pytest .mark .parametrize ("ds" , (1 , 2 ), indirect = True )
645684 @pytest .mark .parametrize ("center" , (True , False ))
0 commit comments