Skip to content

Commit

Permalink
fix[next][dace]: fix map fusion and loop blocking (#1856)
Browse files Browse the repository at this point in the history
Improve optimization for icon4py stencil `apply_diffusion_to_vn` by
means of two changes:
- Ignore check of dynamic volume property on memlet that prevented
serial map fusion.
- Add support for NestedSDFG nodes in loop blocking transformation.
  • Loading branch information
edopao authored Feb 11, 2025
1 parent c96e19e commit 64b90dc
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,24 @@ def _classify_node(
):
return False

# Test if the body of the Tasklet depends on the block variable.
if self.blocking_parameter in node_to_classify.free_symbols:
return False

elif isinstance(node_to_classify, dace.nodes.NestedSDFG):
# Same check as for Tasklets applies to the outputs of a nested SDFG node
if not all(
isinstance(out_edge.dst, dace_nodes.AccessNode)
for out_edge in state.out_edges(node_to_classify)
if not out_edge.data.is_empty()
):
return False

# Additionally, test if the symbol mapping depends on the block variable.
for v in node_to_classify.symbol_mapping.values():
if self.blocking_parameter in v.free_symbols:
return False

elif isinstance(node_to_classify, dace_nodes.AccessNode):
# AccessNodes need to have some special properties.
node_desc: dace.data.Data = node_to_classify.desc(sdfg)
Expand Down Expand Up @@ -422,16 +440,6 @@ def _classify_node(
if out_edge.dst is outer_exit:
return False

# Now we have ensured that the partition exists, thus we will now evaluate
# if the node is independent or dependent.

# Test if the body of the Tasklet depends on the block variable.
if (
isinstance(node_to_classify, dace_nodes.Tasklet)
and self.blocking_parameter in node_to_classify.free_symbols
):
return False

# Now we have to look at incoming edges individually.
# We will inspect the subset of the Memlet to see if they depend on the
# block variable. If this loop ends normally, then we classify the node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def can_be_applied(
output_partition = self.partition_first_outputs(
state=graph,
sdfg=sdfg,
map_exit_1=map_exit_1,
map_entry_2=map_entry_2,
first_map_exit=map_exit_1,
second_map_entry=map_entry_2,
)
if output_partition is None:
return False
Expand Down Expand Up @@ -375,8 +375,8 @@ def partition_first_outputs(
self,
state: SDFGState,
sdfg: SDFG,
map_exit_1: nodes.MapExit,
map_entry_2: nodes.MapEntry,
first_map_exit: nodes.MapExit,
second_map_entry: nodes.MapEntry,
) -> Union[
Tuple[
Set[graph.MultiConnectorEdge[dace.Memlet]],
Expand All @@ -385,19 +385,19 @@ def partition_first_outputs(
],
None,
]:
"""Partition the output edges of `map_exit_1` for serial map fusion.
"""Partition the output edges of `first_map_exit` for serial map fusion.
The output edges of the first map are partitioned into three distinct sets,
defined as follows:
- Pure Output Set `\mathbb{P}`:
* Pure Output Set `\mathbb{P}`:
These edges exits the first map and does not enter the second map. These
outputs will be simply be moved to the output of the second map.
- Exclusive Intermediate Set `\mathbb{E}`:
* Exclusive Intermediate Set `\mathbb{E}`:
Edges in this set leaves the first map exit, enters an access node, from
where a Memlet then leads immediately to the second map. The memory
referenced by this access node is not used anywhere else, thus it can
be removed.
- Shared Intermediate Set `\mathbb{S}`:
* Shared Intermediate Set `\mathbb{S}`:
These edges are very similar to the one in `\mathbb{E}` except that they
are used somewhere else, thus they can not be removed and have to be
recreated as output of the second map.
Expand All @@ -406,17 +406,14 @@ def partition_first_outputs(
output can be added to either intermediate set and might fail to compute
the partition, even if it would exist.
Returns:
If such a decomposition exists the function will return the three sets
mentioned above in the same order.
In case the decomposition does not exist, i.e. the maps can not be fused
the function returns `None`.
:return: If such a decomposition exists the function will return the three sets
mentioned above in the same order. In case the decomposition does not exist,
i.e. the maps can not be fused the function returns `None`.
Args:
state: The in which the two maps are located.
sdfg: The full SDFG in whcih we operate.
map_exit_1: The exit node of the first map.
map_entry_2: The entry node of the second map.
:param state: The in which the two maps are located.
:param sdfg: The full SDFG in whcih we operate.
:param first_map_exit: The exit node of the first map.
:param second_map_entry: The entry node of the second map.
"""
# The three outputs set.
pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set()
Expand All @@ -425,28 +422,17 @@ def partition_first_outputs(

# Compute the renaming that for translating the parameter of the _second_
# map to the ones used by the first map.
repl_dict: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment]
first_map=map_exit_1.map,
second_map=map_entry_2.map,
param_repl: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment]
first_map=first_map_exit.map,
second_map=second_map_entry.map,
)
assert repl_dict is not None
assert param_repl is not None

# Set of intermediate nodes that we have already processed.
processed_inter_nodes: Set[nodes.Node] = set()

# These are the data that is written to multiple times in _this_ state.
# If a data is written to multiple time in a state, it could be
# classified as shared. However, it might happen that the node has zero
# degree. This is not a problem as the maps also induced a before-after
# relationship. But some DaCe transformations do not catch this.
# Thus we will never modify such intermediate nodes and fail instead.
if self.strict_dataflow:
multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg)
else:
multi_write_data = set()

# Now scan all output edges of the first exit and classify them
for out_edge in state.out_edges(map_exit_1):
for out_edge in state.out_edges(first_map_exit):
intermediate_node: nodes.Node = out_edge.dst

# We already processed the node, this should indicate that we should
Expand All @@ -469,7 +455,7 @@ def partition_first_outputs(
if not self.is_node_reachable_from(
graph=state,
begin=intermediate_node,
end=map_entry_2,
end=second_map_entry,
):
pure_outputs.add(out_edge)
continue
Expand All @@ -479,6 +465,12 @@ def partition_first_outputs(
# cases, as handling them is essentially rerouting an edge, whereas
# handling intermediate nodes is much more complicated.

# Empty Memlets are only allowed if they are in `\mathbb{P}`, which
# is also the only place they really make sense (for a map exit).
# Thus if we now found an empty Memlet we reject it.
if out_edge.data.is_empty():
return None

# For us an intermediate node must always be an access node, because
# everything else we do not know how to handle. It is important that
# we do not test for non transient data here, because they can be
Expand All @@ -488,30 +480,14 @@ def partition_first_outputs(
if self.is_view(intermediate_node, sdfg):
return None

# Checks if the intermediate node refers to data that is accessed by
# _other_ access nodes in _this_ state. If this is the case then never
# touch this intermediate node.
# TODO(phimuell): Technically it would be enough to turn the node into
# a shared output node, because this will still fulfil the dependencies.
# However, some DaCe transformation can not handle this properly, so we
# are _forced_ to reject this node.
if intermediate_node.data in multi_write_data:
return None

# Empty Memlets are only allowed if they are in `\mathbb{P}`, which
# is also the only place they really make sense (for a map exit).
# Thus if we now found an empty Memlet we reject it.
if out_edge.data.is_empty():
return None

# It can happen that multiple edges converges at the `IN_` connector
# of the first map exit, but there is only one edge leaving the exit.
# It is complicate to handle this, so for now we ignore it.
# TODO(phimuell): Handle this case properly.
# To handle this we need to associate a consumer edge (the outgoing edges
# of the second map) with exactly one producer.
producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list(
state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])
state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:])
)
if len(producer_edges) > 1:
return None
Expand All @@ -520,7 +496,7 @@ def partition_first_outputs(
# - The source of the producer can not be a view (we do not handle this)
# - The edge shall also not be a reduction edge.
# - Defined location to where they write.
# - No dynamic Memlets.
# - No dynamic Melets.
# Furthermore, we will also extract the subsets, i.e. the location they
# modify inside the intermediate array.
# Since we do not allow for WCR, we do not check if the producer subsets intersects.
Expand All @@ -531,6 +507,7 @@ def partition_first_outputs(
):
return None
if producer_edge.data.dynamic:
# TODO(phimuell): Find out if this restriction could be lifted, but it is unlikely.
return None
if producer_edge.data.wcr is not None:
return None
Expand Down Expand Up @@ -562,9 +539,9 @@ def partition_first_outputs(
for intermediate_node_out_edge in state.out_edges(intermediate_node):
# If the second map entry is not immediately reachable from the intermediate
# node, then ensure that there is not path that goes to it.
if intermediate_node_out_edge.dst is not map_entry_2:
if intermediate_node_out_edge.dst is not second_map_entry:
if self.is_node_reachable_from(
graph=state, begin=intermediate_node_out_edge.dst, end=map_entry_2
graph=state, begin=intermediate_node_out_edge.dst, end=second_map_entry
):
return None
continue
Expand All @@ -583,27 +560,28 @@ def partition_first_outputs(
# Now we look at all edges that leave the second map entry, i.e. the
# edges that feeds the consumer and define what is read inside the map.
# We do not check them, but collect them and inspect them.
# NOTE: The subset still uses the old iteration variables.
# NOTE1: The subset still uses the old iteration variables.
# NOTE2: In case of consumer Memlet we explicitly allow dynamic Memlets.
# This is different compared to the producer Memlet. The reason is
# because in a consumer the data is conditionally read, so the data
# has to exists anyway.
for inner_consumer_edge in state.out_edges_by_connector(
map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:]
second_map_entry, "OUT_" + intermediate_node_out_edge.dst_conn[3:]
):
if inner_consumer_edge.data.src_subset is None:
return None
if inner_consumer_edge.data.dynamic:
# TODO(phimuell): Is this restriction necessary, I am not sure.
return None
consumer_subsets.append(inner_consumer_edge.data.src_subset)
assert (
found_second_map
), f"Found '{intermediate_node}' which looked like a pure node, but is not one."
assert len(consumer_subsets) != 0

# The consumer still uses the original symbols of the second map, so we must rename them.
if repl_dict:
if param_repl:
consumer_subsets = copy.deepcopy(consumer_subsets)
for consumer_subset in consumer_subsets:
symbolic.safe_replace(
mapping=repl_dict, replace_callback=consumer_subset.replace
mapping=param_repl, replace_callback=consumer_subset.replace
)

# Now we are checking if a single iteration of the first (top) map
Expand All @@ -623,6 +601,21 @@ def partition_first_outputs(
# output of the second map.
if self.is_shared_data(data=intermediate_node, state=state, sdfg=sdfg):
# The intermediate data is used somewhere else, either in this or another state.
# NOTE: If the intermediate is shared, then we will turn it into a
# sink node attached to the combined map exit. Technically this
# should be enough, even if the same data appears again in the
# dataflow down streams. However, some DaCe transformations,
# I am looking at you `auto_optimizer()` do not like that. Thus
# if the intermediate is used further down in the same datadflow,
# then we consider that the maps can not be fused. But we only
# do this in the strict data flow mode.
if self.strict_dataflow:
if self._is_data_accessed_downstream(
data=intermediate_node.data,
graph=state,
begin=intermediate_node, # is ignored itself.
):
return None
shared_outputs.add(out_edge)
else:
# The intermediate can be removed, as it is not used anywhere else.
Expand Down Expand Up @@ -669,8 +662,8 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non
output_partition = self.partition_first_outputs(
state=graph,
sdfg=sdfg,
map_exit_1=map_exit_1,
map_entry_2=map_entry_2,
first_map_exit=map_exit_1,
second_map_entry=map_entry_2,
)
assert output_partition is not None # Make MyPy happy.
pure_outputs, exclusive_outputs, shared_outputs = output_partition
Expand Down
Loading

0 comments on commit 64b90dc

Please sign in to comment.