Skip to content

Commit 11c194c

Browse files
SherlockNoMadSherlock Huang
and
Sherlock Huang
authored
Minor fix for ComputeBroadcastBackwardAxesDynamic; Fix for GradientGraphBuilder logging (microsoft#5313)
Co-authored-by: Sherlock Huang <bahuang@OrtTrainingDev3.af05slrtruoetgaxwwjv5nsq5e.px.internal.cloudapp.net>
1 parent 24d8b1b commit 11c194c

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

orttraining/orttraining/core/framework/gradient_graph_builder.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph,
6363
reachable_nodes_ = ReverseBFS(y_nodes_);
6464

6565
std::string unreachable_nodes;
66-
66+
6767
// building x_nodes_
6868
for (const auto& name : x_node_arg_names) {
6969
const NodeArg* node_arg = graph->GetNodeArg(name);
@@ -89,7 +89,9 @@ GradientGraphBuilder::GradientGraphBuilder(Graph* graph,
8989
}
9090
}
9191
}
92-
LOGS(logger_, WARNING) << "Following nodes are unreachable for gradient back propagation: " << unreachable_nodes;
92+
if (!unreachable_nodes.empty()) {
93+
LOGS(logger_, WARNING) << "Following nodes are unreachable for gradient back propagation: " << unreachable_nodes;
94+
}
9395
}
9496

9597
NodeSet GradientGraphBuilder::ReverseBFS(const NodeSet& nodes) const {

orttraining/orttraining/core/graph/gradient_builder_base.cc

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,12 @@ void ComputeBroadcastBackwardAxesDynamic(const ArgDef& a,
105105
const ArgDef* a_axes,
106106
const ArgDef* b_axes,
107107
std::vector<NodeDef>& output) {
108+
// Populate the node names explicitly in case a and b are the same tensor and
109+
// resulting in duplicated node name for Shape node. For example, y = x^2 is sometimes represented as Mul(x,x)
108110
output.push_back(
109-
NodeDef("Shape",
110-
{a},
111-
{a_shape}));
112-
111+
NodeDef("Shape", {a}, {a_shape}, NodeAttributes(), a_shape.name + "_lhs"));
113112
output.push_back(
114-
NodeDef("Shape",
115-
{b},
116-
{b_shape}));
113+
NodeDef("Shape", {b}, {b_shape}, NodeAttributes(), b_shape.name + "_rhs"));
117114

118115
ArgDef a_op = ArgDef(""), b_op = ArgDef("");
119116
if (a_axes)

0 commit comments

Comments
 (0)