Skip to content

Commit d7b95dd

Browse files
duc0facebook-github-bot
authored andcommitted
nomnigraph - easy - expose hasProduce(NodeRef) to python (pytorch#14075)
Summary: Pull Request resolved: pytorch#14075 Expose hasProduce(NodeRef) to python Reviewed By: bwasti Differential Revision: D13092930 fbshipit-source-id: f1ec06e73e0f5f6a16ad0cbb7d2e3e499a861d8e
1 parent e7f5fce commit d7b95dd

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

caffe2/python/nomnigraph_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,12 @@ def test_traversal(self):
128128
nn = ng.NNModule(net)
129129
fc = nn.controlFlow[0]
130130
relu = nn.controlFlow[1]
131+
assert not fc.inputs[0].hasProducer()
131132
assert fc.inputs[0].name == "X"
132133
assert fc.inputs[1].name == "W"
133134
assert relu.outputs[0].name == "Z"
134135
assert relu.inputs[0].name == "Y"
136+
assert relu.inputs[0].hasProducer()
135137
assert relu.inputs[0].producer.name == "FC"
136138
assert fc.outputs[0].consumers[0].name == "Relu"
137139

caffe2/python/pybind_state_nomni.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ void addNomnigraphMethods(pybind11::module& m) {
328328
"tensor", getTensor, py::return_value_policy::reference)
329329
.def("getInputs", getInputs, py::return_value_policy::reference)
330330
.def("getOutputs", getOutputs, py::return_value_policy::reference)
331+
.def("hasProducer", [](NNGraph::NodeRef n) { return nn::hasProducer(n); })
331332
.def("getProducer", getProducer, py::return_value_policy::reference)
332333
.def("getConsumers", getConsumers, py::return_value_policy::reference)
333334
.def_property_readonly(

0 commit comments

Comments
 (0)