|
7 | 7 | import torch
|
8 | 8 | from nvfuser_direct import (
|
9 | 9 | FusionDefinition,
|
| 10 | + IdMappingMode, |
10 | 11 | ParallelType,
|
11 | 12 | TensorView,
|
12 | 13 | Merge,
|
|
15 | 16 | SqueezeOp,
|
16 | 17 | ReshapeOp,
|
17 | 18 | )
|
| 19 | +from nvfuser_direct import idm |
18 | 20 |
|
19 |
| -verbose_ = False |
| 21 | +verbose_ = True |
20 | 22 |
|
21 | 23 |
|
22 | 24 | def test_tutorial_memcpy():
|
@@ -508,3 +510,65 @@ def test_tutorial_reshape():
|
508 | 510 | # Note that all the transformations of squeeze_output are scheduling
|
509 | 511 | # transformations, thus it should not have a root domain
|
510 | 512 | assert not squeeze_output.has_root()
|
| 513 | + |
| 514 | + |
| 515 | +def test_tutorial_id_model_reshape_analysis(): |
| 516 | + """ |
| 517 | + Demonstration of using IdModel for analyzing equivalence of reshape ops |
| 518 | + """ |
| 519 | + with FusionDefinition() as fd: |
| 520 | + # Use the static reshape to avoid reshape concretization. |
| 521 | + tv0 = fd.define_tensor(shape=[10, 20]) |
| 522 | + tv1 = fd.define_tensor(shape=[10, 20]) |
| 523 | + |
| 524 | + # While the reshape operations are equivalent, we do not know if the two |
| 525 | + # inputs are the same. There is not an operation allowing us to infer |
| 526 | + # equivalence. e.g., tv0 + tv1. |
| 527 | + tv2 = fd.ops.reshape(tv0, [20, 10]) |
| 528 | + tv3 = fd.ops.reshape(tv1, [20, 10]) |
| 529 | + fd.add_output(tv2) |
| 530 | + fd.add_output(tv3) |
| 531 | + |
| 532 | + id_model = idm.IdModel(fd.fusion) |
| 533 | + exact_graph = id_model.maybe_build_graph(IdMappingMode.exact) |
| 534 | + |
| 535 | + if verbose_: |
| 536 | + print(id_model) |
| 537 | + print(exact_graph) |
| 538 | + print(exact_graph.disjoint_val_sets()) |
| 539 | + |
| 540 | + # As mentioned above, we do not know any relationship between tv0 and tv1. |
| 541 | + # They should not be mapped in exact graph. |
| 542 | + assert len(tv0.get_logical_domain()) == len(tv1.get_logical_domain()) |
| 543 | + for tv0_id, tv1_id in zip(tv0.get_logical_domain(), tv1.get_logical_domain()): |
| 544 | + assert not exact_graph.disjoint_val_sets().strict_are_mapped(tv0_id, tv1_id) |
| 545 | + |
| 546 | + # Thus, the outputs of the reshape ops are not mapped either |
| 547 | + assert len(tv2.get_loop_domain()) == len(tv3.get_loop_domain()) |
| 548 | + for tv2_id, tv3_id in zip(tv2.get_loop_domain(), tv3.get_loop_domain()): |
| 549 | + assert not exact_graph.disjoint_val_sets().strict_are_mapped(tv2_id, tv3_id) |
| 550 | + |
| 551 | + # Now, suppose we can say the inputs are exactly mapped. We can manually |
| 552 | + # add mappings: |
| 553 | + for tv0_id, tv1_id in zip(tv0.get_logical_domain(), tv1.get_logical_domain()): |
| 554 | + exact_graph.map_vals(tv0_id, tv1_id) |
| 555 | + |
| 556 | + # Now, tv2 and tv3 should be fully mapped, including their root, |
| 557 | + # intermediate and loop domains. |
| 558 | + |
| 559 | + # Check the root domains. |
| 560 | + assert len(tv2.get_root_domain()) == len(tv3.get_root_domain()) |
| 561 | + for tv2_id, tv3_id in zip(tv2.get_root_domain(), tv3.get_root_domain()): |
| 562 | + assert exact_graph.disjoint_val_sets().strict_are_mapped(tv2_id, tv3_id) |
| 563 | + |
| 564 | + # The reshape consists of a merge and split. The output of the merge should |
| 565 | + # be mapped as well |
| 566 | + assert exact_graph.disjoint_val_sets().strict_are_mapped( |
| 567 | + tv2.get_root_domain()[0].uses()[0].output(0), |
| 568 | + tv3.get_root_domain()[0].uses()[0].output(0), |
| 569 | + ) |
| 570 | + |
| 571 | + # The next operation is split. Its outputs, which are the loop domains, |
| 572 | + # should be mapped too. |
| 573 | + for tv2_id, tv3_id in zip(tv2.get_loop_domain(), tv3.get_loop_domain()): |
| 574 | + assert exact_graph.disjoint_val_sets().strict_are_mapped(tv2_id, tv3_id) |
0 commit comments