diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py index fa77a7fd1d..826b5949f2 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py @@ -204,9 +204,9 @@ def _prepare_inner_outer_maps( inner_label = f"inner_{outer_map.label}" inner_range = { self.blocking_parameter: dace_subsets.Range.from_string( - f"({coarse_block_var} * {self.blocking_size} + {rng_start})" + f"(({rng_start}) + ({coarse_block_var}) * ({self.blocking_size}))" + ":" - + f"min(({rng_start} + {coarse_block_var} + 1) * {self.blocking_size}, {rng_stop} + 1)" + + f"min(({rng_start}) + ({coarse_block_var} + 1) * ({self.blocking_size}), ({rng_stop}) + 1)" ) } inner_entry, inner_exit = state.add_map( @@ -219,7 +219,7 @@ def _prepare_inner_outer_maps( # Now we modify the properties of the outer map. coarse_block_range = dace_subsets.Range.from_string( - f"0:int_ceil(({rng_stop} + 1) - {rng_start}, {self.blocking_size})" + f"0:int_ceil((({rng_stop}) + 1) - ({rng_start}), ({self.blocking_size}))" ).ranges[0] outer_map.params[blocking_parameter_dim] = coarse_block_var outer_map.range[blocking_parameter_dim] = coarse_block_range diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index 86136994dc..3b41da6336 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -927,3 +927,68 @@ def test_loop_blocking_no_independent_nodes(): validate_all=True, ) assert count == 1 + + +def _make_only_last_two_elements_sdfg() -> dace.SDFG: + sdfg = dace.SDFG(util.unique_name("simple_block_sdfg")) + state = sdfg.add_state("state", is_start_block=True) + sdfg.add_symbol("N", dace.int32) + sdfg.add_symbol("B", dace.int32) + sdfg.add_symbol("M", dace.int32) + + for name in "acb": + sdfg.add_array( + name, + shape=(20, 10), + dtype=dace.float64, + ) + + state.add_mapped_tasklet( + "computation", + map_ranges={"i": "B:N", "k": "(M-2):M"}, + inputs={ + "__in1": dace.Memlet("a[i, k]"), + "__in2": dace.Memlet("b[i, k]"), + }, + code="__out = __in1 + __in2", + outputs={"__out": dace.Memlet("c[i, k]")}, + external_edges=True, + ) + sdfg.validate() + + return sdfg + + +def test_only_last_two_elements_sdfg(): + sdfg = _make_only_last_two_elements_sdfg() + + def ref_comp(a, b, c, B, N, M): + for i in range(B, N): + for k in range(M - 2, M): + c[i, k] = a[i, k] + b[i, k] + + count = sdfg.apply_transformations_repeated( + gtx_transformations.LoopBlocking( + blocking_size=1, + blocking_parameter="k", + require_independent_nodes=False, + ), + validate=True, + validate_all=True, + ) + assert count == 1 + + ref = { + "a": np.array(np.random.rand(20, 10), dtype=np.float64), + "b": np.array(np.random.rand(20, 10), dtype=np.float64), + "c": np.zeros((20, 10), dtype=np.float64), + "B": 0, + "N": 20, + "M": 6, + } + res = copy.deepcopy(ref) + + ref_comp(**ref) + sdfg(**res) + + assert np.allclose(ref["c"], res["c"])