Skip to content

Commit af515ad

Browse files
authored
Migrate Tutorial.IdModelReshapeAnalysis to direct bindings (#5190)
### Create `idm` submodule in direct bindings for `IdModel`, `DisjointSets`, and `ValGraph`. * Mapped `DisjointSets::strictAreMapped` to python `strict_are_mapped` * Mapped IdModel constructor and `IdModel::maybeBuildGraph` * Mapped `ValGraph::disjointValSets` and `ValGraph::mapVals` * Add `IdMappingMode` enum PR Stack: * #5187 * #5188 * #5189 * #5190 **<<< This PR.**
1 parent d48bbd6 commit af515ad

File tree

10 files changed

+244
-4
lines changed

10 files changed

+244
-4
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,7 @@ if(BUILD_PYTHON)
783783
${NVFUSER_PYTHON_DIRECT_BINDINGS}/cutlass.cpp
784784
${NVFUSER_PYTHON_DIRECT_BINDINGS}/runtime.cpp
785785
${NVFUSER_PYTHON_DIRECT_BINDINGS}/schedule.cpp
786+
${NVFUSER_PYTHON_DIRECT_BINDINGS}/id_model.cpp
786787
${NVFUSER_PYTHON_DIRECT_BINDINGS}/direct_utils.cpp
787788
${NVFUSER_PYTHON_DIRECT_BINDINGS}/python_translate.cpp
788789
)

csrc/disjoint_set.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ class VectorOfUniqueEntries {
293293
//! DisjointSet::*AreMapped(a,b) checks if a and b belong to the same disjoint
294294
//! set
295295
template <typename T, typename Hash = std::hash<T>>
296-
class DisjointSets {
296+
class NVF_API DisjointSets {
297297
public:
298298
using DisjointSet = std::shared_ptr<VectorOfUniqueEntries<T, Hash>>;
299299
using DisjointSetMap = std::unordered_map<T, DisjointSet, Hash>;

csrc/id_model/id_model.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ StatefulInliningInfo buildStatefulInliningInfo(
104104
// IdMappingMode::LOOP
105105
// Subgraph of the permissive graph. Maps only CA and their
106106
// dependent domains.
107-
class IdModel : public PolymorphicBase {
107+
class NVF_API IdModel : public PolymorphicBase {
108108
public:
109109
// Sometimes fusion inputs or outputs are disconnected from expressions, in
110110
// those cases we still may want to send in some additional tensor views from

csrc/val_graph.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ namespace nvfuser {
5353
// only tested with IterDomain. Some of the routines might need to be
5454
// extended for other Val types.
5555

56-
class ValGraph {
56+
class NVF_API ValGraph {
5757
public:
5858
ValGraph() = default;
5959

python/python_direct/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ void initNvFuserPythonBindings(PyObject* module) {
2424
bindOperations(nvfuser);
2525
bindScheduleOperators(nvfuser);
2626
bindMultiDevice(nvfuser);
27+
bindIdModel(nvfuser);
2728
nvfuser.def(
2829
"translate_fusion",
2930
&translateFusion,

python/python_direct/bindings.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ void bindScheduleOperators(py::module& nvfuser);
3939
// Add bindings for MultiDevice features
4040
void bindMultiDevice(py::module& nvfuser);
4141

42+
// Add bindings for IdModel and ValGraph
43+
void bindIdModel(py::module& nvfuser);
44+
4245
// Translate a CPP Fusion to a bindings python function
4346
std::string translateFusion(Fusion* f);
4447

python/python_direct/enum.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@ void bindEnums(py::module& nvfuser) {
6767
.value("transpose", SchedulerType::Transpose)
6868
.value("expr_eval", SchedulerType::ExprEval)
6969
.value("resize", SchedulerType::Resize);
70+
71+
py::enum_<IdMappingMode>(nvfuser, "IdMappingMode")
72+
.value("exact", IdMappingMode::EXACT)
73+
.value("almost_exact", IdMappingMode::ALMOSTEXACT)
74+
.value("broadcast", IdMappingMode::BROADCAST)
75+
.value("permissive", IdMappingMode::PERMISSIVE)
76+
.value("loop", IdMappingMode::LOOP);
7077
}
7178

7279
} // namespace nvfuser::python

python/python_direct/id_model.cpp

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
// clang-format off
2+
/*
3+
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
4+
* All rights reserved.
5+
* SPDX-License-Identifier: BSD-3-Clause
6+
*/
7+
// clang-format on
8+
#include <bindings.h>
9+
#include <id_model/id_model.h>
10+
#include <val_graph.h>
11+
12+
namespace nvfuser::python {
13+
14+
namespace {
15+
16+
void bindIdModelClass(py::module_& idm) {
17+
py::class_<IdModel, std::unique_ptr<IdModel>> id_model(idm, "IdModel");
18+
id_model.def(
19+
py::init([](Fusion* fusion,
20+
bool build_graphs,
21+
bool allow_self_mapping,
22+
bool validate) {
23+
return std::make_unique<IdModel>(
24+
fusion, build_graphs, allow_self_mapping, validate);
25+
}),
26+
py::arg("fusion"),
27+
py::arg("build_graphs") = false,
28+
py::arg("allow_self_mapping") = true,
29+
py::arg("validate") = false,
30+
R"(
31+
Create a new IdModel for the given fusion.
32+
33+
Parameters
34+
----------
35+
fusion : Fusion
36+
The fusion to create the IdModel for
37+
build_graphs : bool
38+
Whether to build graphs
39+
allow_self_mapping : bool
40+
Whether to allow self mapping
41+
validate : bool
42+
Whether to validate graphs
43+
44+
Returns
45+
-------
46+
IdModel
47+
The created IdModel
48+
)");
49+
id_model.def(
50+
"__str__",
51+
&IdModel::toString,
52+
R"(
53+
Returns the string representation of the IdModel.
54+
)");
55+
id_model.def(
56+
"maybe_build_graph",
57+
&IdModel::maybeBuildGraph,
58+
py::arg("mode"),
59+
py::return_value_policy::reference,
60+
R"(
61+
Build a graph if not already built.
62+
Dependent graphs are also built if not yet done.
63+
64+
Parameters
65+
----------
66+
mode : IdMappingMode
67+
The mode to build the graph for
68+
69+
Returns
70+
-------
71+
ValGraph
72+
The graph built
73+
)");
74+
}
75+
76+
void bindValGraph(py::module_& idm) {
77+
py::class_<ValGraph, std::unique_ptr<ValGraph>> val_graph(idm, "ValGraph");
78+
val_graph.def(
79+
"disjoint_val_sets",
80+
&ValGraph::disjointValSets,
81+
py::return_value_policy::reference,
82+
R"(
83+
Returns the disjoint val set.
84+
85+
Returns
86+
-------
87+
DisjointValSets
88+
The disjoint val set
89+
)");
90+
val_graph.def(
91+
"__str__",
92+
&ValGraph::toString,
93+
R"(
94+
Returns the string representation of the ValGraph.
95+
)");
96+
val_graph.def(
97+
"map_vals",
98+
&ValGraph::mapVals,
99+
py::arg("val0"),
100+
py::arg("val1"),
101+
R"(Maps the two values.
102+
103+
Parameters
104+
----------
105+
val0 : Val
106+
The first value to map
107+
val1 : Val
108+
The second value to map
109+
)");
110+
}
111+
112+
void bindDisjointSets(py::module_& id_model) {
113+
py::class_<DisjointSets<Val*>, std::unique_ptr<DisjointSets<Val*>>>
114+
disjoint_sets(id_model, "DisjointValSets");
115+
disjoint_sets.def(
116+
"__str__",
117+
&DisjointSets<Val*>::toString,
118+
R"(
119+
Returns the string representation of the DisjointSets.
120+
)");
121+
disjoint_sets.def(
122+
"strict_are_mapped",
123+
&DisjointSets<Val*>::strictAreMapped,
124+
py::arg("entry0"),
125+
py::arg("entry1"),
126+
R"(
127+
Returns if the two entries are strictly mapped.
128+
129+
Parameters
130+
----------
131+
entry0 : Val
132+
The first entry to check
133+
entry1 : Val
134+
The second entry to check
135+
136+
Returns
137+
-------
138+
bool
139+
True if the two entries are strictly mapped, False otherwise.
140+
)");
141+
}
142+
143+
} // namespace
144+
145+
void bindIdModel(py::module& nvfuser) {
146+
py::module_ idm = nvfuser.def_submodule(
147+
"idm", "This submodule contains all id model operators for NvFuser.");
148+
bindIdModelClass(idm);
149+
bindValGraph(idm);
150+
bindDisjointSets(idm);
151+
}
152+
153+
} // namespace nvfuser::python

python/python_direct/ir.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,17 @@ Returns
7070
-------
7171
Expr
7272
The definition of this expression.
73+
)")
74+
.def(
75+
"uses",
76+
&Val::uses,
77+
R"(
78+
Get the uses of this expression.
79+
80+
Returns
81+
-------
82+
Expr
83+
The uses of this expression.
7384
)");
7485

7586
// Expr

tests/python/direct/test_tutorial.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from nvfuser_direct import (
99
FusionDefinition,
10+
IdMappingMode,
1011
ParallelType,
1112
TensorView,
1213
Merge,
@@ -15,8 +16,9 @@
1516
SqueezeOp,
1617
ReshapeOp,
1718
)
19+
from nvfuser_direct import idm
1820

19-
verbose_ = False
21+
verbose_ = True
2022

2123

2224
def test_tutorial_memcpy():
@@ -508,3 +510,65 @@ def test_tutorial_reshape():
508510
# Note that all the transformations of squeeze_output are scheduling
509511
# transformations, thus it should not have a root domain
510512
assert not squeeze_output.has_root()
513+
514+
515+
def test_tutorial_id_model_reshape_analysis():
516+
"""
517+
Demonstration of using IdModel for analyzing equivalence of reshape ops
518+
"""
519+
with FusionDefinition() as fd:
520+
# Use the static reshape to avoid reshape concretization.
521+
tv0 = fd.define_tensor(shape=[10, 20])
522+
tv1 = fd.define_tensor(shape=[10, 20])
523+
524+
# While the reshape operations are equivalent, we do not know if the two
525+
# inputs are the same. There is not an operation allowing us to infer
526+
# equivalence. e.g., tv0 + tv1.
527+
tv2 = fd.ops.reshape(tv0, [20, 10])
528+
tv3 = fd.ops.reshape(tv1, [20, 10])
529+
fd.add_output(tv2)
530+
fd.add_output(tv3)
531+
532+
id_model = idm.IdModel(fd.fusion)
533+
exact_graph = id_model.maybe_build_graph(IdMappingMode.exact)
534+
535+
if verbose_:
536+
print(id_model)
537+
print(exact_graph)
538+
print(exact_graph.disjoint_val_sets())
539+
540+
# As mentioned above, we do not know any relationship between tv0 and tv1.
541+
# They should not be mapped in exact graph.
542+
assert len(tv0.get_logical_domain()) == len(tv1.get_logical_domain())
543+
for tv0_id, tv1_id in zip(tv0.get_logical_domain(), tv1.get_logical_domain()):
544+
assert not exact_graph.disjoint_val_sets().strict_are_mapped(tv0_id, tv1_id)
545+
546+
# Thus, the outputs of the reshape ops are not mapped either
547+
assert len(tv2.get_loop_domain()) == len(tv3.get_loop_domain())
548+
for tv2_id, tv3_id in zip(tv2.get_loop_domain(), tv3.get_loop_domain()):
549+
assert not exact_graph.disjoint_val_sets().strict_are_mapped(tv2_id, tv3_id)
550+
551+
# Now, suppose we can say the inputs are exactly mapped. We can manually
552+
# add mappings:
553+
for tv0_id, tv1_id in zip(tv0.get_logical_domain(), tv1.get_logical_domain()):
554+
exact_graph.map_vals(tv0_id, tv1_id)
555+
556+
# Now, tv2 and tv3 should be fully mapped, including their root,
557+
# intermediate and loop domains.
558+
559+
# Check the root domains.
560+
assert len(tv2.get_root_domain()) == len(tv3.get_root_domain())
561+
for tv2_id, tv3_id in zip(tv2.get_root_domain(), tv3.get_root_domain()):
562+
assert exact_graph.disjoint_val_sets().strict_are_mapped(tv2_id, tv3_id)
563+
564+
# The reshape consists of a merge and split. The output of the merge should
565+
# be mapped as well
566+
assert exact_graph.disjoint_val_sets().strict_are_mapped(
567+
tv2.get_root_domain()[0].uses()[0].output(0),
568+
tv3.get_root_domain()[0].uses()[0].output(0),
569+
)
570+
571+
# The next operation is split. Its outputs, which are the loop domains,
572+
# should be mapped too.
573+
for tv2_id, tv3_id in zip(tv2.get_loop_domain(), tv3.get_loop_domain()):
574+
assert exact_graph.disjoint_val_sets().strict_are_mapped(tv2_id, tv3_id)

0 commit comments

Comments
 (0)