@@ -125,8 +125,8 @@ def can_be_applied(
125
125
output_partition = self .partition_first_outputs (
126
126
state = graph ,
127
127
sdfg = sdfg ,
128
- map_exit_1 = map_exit_1 ,
129
- map_entry_2 = map_entry_2 ,
128
+ first_map_exit = map_exit_1 ,
129
+ second_map_entry = map_entry_2 ,
130
130
)
131
131
if output_partition is None :
132
132
return False
@@ -375,8 +375,8 @@ def partition_first_outputs(
375
375
self ,
376
376
state : SDFGState ,
377
377
sdfg : SDFG ,
378
- map_exit_1 : nodes .MapExit ,
379
- map_entry_2 : nodes .MapEntry ,
378
+ first_map_exit : nodes .MapExit ,
379
+ second_map_entry : nodes .MapEntry ,
380
380
) -> Union [
381
381
Tuple [
382
382
Set [graph .MultiConnectorEdge [dace .Memlet ]],
@@ -385,19 +385,19 @@ def partition_first_outputs(
385
385
],
386
386
None ,
387
387
]:
388
- """Partition the output edges of `map_exit_1 ` for serial map fusion.
388
+ """Partition the output edges of `first_map_exit ` for serial map fusion.
389
389
390
390
The output edges of the first map are partitioned into three distinct sets,
391
391
defined as follows:
392
- - Pure Output Set `\mathbb{P}`:
392
+ * Pure Output Set `\mathbb{P}`:
393
393
These edges exits the first map and does not enter the second map. These
394
394
outputs will be simply be moved to the output of the second map.
395
- - Exclusive Intermediate Set `\mathbb{E}`:
395
+ * Exclusive Intermediate Set `\mathbb{E}`:
396
396
Edges in this set leaves the first map exit, enters an access node, from
397
397
where a Memlet then leads immediately to the second map. The memory
398
398
referenced by this access node is not used anywhere else, thus it can
399
399
be removed.
400
- - Shared Intermediate Set `\mathbb{S}`:
400
+ * Shared Intermediate Set `\mathbb{S}`:
401
401
These edges are very similar to the one in `\mathbb{E}` except that they
402
402
are used somewhere else, thus they can not be removed and have to be
403
403
recreated as output of the second map.
@@ -406,17 +406,14 @@ def partition_first_outputs(
406
406
output can be added to either intermediate set and might fail to compute
407
407
the partition, even if it would exist.
408
408
409
- Returns:
410
- If such a decomposition exists the function will return the three sets
411
- mentioned above in the same order.
412
- In case the decomposition does not exist, i.e. the maps can not be fused
413
- the function returns `None`.
409
+ :return: If such a decomposition exists the function will return the three sets
410
+ mentioned above in the same order. In case the decomposition does not exist,
411
+ i.e. the maps can not be fused the function returns `None`.
414
412
415
- Args:
416
- state: The in which the two maps are located.
417
- sdfg: The full SDFG in whcih we operate.
418
- map_exit_1: The exit node of the first map.
419
- map_entry_2: The entry node of the second map.
413
+ :param state: The in which the two maps are located.
414
+ :param sdfg: The full SDFG in whcih we operate.
415
+ :param first_map_exit: The exit node of the first map.
416
+ :param second_map_entry: The entry node of the second map.
420
417
"""
421
418
# The three outputs set.
422
419
pure_outputs : Set [graph .MultiConnectorEdge [dace .Memlet ]] = set ()
@@ -425,28 +422,17 @@ def partition_first_outputs(
425
422
426
423
# Compute the renaming that for translating the parameter of the _second_
427
424
# map to the ones used by the first map.
428
- repl_dict : Dict [str , str ] = self .find_parameter_remapping ( # type: ignore[assignment]
429
- first_map = map_exit_1 .map ,
430
- second_map = map_entry_2 .map ,
425
+ param_repl : Dict [str , str ] = self .find_parameter_remapping ( # type: ignore[assignment]
426
+ first_map = first_map_exit .map ,
427
+ second_map = second_map_entry .map ,
431
428
)
432
- assert repl_dict is not None
429
+ assert param_repl is not None
433
430
434
431
# Set of intermediate nodes that we have already processed.
435
432
processed_inter_nodes : Set [nodes .Node ] = set ()
436
433
437
- # These are the data that is written to multiple times in _this_ state.
438
- # If a data is written to multiple time in a state, it could be
439
- # classified as shared. However, it might happen that the node has zero
440
- # degree. This is not a problem as the maps also induced a before-after
441
- # relationship. But some DaCe transformations do not catch this.
442
- # Thus we will never modify such intermediate nodes and fail instead.
443
- if self .strict_dataflow :
444
- multi_write_data : Set [str ] = self ._compute_multi_write_data (state , sdfg )
445
- else :
446
- multi_write_data = set ()
447
-
448
434
# Now scan all output edges of the first exit and classify them
449
- for out_edge in state .out_edges (map_exit_1 ):
435
+ for out_edge in state .out_edges (first_map_exit ):
450
436
intermediate_node : nodes .Node = out_edge .dst
451
437
452
438
# We already processed the node, this should indicate that we should
@@ -469,7 +455,7 @@ def partition_first_outputs(
469
455
if not self .is_node_reachable_from (
470
456
graph = state ,
471
457
begin = intermediate_node ,
472
- end = map_entry_2 ,
458
+ end = second_map_entry ,
473
459
):
474
460
pure_outputs .add (out_edge )
475
461
continue
@@ -479,6 +465,12 @@ def partition_first_outputs(
479
465
# cases, as handling them is essentially rerouting an edge, whereas
480
466
# handling intermediate nodes is much more complicated.
481
467
468
+ # Empty Memlets are only allowed if they are in `\mathbb{P}`, which
469
+ # is also the only place they really make sense (for a map exit).
470
+ # Thus if we now found an empty Memlet we reject it.
471
+ if out_edge .data .is_empty ():
472
+ return None
473
+
482
474
# For us an intermediate node must always be an access node, because
483
475
# everything else we do not know how to handle. It is important that
484
476
# we do not test for non transient data here, because they can be
@@ -488,30 +480,14 @@ def partition_first_outputs(
488
480
if self .is_view (intermediate_node , sdfg ):
489
481
return None
490
482
491
- # Checks if the intermediate node refers to data that is accessed by
492
- # _other_ access nodes in _this_ state. If this is the case then never
493
- # touch this intermediate node.
494
- # TODO(phimuell): Technically it would be enough to turn the node into
495
- # a shared output node, because this will still fulfil the dependencies.
496
- # However, some DaCe transformation can not handle this properly, so we
497
- # are _forced_ to reject this node.
498
- if intermediate_node .data in multi_write_data :
499
- return None
500
-
501
- # Empty Memlets are only allowed if they are in `\mathbb{P}`, which
502
- # is also the only place they really make sense (for a map exit).
503
- # Thus if we now found an empty Memlet we reject it.
504
- if out_edge .data .is_empty ():
505
- return None
506
-
507
483
# It can happen that multiple edges converges at the `IN_` connector
508
484
# of the first map exit, but there is only one edge leaving the exit.
509
485
# It is complicate to handle this, so for now we ignore it.
510
486
# TODO(phimuell): Handle this case properly.
511
487
# To handle this we need to associate a consumer edge (the outgoing edges
512
488
# of the second map) with exactly one producer.
513
489
producer_edges : List [graph .MultiConnectorEdge [dace .Memlet ]] = list (
514
- state .in_edges_by_connector (map_exit_1 , "IN_" + out_edge .src_conn [4 :])
490
+ state .in_edges_by_connector (first_map_exit , "IN_" + out_edge .src_conn [4 :])
515
491
)
516
492
if len (producer_edges ) > 1 :
517
493
return None
@@ -520,7 +496,7 @@ def partition_first_outputs(
520
496
# - The source of the producer can not be a view (we do not handle this)
521
497
# - The edge shall also not be a reduction edge.
522
498
# - Defined location to where they write.
523
- # - No dynamic Memlets .
499
+ # - No dynamic Melets .
524
500
# Furthermore, we will also extract the subsets, i.e. the location they
525
501
# modify inside the intermediate array.
526
502
# Since we do not allow for WCR, we do not check if the producer subsets intersects.
@@ -531,6 +507,7 @@ def partition_first_outputs(
531
507
):
532
508
return None
533
509
if producer_edge .data .dynamic :
510
+ # TODO(phimuell): Find out if this restriction could be lifted, but it is unlikely.
534
511
return None
535
512
if producer_edge .data .wcr is not None :
536
513
return None
@@ -562,9 +539,9 @@ def partition_first_outputs(
562
539
for intermediate_node_out_edge in state .out_edges (intermediate_node ):
563
540
# If the second map entry is not immediately reachable from the intermediate
564
541
# node, then ensure that there is not path that goes to it.
565
- if intermediate_node_out_edge .dst is not map_entry_2 :
542
+ if intermediate_node_out_edge .dst is not second_map_entry :
566
543
if self .is_node_reachable_from (
567
- graph = state , begin = intermediate_node_out_edge .dst , end = map_entry_2
544
+ graph = state , begin = intermediate_node_out_edge .dst , end = second_map_entry
568
545
):
569
546
return None
570
547
continue
@@ -583,27 +560,28 @@ def partition_first_outputs(
583
560
# Now we look at all edges that leave the second map entry, i.e. the
584
561
# edges that feeds the consumer and define what is read inside the map.
585
562
# We do not check them, but collect them and inspect them.
586
- # NOTE: The subset still uses the old iteration variables.
563
+ # NOTE1: The subset still uses the old iteration variables.
564
+ # NOTE2: In case of consumer Memlet we explicitly allow dynamic Memlets.
565
+ # This is different compared to the producer Memlet. The reason is
566
+ # because in a consumer the data is conditionally read, so the data
567
+ # has to exists anyway.
587
568
for inner_consumer_edge in state .out_edges_by_connector (
588
- map_entry_2 , "OUT_" + intermediate_node_out_edge .dst_conn [3 :]
569
+ second_map_entry , "OUT_" + intermediate_node_out_edge .dst_conn [3 :]
589
570
):
590
571
if inner_consumer_edge .data .src_subset is None :
591
572
return None
592
- if inner_consumer_edge .data .dynamic :
593
- # TODO(phimuell): Is this restriction necessary, I am not sure.
594
- return None
595
573
consumer_subsets .append (inner_consumer_edge .data .src_subset )
596
574
assert (
597
575
found_second_map
598
576
), f"Found '{ intermediate_node } ' which looked like a pure node, but is not one."
599
577
assert len (consumer_subsets ) != 0
600
578
601
579
# The consumer still uses the original symbols of the second map, so we must rename them.
602
- if repl_dict :
580
+ if param_repl :
603
581
consumer_subsets = copy .deepcopy (consumer_subsets )
604
582
for consumer_subset in consumer_subsets :
605
583
symbolic .safe_replace (
606
- mapping = repl_dict , replace_callback = consumer_subset .replace
584
+ mapping = param_repl , replace_callback = consumer_subset .replace
607
585
)
608
586
609
587
# Now we are checking if a single iteration of the first (top) map
@@ -623,6 +601,21 @@ def partition_first_outputs(
623
601
# output of the second map.
624
602
if self .is_shared_data (data = intermediate_node , state = state , sdfg = sdfg ):
625
603
# The intermediate data is used somewhere else, either in this or another state.
604
+ # NOTE: If the intermediate is shared, then we will turn it into a
605
+ # sink node attached to the combined map exit. Technically this
606
+ # should be enough, even if the same data appears again in the
607
+ # dataflow down streams. However, some DaCe transformations,
608
+ # I am looking at you `auto_optimizer()` do not like that. Thus
609
+ # if the intermediate is used further down in the same datadflow,
610
+ # then we consider that the maps can not be fused. But we only
611
+ # do this in the strict data flow mode.
612
+ if self .strict_dataflow :
613
+ if self ._is_data_accessed_downstream (
614
+ data = intermediate_node .data ,
615
+ graph = state ,
616
+ begin = intermediate_node , # is ignored itself.
617
+ ):
618
+ return None
626
619
shared_outputs .add (out_edge )
627
620
else :
628
621
# The intermediate can be removed, as it is not used anywhere else.
@@ -669,8 +662,8 @@ def apply(self, graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG) -> Non
669
662
output_partition = self .partition_first_outputs (
670
663
state = graph ,
671
664
sdfg = sdfg ,
672
- map_exit_1 = map_exit_1 ,
673
- map_entry_2 = map_entry_2 ,
665
+ first_map_exit = map_exit_1 ,
666
+ second_map_entry = map_entry_2 ,
674
667
)
675
668
assert output_partition is not None # Make MyPy happy.
676
669
pure_outputs , exclusive_outputs , shared_outputs = output_partition
0 commit comments