From ed4ddd46c71a0523749d61fdd0946bb7c076c568 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 7 Oct 2024 11:17:07 -0400 Subject: [PATCH 1/4] Expose `k_start` and `k_end` automatically for any FrozenStencil --- ndsl/dsl/stencil.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 6a70b4d8..bb9f5c70 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -737,6 +737,8 @@ def axis_offsets( "local_js": gtscript.J[0] + self.jsc - origin[1], "j_end": j_end, "local_je": gtscript.J[-1] + self.jec - origin[1] - domain[1] + 1, + "k_start": self.origin[2], + "k_end": self.origin[2] + domain[2] - 1, } def get_origin_domain( From 5b09a676394de64d972ded94e04cabde4c647320 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 7 Oct 2024 11:37:22 -0400 Subject: [PATCH 2/4] Fix k_start + utest --- ndsl/dsl/stencil.py | 14 +++++++------- tests/dsl/test_stencil_factory.py | 17 +++++++++++++++++ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index bb9f5c70..085e4822 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -349,14 +349,14 @@ def __init__( ): unblock_waiting_tiles(MPI.COMM_WORLD) - self._timing_collector.build_info[ - _stencil_object_name(self.stencil_object) - ] = build_info + self._timing_collector.build_info[_stencil_object_name(self.stencil_object)] = ( + build_info + ) field_info = self.stencil_object.field_info - self._field_origins: Dict[ - str, Tuple[int, ...] - ] = FrozenStencil._compute_field_origins(field_info, self.origin) + self._field_origins: Dict[str, Tuple[int, ...]] = ( + FrozenStencil._compute_field_origins(field_info, self.origin) + ) """mapping from field names to field origins""" self._stencil_run_kwargs: Dict[str, Any] = { @@ -737,7 +737,7 @@ def axis_offsets( "local_js": gtscript.J[0] + self.jsc - origin[1], "j_end": j_end, "local_je": gtscript.J[-1] + self.jec - origin[1] - domain[1] + 1, - "k_start": self.origin[2], + "k_start": origin[2], "k_end": self.origin[2] + domain[2] - 1, } diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index ac189ad8..aa7a23fd 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -107,6 +107,23 @@ def test_get_stencils_with_varied_bounds_and_regions(backend: str): np.testing.assert_array_equal(q_orig.data, q_ref.data) +def test_stencil_vertical_bounds(backend: str): + factory = get_stencil_factory(backend) + origins = [(3, 3, 0), (2, 2, 1)] + domains = [(1, 1, 3), (2, 2, 4)] + stencils = get_stencils_with_varied_bounds( + add_1_in_region_stencil, + origins, + domains, + stencil_factory=factory, + ) + + assert "k_start" in stencils[0].externals and stencils[0].externals["k_start"] == 0 + assert "k_end" in stencils[0].externals and stencils[0].externals["k_end"] == 2 + assert "k_start" in stencils[1].externals and stencils[1].externals["k_start"] == 1 + assert "k_end" in stencils[1].externals and stencils[1].externals["k_end"] == 3 + + @pytest.mark.parametrize("enabled", [True, False]) def test_stencil_factory_numpy_comparison_from_dims_halo(enabled: bool): backend = "numpy" From b0b2940502f36d0f5c8fc18c9becae6546043d7f Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 7 Oct 2024 11:43:02 -0400 Subject: [PATCH 3/4] lint --- ndsl/dsl/stencil.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 085e4822..25b3388a 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -349,14 +349,14 @@ def __init__( ): unblock_waiting_tiles(MPI.COMM_WORLD) - self._timing_collector.build_info[_stencil_object_name(self.stencil_object)] = ( - build_info - ) + self._timing_collector.build_info[ + _stencil_object_name(self.stencil_object) + ] = build_info field_info = self.stencil_object.field_info - self._field_origins: Dict[str, Tuple[int, ...]] = ( - FrozenStencil._compute_field_origins(field_info, self.origin) - ) + self._field_origins: Dict[ + str, Tuple[int, ...] + ] = FrozenStencil._compute_field_origins(field_info, self.origin) """mapping from field names to field origins""" self._stencil_run_kwargs: Dict[str, Any] = { From 0c7c90225d7b6731aa1838113653cc3253aa37eb Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Mon, 7 Oct 2024 12:03:03 -0400 Subject: [PATCH 4/4] Fix for 2d stencils --- ndsl/dsl/stencil.py | 5 +++-- tests/dsl/test_stencil_factory.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ndsl/dsl/stencil.py b/ndsl/dsl/stencil.py index 25b3388a..5e917e66 100644 --- a/ndsl/dsl/stencil.py +++ b/ndsl/dsl/stencil.py @@ -737,8 +737,9 @@ def axis_offsets( "local_js": gtscript.J[0] + self.jsc - origin[1], "j_end": j_end, "local_je": gtscript.J[-1] + self.jec - origin[1] - domain[1] + 1, - "k_start": origin[2], - "k_end": self.origin[2] + domain[2] - 1, + "k_start": origin[2] if len(origin) > 2 else 0, + "k_end": (origin[2] if len(origin) > 2 else 0) + + (domain[2] - 1 if len(domain) > 2 else 0), } def get_origin_domain( diff --git a/tests/dsl/test_stencil_factory.py b/tests/dsl/test_stencil_factory.py index aa7a23fd..2af1218d 100644 --- a/tests/dsl/test_stencil_factory.py +++ b/tests/dsl/test_stencil_factory.py @@ -121,7 +121,7 @@ def test_stencil_vertical_bounds(backend: str): assert "k_start" in stencils[0].externals and stencils[0].externals["k_start"] == 0 assert "k_end" in stencils[0].externals and stencils[0].externals["k_end"] == 2 assert "k_start" in stencils[1].externals and stencils[1].externals["k_start"] == 1 - assert "k_end" in stencils[1].externals and stencils[1].externals["k_end"] == 3 + assert "k_end" in stencils[1].externals and stencils[1].externals["k_end"] == 4 @pytest.mark.parametrize("enabled", [True, False])