Skip to content

Commit c7a247f

Browse files
duc0facebook-github-bot
authored andcommitted
nomnigraph - support subgraph visualization (pytorch#13795)
Summary: Pull Request resolved: pytorch#13795 Add ability for dot string generation for a single subgraph and python bindings (which is pretty useful for model exploration in Python) Restructure DotGenerator class a bit to make it easy to implement this feature Reviewed By: bwasti Differential Revision: D13010512 fbshipit-source-id: 825665438394b7e6968ab6da167b477af82a7b62
1 parent d7b95dd commit c7a247f

File tree

7 files changed

+178
-77
lines changed

7 files changed

+178
-77
lines changed

caffe2/core/nomnigraph/include/nomnigraph/Converters/Dot.h

Lines changed: 103 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef NOM_CONVERTERS_DOT_H
22
#define NOM_CONVERTERS_DOT_H
33

4+
#include "nomnigraph/Graph/Algorithms.h"
45
#include "nomnigraph/Graph/Graph.h"
56
#include "nomnigraph/Support/Casting.h"
67

@@ -10,57 +11,40 @@
1011
#include <queue>
1112
#include <sstream>
1213
#include <unordered_map>
14+
#include <vector>
1315

1416
namespace nom {
1517
namespace converters {
1618

17-
template <typename T, typename... U>
19+
template <typename GraphT>
1820
class DotGenerator {
1921
public:
2022
using NodePrinter = std::function<std::map<std::string, std::string>(
21-
typename nom::Graph<T, U...>::NodeRef)>;
23+
typename GraphT::NodeRef)>;
2224
using EdgePrinter = std::function<std::map<std::string, std::string>(
23-
typename nom::Graph<T, U...>::EdgeRef)>;
24-
using NodeRef = typename nom::Graph<T, U...>::NodeRef;
25+
typename GraphT::EdgeRef)>;
2526

2627
static std::map<std::string, std::string> defaultEdgePrinter(
27-
typename nom::Graph<T, U...>::EdgeRef e) {
28+
typename GraphT::EdgeRef) {
2829
std::map<std::string, std::string> labelMap;
2930
return labelMap;
3031
}
3132

32-
DotGenerator(typename nom::Graph<T, U...>* g) : g_(g) {}
33+
DotGenerator(NodePrinter nodePrinter, EdgePrinter edgePrinter)
34+
: nodePrinter_(nodePrinter), edgePrinter_(edgePrinter) {}
3335

34-
~DotGenerator() {}
35-
36-
/**
37-
* Converts given graph into DOT string
38-
* @param nodePrinter node attribute extractor
39-
* @param edgePrinter edge attribute extractor
40-
* @return DOT string representation of graph
41-
*/
42-
std::string convert(NodePrinter nodePrinter, EdgePrinter edgePrinter) {
36+
// Convert a graph (with optional subgraphs cluster) to dot.
37+
std::string convert(
38+
const typename GraphT::SubgraphType& sg,
39+
const std::vector<typename GraphT::SubgraphType*>& subgraphs) const {
4340
std::ostringstream output;
4441
output << "digraph G {\n\
4542
";
46-
for (const auto& node : g_->getMutableNodes()) {
47-
output << (uint64_t)node; // dot doesn't like hex
48-
output << "[";
49-
for (const auto& attrib : nodePrinter(node)) {
50-
output << attrib.first << "=\"" << attrib.second << "\",";
51-
}
52-
output << "];\n";
53-
for (const auto& edge : node->getOutEdges()) {
54-
output << (uint64_t)edge->tail() << " -> " << (uint64_t)edge->head();
55-
output << "[";
56-
for (const auto& attrib : edgePrinter(edge)) {
57-
output << attrib.first << "=\"" << attrib.second << "\",";
58-
}
59-
output << "];\n";
60-
}
43+
for (const auto& node : sg.getNodes()) {
44+
generateNode(node, sg, output);
6145
}
62-
for (auto i = 0; i < subgraphs_.size(); ++i) {
63-
const auto& subgraph = subgraphs_[i];
46+
for (auto i = 0; i < subgraphs.size(); ++i) {
47+
const auto& subgraph = subgraphs[i];
6448
output << "subgraph cluster" << i << " {\n";
6549
output << "style=dotted;\n";
6650
for (const auto& node : subgraph->getNodes()) {
@@ -73,6 +57,18 @@ class DotGenerator {
7357
return output.str();
7458
}
7559

60+
// Convert a subgraph to dot.
61+
std::string convert(const typename GraphT::SubgraphType& sg) const {
62+
std::ostringstream output;
63+
output << "digraph G {\n\
64+
";
65+
for (const auto& node : sg.getNodes()) {
66+
generateNode(node, sg, output);
67+
}
68+
output << "}";
69+
return output.str();
70+
}
71+
7672
/**
7773
* NOTE No subgraph support
7874
* Converts given graph into DOT string w/operator input-order preserved
@@ -84,19 +80,20 @@ class DotGenerator {
8480
* - Node: op_ptr[shape=record, label="{{<i0>*|<i1>*|...}|{op}|{<o0>*}"]
8581
* - Edge: <parent_node_ptr>:<ref>:s -> <this_node_ptr>:<ref>:n
8682
*/
87-
std::string convertStruct(NodePrinter nodePrinter, EdgePrinter edgePrinter) {
83+
std::string convertStruct(const typename GraphT::SubgraphType& sg) const {
8884
std::ostringstream output;
8985
output << "digraph G {\n";
9086

9187
// Get input nodes (nodes w/o parents)
92-
std::unordered_map<NodeRef, int> nodeDepthMap; // Touched nodes for BFS
93-
std::queue<NodeRef> workList; // Init w/parentless nodes
94-
for (const auto& node : g_->getMutableNodes()) {
88+
std::unordered_map<typename GraphT::NodeRef, int>
89+
nodeDepthMap; // Touched nodes for BFS
90+
std::queue<typename GraphT::NodeRef> workList; // Init w/parentless nodes
91+
for (const auto& node : sg.getNodes()) {
9592
if (node->getInEdges().size() == 0 && node->getOutEdges().size() > 0) {
9693
// Add input node to dot string
9794
output << (uint64_t)node << "[shape=record, label=\"{{Data In}|{<"
9895
<< (uint64_t)node << ">";
99-
for (const auto& attr : nodePrinter(node)) {
96+
for (const auto& attr : nodePrinter_(node)) {
10097
output << attr.second;
10198
}
10299
output << "}}\"]\n";
@@ -108,7 +105,7 @@ class DotGenerator {
108105
}
109106

110107
// BFS to get operator nodes
111-
std::vector<NodeRef> ops;
108+
std::vector<typename GraphT::NodeRef> ops;
112109
while (workList.size() > 0) {
113110
const auto& node = workList.front();
114111
for (const auto& edge : node->getOutEdges()) {
@@ -127,25 +124,21 @@ class DotGenerator {
127124
}
128125

129126
// Finalize output
130-
output << getOperatorSubtreeDotString(ops, nodePrinter) << "}\n";
127+
output << getOperatorSubtreeDotString(ops) << "}\n";
131128
return output.str();
132129
}
133130

134-
void addSubgraph(const nom::Subgraph<T, U...>* s) {
135-
subgraphs_.emplace_back(s);
136-
}
137-
138131
private:
139-
typename nom::Graph<T, U...>* g_;
140-
typename std::vector<const nom::Subgraph<T, U...>*> subgraphs_;
132+
NodePrinter nodePrinter_;
133+
EdgePrinter edgePrinter_;
141134

142135
/**
143136
* Get DOT string record of given operator and DOT string of its input edges
144137
* @param op operator to parse
145138
* @param nodePrinter node attribute extractor
146139
* @return '\n' sep string of operator & input edges
147140
*/
148-
std::string getOperatorDotString(NodeRef op, NodePrinter nodePrinter) {
141+
std::string getOperatorDotString(typename GraphT::NodeRef op) const {
149142
std::ostringstream output;
150143
std::ostringstream record; // Operator node record
151144
record << (uint64_t)op << "[shape=record, label=\"{{";
@@ -167,15 +160,15 @@ class DotGenerator {
167160

168161
// Add input to operator record
169162
record << sep << "<" << (uint64_t)input << ">";
170-
for (const auto& attr : nodePrinter(input)) {
163+
for (const auto& attr : nodePrinter_(input)) {
171164
record << attr.second;
172165
}
173166
sep = "|";
174167
}
175168

176169
// Extract operator name
177170
record << "}|{";
178-
for (const auto& attr : nodePrinter(op)) {
171+
for (const auto& attr : nodePrinter_(op)) {
179172
record << attr.second;
180173
}
181174
record << "}|{";
@@ -185,7 +178,7 @@ class DotGenerator {
185178
for (const auto& edge : op->getOutEdges()) {
186179
const auto& child = edge->head();
187180
record << sep << "<" << (uint64_t)child << ">";
188-
for (const auto& attr : nodePrinter(child)) {
181+
for (const auto& attr : nodePrinter_(child)) {
189182
record << attr.second;
190183
}
191184
sep = "|";
@@ -203,48 +196,81 @@ class DotGenerator {
203196
* @return DOT string that renders operators subgraph
204197
*/
205198
std::string getOperatorSubtreeDotString(
206-
std::vector<NodeRef> ops,
207-
NodePrinter nodePrinter) {
199+
std::vector<typename GraphT::NodeRef> ops) const {
208200
std::ostringstream output;
209201
for (const auto& op : ops) {
210-
output << getOperatorDotString(op, nodePrinter);
202+
output << getOperatorDotString(op);
211203
}
212204
return output.str();
213205
}
206+
207+
// Generate dot string for a node.
208+
void generateNode(
209+
typename GraphT::NodeRef node,
210+
const typename GraphT::SubgraphType& sg,
211+
std::ostringstream& output) const {
212+
output << (uint64_t)node; // dot doesn't like hex
213+
output << "[";
214+
for (const auto& attrib : nodePrinter_(node)) {
215+
output << attrib.first << "=\"" << attrib.second << "\",";
216+
}
217+
output << "];\n";
218+
for (const auto& edge : node->getOutEdges()) {
219+
if (!sg.hasEdge(edge)) {
220+
continue;
221+
}
222+
output << (uint64_t)edge->tail() << " -> " << (uint64_t)edge->head();
223+
output << "[";
224+
for (const auto& attrib : edgePrinter_(edge)) {
225+
output << attrib.first << "=\"" << attrib.second << "\",";
226+
}
227+
output << "];\n";
228+
}
229+
}
214230
};
215231

216-
template <typename T, typename... U>
232+
// Convert a graph to dot string.
233+
template <typename GraphT>
217234
std::string convertToDotString(
218-
nom::Graph<T, U...>* g,
219-
typename DotGenerator<T, U...>::NodePrinter nodePrinter,
220-
typename DotGenerator<T, U...>::EdgePrinter edgePrinter =
221-
DotGenerator<T, U...>::defaultEdgePrinter) {
222-
auto d = DotGenerator<T, U...>(g);
223-
return d.convert(nodePrinter, edgePrinter);
235+
GraphT* g,
236+
typename DotGenerator<GraphT>::NodePrinter nodePrinter,
237+
typename DotGenerator<GraphT>::EdgePrinter edgePrinter =
238+
DotGenerator<GraphT>::defaultEdgePrinter) {
239+
auto d = DotGenerator<GraphT>(nodePrinter, edgePrinter);
240+
return d.convert(algorithm::createSubgraph(g), {});
224241
}
225242

226-
template <typename T, typename... U>
243+
// Convert a graph to dot string and annotate subgraph clusters.
244+
template <typename GraphT>
227245
std::string convertToDotString(
228-
nom::Graph<T, U...>* g,
229-
const std::vector<nom::Subgraph<T, U...>>& subgraphs,
230-
typename DotGenerator<T, U...>::NodePrinter nodePrinter,
231-
typename DotGenerator<T, U...>::EdgePrinter edgePrinter =
232-
DotGenerator<T, U...>::defaultEdgePrinter) {
233-
auto d = DotGenerator<T, U...>(g);
234-
for (const auto& subgraph : subgraphs) {
235-
d.addSubgraph(&subgraph);
236-
}
237-
return d.convert(nodePrinter, edgePrinter);
246+
GraphT* g,
247+
const std::vector<typename GraphT::SubgraphType*>& subgraphs,
248+
typename DotGenerator<GraphT>::NodePrinter nodePrinter,
249+
typename DotGenerator<GraphT>::EdgePrinter edgePrinter =
250+
DotGenerator<GraphT>::defaultEdgePrinter) {
251+
auto d = DotGenerator<GraphT>(nodePrinter, edgePrinter);
252+
return d.convert(algorithm::createSubgraph(g), subgraphs);
253+
}
254+
255+
// Convert a subgraph to dot string.
256+
template <typename GraphT>
257+
std::string convertToDotString(
258+
const typename GraphT::SubgraphType& sg,
259+
typename DotGenerator<GraphT>::NodePrinter nodePrinter,
260+
typename DotGenerator<GraphT>::EdgePrinter edgePrinter =
261+
DotGenerator<GraphT>::defaultEdgePrinter) {
262+
auto d = DotGenerator<GraphT>(nodePrinter, edgePrinter);
263+
return d.convert(sg);
238264
}
239265

240-
template <typename T, typename... U>
266+
template <typename GraphT>
241267
std::string convertToDotRecordString(
242-
nom::Graph<T, U...>* g,
243-
typename DotGenerator<T, U...>::NodePrinter nodePrinter,
244-
typename DotGenerator<T, U...>::EdgePrinter edgePrinter =
245-
DotGenerator<T, U...>::defaultEdgePrinter) {
246-
auto d = DotGenerator<T, U...>(g);
247-
return d.convertStruct(nodePrinter, edgePrinter);
268+
GraphT* g,
269+
typename DotGenerator<GraphT>::NodePrinter nodePrinter,
270+
typename DotGenerator<GraphT>::EdgePrinter edgePrinter =
271+
DotGenerator<GraphT>::defaultEdgePrinter) {
272+
auto d = DotGenerator<GraphT>(nodePrinter, edgePrinter);
273+
return d.convertStruct(algorithm::createSubgraph(g));
248274
}
249275

250276
} // namespace converters

caffe2/core/nomnigraph/include/nomnigraph/Graph/Algorithms.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,17 @@ void induceEdges(SubgraphType* sg) {
196196
}
197197
}
198198

199+
/// \brief Create subgraph object from graph.
200+
template <typename GraphType>
201+
typename GraphType::SubgraphType createSubgraph(GraphType* g) {
202+
typename GraphType::SubgraphType subgraph;
203+
for (auto& node : g->getMutableNodes()) {
204+
subgraph.addNode(node);
205+
}
206+
induceEdges(&subgraph);
207+
return subgraph;
208+
}
209+
199210
} // namespace algorithm
200211
} // namespace nom
201212

caffe2/core/nomnigraph/tests/GraphTest.cc

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,44 @@ TEST(Basic, MoveSubgraph) {
176176
EXPECT_EQ(g.getMutableEdges().size(), 0);
177177
EXPECT_EQ(g2.getMutableEdges().size(), 1);
178178
}
179+
180+
TEST(Basic, DotGenerator) {
181+
TestGraph g;
182+
auto n1 = createTestNode(g);
183+
auto n2 = createTestNode(g);
184+
auto n3 = createTestNode(g);
185+
auto e12 = g.createEdge(n1, n2);
186+
g.createEdge(n1, n3);
187+
188+
std::string dot = nom::converters::convertToDotString(&g, TestNodePrinter);
189+
190+
// sanity check
191+
std::string prefix = "digraph G";
192+
// Full string comparison of the output is not stable because the dot
193+
// string includes node pointer address as node id. We should switch to
194+
// comparing full output once dot generator no longer uses addresses.
195+
EXPECT_TRUE(dot.compare(0, prefix.length(), prefix) == 0);
196+
197+
TestGraph::SubgraphType sg;
198+
sg.addNode(n1);
199+
sg.addNode(n2);
200+
sg.addEdge(e12);
201+
202+
// Convert to dot with subgraph clusters.
203+
dot = nom::converters::convertToDotString<TestGraph>(
204+
&g, {&sg}, TestNodePrinter);
205+
206+
// sanity check
207+
EXPECT_TRUE(dot.compare(0, prefix.length(), prefix) == 0);
208+
209+
// Convert a single subgraph to dot.
210+
dot = nom::converters::convertToDotString<TestGraph>(sg, TestNodePrinter);
211+
212+
// sanity check
213+
EXPECT_TRUE(dot.compare(0, prefix.length(), prefix) == 0);
214+
215+
dot =
216+
nom::converters::convertToDotRecordString<TestGraph>(&g, TestNodePrinter);
217+
// sanity check
218+
EXPECT_TRUE(dot.compare(0, prefix.length(), prefix) == 0);
219+
}

caffe2/core/nomnigraph/tests/test_util.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,10 @@ std::map<std::string, std::string> NNPrinter(typename nom::repr::NNGraph::NodeRe
120120
nom::Graph<TestClass>::NodeRef createTestNode(nom::Graph<TestClass>& g) {
121121
return g.createNode(TestClass());
122122
}
123+
124+
std::map<std::string, std::string> TestNodePrinter(
125+
nom::Graph<TestClass>::NodeRef /* unused */) {
126+
std::map<std::string, std::string> labelMap;
127+
labelMap["label"] = "Node";
128+
return labelMap;
129+
}

caffe2/core/nomnigraph/tests/test_util.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,4 +114,7 @@ std::map<std::string, std::string> NNPrinter(typename nom::repr::NNGraph::NodeRe
114114

115115
CAFFE2_API nom::Graph<TestClass>::NodeRef createTestNode(
116116
nom::Graph<TestClass>& g);
117+
118+
CAFFE2_API std::map<std::string, std::string> TestNodePrinter(
119+
nom::Graph<TestClass>::NodeRef node);
117120
#endif // NOM_TESTS_TEST_UTIL_H

0 commit comments

Comments
 (0)