Skip to content

Commit 64b90dc

Browse files
authored
fix[next][dace]: fix map fusion and loop blocking (#1856)
Improve optimization for icon4py stencil `apply_diffusion_to_vn` by means of two changes: - Ignore check of dynamic volume property on memlet that prevented serial map fusion. - Add support for NestedSDFG nodes in loop blocking transformation.
1 parent c96e19e commit 64b90dc

File tree

3 files changed

+175
-99
lines changed

3 files changed

+175
-99
lines changed

src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,24 @@ def _classify_node(
380380
):
381381
return False
382382

383+
# Test if the body of the Tasklet depends on the block variable.
384+
if self.blocking_parameter in node_to_classify.free_symbols:
385+
return False
386+
387+
elif isinstance(node_to_classify, dace.nodes.NestedSDFG):
388+
# Same check as for Tasklets applies to the outputs of a nested SDFG node
389+
if not all(
390+
isinstance(out_edge.dst, dace_nodes.AccessNode)
391+
for out_edge in state.out_edges(node_to_classify)
392+
if not out_edge.data.is_empty()
393+
):
394+
return False
395+
396+
# Additionally, test if the symbol mapping depends on the block variable.
397+
for v in node_to_classify.symbol_mapping.values():
398+
if self.blocking_parameter in v.free_symbols:
399+
return False
400+
383401
elif isinstance(node_to_classify, dace_nodes.AccessNode):
384402
# AccessNodes need to have some special properties.
385403
node_desc: dace.data.Data = node_to_classify.desc(sdfg)
@@ -422,16 +440,6 @@ def _classify_node(
422440
if out_edge.dst is outer_exit:
423441
return False
424442

425-
# Now we have ensured that the partition exists, thus we will now evaluate
426-
# if the node is independent or dependent.
427-
428-
# Test if the body of the Tasklet depends on the block variable.
429-
if (
430-
isinstance(node_to_classify, dace_nodes.Tasklet)
431-
and self.blocking_parameter in node_to_classify.free_symbols
432-
):
433-
return False
434-
435443
# Now we have to look at incoming edges individually.
436444
# We will inspect the subset of the Memlet to see if they depend on the
437445
# block variable. If this loop ends normally, then we classify the node

src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py

Lines changed: 57 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,8 @@ def can_be_applied(
125125
output_partition = self.partition_first_outputs(
126126
state=graph,
127127
sdfg=sdfg,
128-
map_exit_1=map_exit_1,
129-
map_entry_2=map_entry_2,
128+
first_map_exit=map_exit_1,
129+
second_map_entry=map_entry_2,
130130
)
131131
if output_partition is None:
132132
return False
@@ -375,8 +375,8 @@ def partition_first_outputs(
375375
self,
376376
state: SDFGState,
377377
sdfg: SDFG,
378-
map_exit_1: nodes.MapExit,
379-
map_entry_2: nodes.MapEntry,
378+
first_map_exit: nodes.MapExit,
379+
second_map_entry: nodes.MapEntry,
380380
) -> Union[
381381
Tuple[
382382
Set[graph.MultiConnectorEdge[dace.Memlet]],
@@ -385,19 +385,19 @@ def partition_first_outputs(
385385
],
386386
None,
387387
]:
388-
"""Partition the output edges of `map_exit_1` for serial map fusion.
388+
"""Partition the output edges of `first_map_exit` for serial map fusion.
389389
390390
The output edges of the first map are partitioned into three distinct sets,
391391
defined as follows:
392-
- Pure Output Set `\mathbb{P}`:
392+
* Pure Output Set `\mathbb{P}`:
393393
These edges exits the first map and does not enter the second map. These
394394
outputs will be simply be moved to the output of the second map.
395-
- Exclusive Intermediate Set `\mathbb{E}`:
395+
* Exclusive Intermediate Set `\mathbb{E}`:
396396
Edges in this set leaves the first map exit, enters an access node, from
397397
where a Memlet then leads immediately to the second map. The memory
398398
referenced by this access node is not used anywhere else, thus it can
399399
be removed.
400-
- Shared Intermediate Set `\mathbb{S}`:
400+
* Shared Intermediate Set `\mathbb{S}`:
401401
These edges are very similar to the one in `\mathbb{E}` except that they
402402
are used somewhere else, thus they can not be removed and have to be
403403
recreated as output of the second map.
@@ -406,17 +406,14 @@ def partition_first_outputs(
406406
output can be added to either intermediate set and might fail to compute
407407
the partition, even if it would exist.
408408
409-
Returns:
410-
If such a decomposition exists the function will return the three sets
411-
mentioned above in the same order.
412-
In case the decomposition does not exist, i.e. the maps can not be fused
413-
the function returns `None`.
409+
:return: If such a decomposition exists the function will return the three sets
410+
mentioned above in the same order. In case the decomposition does not exist,
411+
i.e. the maps can not be fused the function returns `None`.
414412
415-
Args:
416-
state: The in which the two maps are located.
417-
sdfg: The full SDFG in whcih we operate.
418-
map_exit_1: The exit node of the first map.
419-
map_entry_2: The entry node of the second map.
413+
:param state: The in which the two maps are located.
414+
:param sdfg: The full SDFG in whcih we operate.
415+
:param first_map_exit: The exit node of the first map.
416+
:param second_map_entry: The entry node of the second map.
420417
"""
421418
# The three outputs set.
422419
pure_outputs: Set[graph.MultiConnectorEdge[dace.Memlet]] = set()
@@ -425,28 +422,17 @@ def partition_first_outputs(
425422

426423
# Compute the renaming that for translating the parameter of the _second_
427424
# map to the ones used by the first map.
428-
repl_dict: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment]
429-
first_map=map_exit_1.map,
430-
second_map=map_entry_2.map,
425+
param_repl: Dict[str, str] = self.find_parameter_remapping( # type: ignore[assignment]
426+
first_map=first_map_exit.map,
427+
second_map=second_map_entry.map,
431428
)
432-
assert repl_dict is not None
429+
assert param_repl is not None
433430

434431
# Set of intermediate nodes that we have already processed.
435432
processed_inter_nodes: Set[nodes.Node] = set()
436433

437-
# These are the data that is written to multiple times in _this_ state.
438-
# If a data is written to multiple time in a state, it could be
439-
# classified as shared. However, it might happen that the node has zero
440-
# degree. This is not a problem as the maps also induced a before-after
441-
# relationship. But some DaCe transformations do not catch this.
442-
# Thus we will never modify such intermediate nodes and fail instead.
443-
if self.strict_dataflow:
444-
multi_write_data: Set[str] = self._compute_multi_write_data(state, sdfg)
445-
else:
446-
multi_write_data = set()
447-
448434
# Now scan all output edges of the first exit and classify them
449-
for out_edge in state.out_edges(map_exit_1):
435+
for out_edge in state.out_edges(first_map_exit):
450436
intermediate_node: nodes.Node = out_edge.dst
451437

452438
# We already processed the node, this should indicate that we should
@@ -469,7 +455,7 @@ def partition_first_outputs(
469455
if not self.is_node_reachable_from(
470456
graph=state,
471457
begin=intermediate_node,
472-
end=map_entry_2,
458+
end=second_map_entry,
473459
):
474460
pure_outputs.add(out_edge)
475461
continue
@@ -479,6 +465,12 @@ def partition_first_outputs(
479465
# cases, as handling them is essentially rerouting an edge, whereas
480466
# handling intermediate nodes is much more complicated.
481467

468+
# Empty Memlets are only allowed if they are in `\mathbb{P}`, which
469+
# is also the only place they really make sense (for a map exit).
470+
# Thus if we now found an empty Memlet we reject it.
471+
if out_edge.data.is_empty():
472+
return None
473+
482474
# For us an intermediate node must always be an access node, because
483475
# everything else we do not know how to handle. It is important that
484476
# we do not test for non transient data here, because they can be
@@ -488,30 +480,14 @@ def partition_first_outputs(
488480
if self.is_view(intermediate_node, sdfg):
489481
return None
490482

491-
# Checks if the intermediate node refers to data that is accessed by
492-
# _other_ access nodes in _this_ state. If this is the case then never
493-
# touch this intermediate node.
494-
# TODO(phimuell): Technically it would be enough to turn the node into
495-
# a shared output node, because this will still fulfil the dependencies.
496-
# However, some DaCe transformation can not handle this properly, so we
497-
# are _forced_ to reject this node.
498-
if intermediate_node.data in multi_write_data:
499-
return None
500-
501-
# Empty Memlets are only allowed if they are in `\mathbb{P}`, which
502-
# is also the only place they really make sense (for a map exit).
503-
# Thus if we now found an empty Memlet we reject it.
504-
if out_edge.data.is_empty():
505-
return None
506-
507483
# It can happen that multiple edges converges at the `IN_` connector
508484
# of the first map exit, but there is only one edge leaving the exit.
509485
# It is complicate to handle this, so for now we ignore it.
510486
# TODO(phimuell): Handle this case properly.
511487
# To handle this we need to associate a consumer edge (the outgoing edges
512488
# of the second map) with exactly one producer.
513489
producer_edges: List[graph.MultiConnectorEdge[dace.Memlet]] = list(
514-
state.in_edges_by_connector(map_exit_1, "IN_" + out_edge.src_conn[4:])
490+
state.in_edges_by_connector(first_map_exit, "IN_" + out_edge.src_conn[4:])
515491
)
516492
if len(producer_edges) > 1:
517493
return None
@@ -520,7 +496,7 @@ def partition_first_outputs(
520496
# - The source of the producer can not be a view (we do not handle this)
521497
# - The edge shall also not be a reduction edge.
522498
# - Defined location to where they write.
523-
# - No dynamic Memlets.
499+
# - No dynamic Melets.
524500
# Furthermore, we will also extract the subsets, i.e. the location they
525501
# modify inside the intermediate array.
526502
# Since we do not allow for WCR, we do not check if the producer subsets intersects.
@@ -531,6 +507,7 @@ def partition_first_outputs(
531507
):
532508
return None
533509
if producer_edge.data.dynamic:
510+
# TODO(phimuell): Find out if this restriction could be lifted, but it is unlikely.
534511
return None
535512
if producer_edge.data.wcr is not None:
536513
return None
@@ -562,9 +539,9 @@ def partition_first_outputs(
562539
for intermediate_node_out_edge in state.out_edges(intermediate_node):
563540
# If the second map entry is not immediately reachable from the intermediate
564541
# node, then ensure that there is not path that goes to it.
565-
if intermediate_node_out_edge.dst is not map_entry_2:
542+
if intermediate_node_out_edge.dst is not second_map_entry:
566543
if self.is_node_reachable_from(
567-
graph=state, begin=intermediate_node_out_edge.dst, end=map_entry_2
544+
graph=state, begin=intermediate_node_out_edge.dst, end=second_map_entry
568545
):
569546
return None
570547
continue
@@ -583,27 +560,28 @@ def partition_first_outputs(
583560
# Now we look at all edges that leave the second map entry, i.e. the
584561
# edges that feeds the consumer and define what is read inside the map.
585562
# We do not check them, but collect them and inspect them.
586-
# NOTE: The subset still uses the old iteration variables.
563+
# NOTE1: The subset still uses the old iteration variables.
564+
# NOTE2: In case of consumer Memlet we explicitly allow dynamic Memlets.
565+
# This is different compared to the producer Memlet. The reason is
566+
# because in a consumer the data is conditionally read, so the data
567+
# has to exists anyway.
587568
for inner_consumer_edge in state.out_edges_by_connector(
588-
map_entry_2, "OUT_" + intermediate_node_out_edge.dst_conn[3:]
569+
second_map_entry, "OUT_" + intermediate_node_out_edge.dst_conn[3:]
589570
):
590571
if inner_consumer_edge.data.src_subset is None:
591572
return None
592-
if inner_consumer_edge.data.dynamic:
593-
# TODO(phimuell): Is this restriction necessary, I am not sure.
594-
return None
595573
consumer_subsets.append(inner_consumer_edge.data.src_subset)
596574
assert (
597575
found_second_map
598576
), f"Found '{intermediate_node}' which looked like a pure node, but is not one."
599577
assert len(consumer_subsets) != 0
600578

601579
# The consumer still uses the original symbols of the second map, so we must rename them.
602-
if repl_dict:
580+
if param_repl:
603581
consumer_subsets = copy.deepcopy(consumer_subsets)
604582
for consumer_subset in consumer_subsets:
605583
symbolic.safe_replace(
606-
mapping=repl_dict, replace_callback=consumer_subset.replace
584+
mapping=param_repl, replace_callback=consumer_subset.replace
607585
)
608586

609587
# Now we are checking if a single iteration of the first (top) map
@@ -623,6 +601,21 @@ def partition_first_outputs(
623601
# output of the second map.
624602
if self.is_shared_data(data=intermediate_node, state=state, sdfg=sdfg):
625603
# The intermediate data is used somewhere else, either in this or another state.
604+
# NOTE: If the intermediate is shared, then we will turn it into a
605+
# sink node attached to the combined map exit. Technically this
606+
# should be enough, even if the same data appears again in the
607+
# dataflow down streams. However, some DaCe transformations,
608+
# I am looking at you `auto_optimizer()` do not like that. Thus
609+
# if the intermediate is used further down in the same datadflow,
610+
# then we consider that the maps can not be fused. But we only
611+
# do this in the strict data flow mode.
612+
if self.strict_dataflow:
613+
if self._is_data_accessed_downstream(
614+
data=intermediate_node.data,
615+
graph=state,
616+
begin=intermediate_node, # is ignored itself.
617+
):
618+
return None
626619
shared_outputs.add(out_edge)
627620
else:
628621
# The intermediate can be removed, as it is not used anywhere else.
@@ -669,8 +662,8 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non
669662
output_partition = self.partition_first_outputs(
670663
state=graph,
671664
sdfg=sdfg,
672-
map_exit_1=map_exit_1,
673-
map_entry_2=map_entry_2,
665+
first_map_exit=map_exit_1,
666+
second_map_entry=map_entry_2,
674667
)
675668
assert output_partition is not None # Make MyPy happy.
676669
pure_outputs, exclusive_outputs, shared_outputs = output_partition

0 commit comments

Comments
 (0)