Skip to content

Commit

Permalink
Merge branch 'main' into decouple_inferences
Browse files Browse the repository at this point in the history
  • Loading branch information
SF-N authored Feb 14, 2025
2 parents 9050385 + 8595cfd commit 2a9a25a
Show file tree
Hide file tree
Showing 15 changed files with 2,525 additions and 1,934 deletions.
6 changes: 6 additions & 0 deletions src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,12 @@ def replace(self, index: int | Dimension, *named_ranges: NamedRange) -> Domain:

return Domain(dims=dims, ranges=ranges)

def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
# remove cached property
state.pop("slice_at", None)
return state


FiniteDomain: TypeAlias = Domain[FiniteUnitRange]

Expand Down
6 changes: 4 additions & 2 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,13 @@ def _is_collectable_expr(node: itir.Node) -> bool:
if isinstance(node, itir.FunCall):
# do not collect (and thus deduplicate in CSE) shift(offsets…) calls. Node must still be
# visited, to ensure symbol dependencies are recognized correctly.
# do also not collect reduce nodes if they are left in the it at this point, this may lead to
# do also not collect reduce nodes if they are left in the IR at this point, this may lead to
# conceptual problems (other parts of the tool chain rely on the arguments being present directly
# on the reduce FunCall node (connectivity deduction)), as well as problems with the imperative backend
# backend (single pass eager depth first visit approach)
if isinstance(node.fun, itir.SymRef) and node.fun.id in ["lift", "shift", "reduce", "map_"]:
# do also not collect lifts or applied lifts as they become invisible to the lift inliner
# otherwise
if cpm.is_call_to(node, ("lift", "shift", "reduce", "map_")) or cpm.is_applied_lift(node):
return False
return True
elif isinstance(node, itir.Lambda):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
)
from .local_double_buffering import gt_create_local_double_buffering
from .loop_blocking import LoopBlocking
from .map_fusion_parallel import MapFusionParallel
from .map_fusion_serial import MapFusionSerial
from .map_fusion import MapFusion, MapFusionParallel, MapFusionSerial
from .map_orderer import MapIterationOrder, gt_set_iteration_order
from .map_promoter import SerialMapPromoter
from .simplify import (
Expand Down Expand Up @@ -52,6 +51,7 @@
"GT4PyMapBufferElimination",
"GT4PyMoveTaskletIntoMap",
"LoopBlocking",
"MapFusion",
"MapFusionParallel",
"MapFusionSerial",
"MapIterationOrder",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import dace
from dace.transformation import dataflow as dace_dataflow
from dace.transformation.auto import auto_optimize as dace_aoptimize
from dace.transformation.passes import analysis as dace_analysis

from gt4py.next import common as gtx_common
from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations
Expand Down Expand Up @@ -328,22 +329,28 @@ def gt_auto_fuse_top_level_maps(
# after the other, thus new opportunities might arise in the next round.
# We use the hash of the SDFG to detect if we have reached a fix point.
for _ in range(max_optimization_rounds):
# Use map fusion to reduce their number and to create big kernels
# TODO(phimuell): Use a cost measurement to decide if fusion should be done.
# TODO(phimuell): Add parallel fusion transformation. Should it run after
# or with the serial one?
# TODO(phimuell): Switch to `FullMapFusion` once DaCe has parallel map fusion
# and [issue#1911](https://github.com/spcl/dace/issues/1911) has been solved.

# First we do scan the entire SDFG to figure out which data is only
# used once and can be deleted. MapFusion could do this on its own but
# it is more efficient to do it once and then reuse it.
find_single_use_data = dace_analysis.FindSingleUseData()
single_use_data = find_single_use_data.apply_pass(sdfg, None)

fusion_transformation = gtx_transformations.MapFusion(
only_toplevel_maps=True,
allow_parallel_map_fusion=True,
allow_serial_map_fusion=True,
only_if_common_ancestor=False,
)
fusion_transformation._single_use_data = single_use_data

sdfg.apply_transformations_repeated(
[
gtx_transformations.MapFusionSerial(
only_toplevel_maps=True,
),
gtx_transformations.MapFusionParallel(
only_toplevel_maps=True,
# This will lead to the creation of big probably unrelated maps.
# However, it might be good.
only_if_common_ancestor=False,
),
],
fusion_transformation,
validate=validate,
validate_all=validate_all,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def gt_gpu_transform_non_standard_memlet(
# This function allows to restrict any fusion operation to the maps
# that we have just created.
def restrict_fusion_to_newly_created_maps(
self: gtx_transformations.map_fusion_helper.MapFusionHelper,
self: gtx_transformations.MapFusion,
map_entry_1: dace_nodes.MapEntry,
map_entry_2: dace_nodes.MapEntry,
graph: Union[dace.SDFGState, dace.SDFG],
Expand Down Expand Up @@ -690,9 +690,9 @@ def can_be_applied(
self._promote_map(graph, replace_trivail_map_parameter=False)
if not gtx_transformations.MapFusionSerial.can_be_applied_to(
sdfg=sdfg,
map_exit_1=trivial_map_exit,
intermediate_access_node=self.access_node,
map_entry_2=self.second_map_entry,
first_map_exit=trivial_map_exit,
array=self.access_node,
second_map_entry=self.second_map_entry,
):
return False
finally:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,9 @@ def _prepare_inner_outer_maps(
inner_label = f"inner_{outer_map.label}"
inner_range = {
self.blocking_parameter: dace_subsets.Range.from_string(
f"({coarse_block_var} * {self.blocking_size} + {rng_start})"
f"(({rng_start}) + ({coarse_block_var}) * ({self.blocking_size}))"
+ ":"
+ f"min(({rng_start} + {coarse_block_var} + 1) * {self.blocking_size}, {rng_stop} + 1)"
+ f"min(({rng_start}) + ({coarse_block_var} + 1) * ({self.blocking_size}), ({rng_stop}) + 1)"
)
}
inner_entry, inner_exit = state.add_map(
Expand All @@ -219,7 +219,7 @@ def _prepare_inner_outer_maps(

# Now we modify the properties of the outer map.
coarse_block_range = dace_subsets.Range.from_string(
f"0:int_ceil(({rng_stop} + 1) - {rng_start}, {self.blocking_size})"
f"0:int_ceil((({rng_stop}) + 1) - ({rng_start}), ({self.blocking_size}))"
).ranges[0]
outer_map.params[blocking_parameter_dim] = coarse_block_var
outer_map.range[blocking_parameter_dim] = coarse_block_range
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause
"""An interface between DaCe's MapFusion and the one of GT4Py."""

# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.

from typing import Any, Callable, Optional, TypeAlias, TypeVar, Union

import dace
from dace import nodes as dace_nodes, properties as dace_properties

from gt4py.next.program_processors.runners.dace.transformations import (
map_fusion_dace as dace_map_fusion,
)


_MapFusionType = TypeVar("_MapFusionType", bound="dace_map_fusion.MapFusion")

FusionTestCallback: TypeAlias = Callable[
[_MapFusionType, dace_nodes.MapEntry, dace_nodes.MapEntry, dace.SDFGState, dace.SDFG, int], bool
]
"""Callback for the map fusion transformation to check if a fusion should be performed.
The callback returns `True` if the fusion should be performed and `False` if it
should be rejected. See also the description of GT4Py's MapFusion transformation for
more information.
The arguments are as follows:
- The transformation object that is active.
- The MapEntry node of the first map; exact meaning depends on if parallel or
serial map fusion is performed.
- The MapEntry node of the second map; exact meaning depends on if parallel or
serial map fusion is performed.
- The SDFGState that that contains the data flow.
- The SDFG that is processed.
- The expression index, see `expr_index` in `can_be_applied()` it is `0` for
serial map fusion and `1` for parallel map fusion.
"""


@dace_properties.make_properties
class MapFusion(dace_map_fusion.MapFusion):
"""GT4Py's MapFusion transformation.
It is a wrapper that adds some functionality to the transformation that is not
present in the DaCe version of this transformation.
There are three important differences when compared with DaCe's MapFusion:
- In DaCe strict data flow is enabled by default, in GT4Py it is disabled by default.
- In DaCe `MapFusion` only performs the fusion of serial maps by default. In GT4Py
`MapFusion` will also perform parallel map fusion by default.
- GT4Py accepts an additional argument `apply_fusion_callback`. This is a
function that is called by the transformation, at the _beginning_ of
`self.can_be_applied()`, i.e. before the transformation does any check if
the maps can be fused. If this function returns `False`, `self.can_be_applied()`
ends and returns `False`. In case the callback returns `True` the transformation
will perform the usual steps to check if the transformation can apply or not.
For the signature see `FusionTestCallback`.
Args:
only_inner_maps: Only match Maps that are internal, i.e. inside another Map.
only_toplevel_maps: Only consider Maps that are at the top.
strict_dataflow: Strict dataflow mode should be used, it is disabled by default.
assume_always_shared: Assume that all intermediates are shared.
allow_serial_map_fusion: Allow serial map fusion, by default `True`.
allow_parallel_fusion: Allow to merge parallel maps, by default `True`.
only_if_common_ancestor: In parallel map fusion mode, only fuse if both maps
have a common direct ancestor.
apply_fusion_callback: The callback function that is used.
Todo:
Investigate ways of how to remove this intermediate layer. The main reason
why we need it is the callback functionality, but it is not needed often
and in these cases it might be solved differently.
"""

_apply_fusion_callback: Optional[FusionTestCallback]

def __init__(
self,
strict_dataflow: bool = False,
allow_serial_map_fusion: bool = True,
allow_parallel_map_fusion: bool = True,
apply_fusion_callback: Optional[FusionTestCallback] = None,
**kwargs: Any,
) -> None:
self._apply_fusion_callback = None
super().__init__(
strict_dataflow=strict_dataflow,
allow_serial_map_fusion=allow_serial_map_fusion,
allow_parallel_map_fusion=allow_parallel_map_fusion,
**kwargs,
)
if apply_fusion_callback is not None:
self._apply_fusion_callback = apply_fusion_callback

def can_be_applied(
self,
graph: Union[dace.SDFGState, dace.SDFG],
expr_index: int,
sdfg: dace.SDFG,
permissive: bool = False,
) -> bool:
"""Performs basic checks if the maps can be fused.
Args:
map_entry_1: The entry of the first (in serial case the top) map.
map_exit_2: The entry of the second (in serial case the bottom) map.
graph: The SDFGState in which the maps are located.
sdfg: The SDFG itself.
permissive: Currently unused.
"""
assert expr_index in [0, 1]

# If the call back is given then proceed with it.
if self._apply_fusion_callback is not None:
if expr_index == 0: # Serial MapFusion.
first_map_entry: dace_nodes.MapEntry = graph.entry_node(self.first_map_exit)
second_map_entry: dace_nodes.MapEntry = self.second_map_entry
elif expr_index == 1: # Parallel MapFusion
first_map_entry = self.first_parallel_map_entry
second_map_entry = self.second_parallel_map_entry
else:
raise NotImplementedError(f"Not implemented expression: {expr_index}")

# Apply the call back.
if not self._apply_fusion_callback(
self,
first_map_entry,
second_map_entry,
graph,
sdfg,
expr_index,
):
return False

# Now forward to the underlying implementation.
return super().can_be_applied(
graph=graph,
expr_index=expr_index,
sdfg=sdfg,
permissive=permissive,
)


@dace_properties.make_properties
class MapFusionSerial(MapFusion):
"""Wrapper around `MapFusion` that only supports serial map fusion.
Note:
This class exists only for the transition period.
"""

def __init__(
self,
**kwargs: Any,
) -> None:
assert "allow_serial_map_fusion" not in kwargs
assert "allow_parallel_map_fusion" not in kwargs
super().__init__(
allow_serial_map_fusion=True,
allow_parallel_map_fusion=False,
**kwargs,
)


@dace_properties.make_properties
class MapFusionParallel(MapFusion):
"""Wrapper around `MapFusion` that only supports parallel map fusion.
Note:
This class exists only for the transition period.
"""

def __init__(
self,
**kwargs: Any,
) -> None:
assert "allow_serial_map_fusion" not in kwargs
assert "allow_parallel_map_fusion" not in kwargs
super().__init__(
allow_serial_map_fusion=False,
allow_parallel_map_fusion=True,
**kwargs,
)
Loading

0 comments on commit 2a9a25a

Please sign in to comment.