Skip to content

Commit a38e4bf

Browse files
committed
Initial version of network graph generation
1 parent 82dd91c commit a38e4bf

File tree

2 files changed

+242
-0
lines changed

2 files changed

+242
-0
lines changed

inspect.lua

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
require 'graph'
2+
require 'trepl'
3+
4+
models = dofile 'models.lua'
5+
6+
--net,input = models.resnet({dataset='cifar10',depth=20})
7+
--net,input = models.alexnet()
8+
--net,input = models.basic1()
9+
net,input = models.googlenet()
10+
11+
-- taken from http://www.graphviz.org/doc/info/colors.html
12+
local colorNames = {
13+
"aliceblue","antiquewhite","antiquewhite1","antiquewhite2","antiquewhite3",
14+
"antiquewhite4","aquamarine","aquamarine1","aquamarine2","aquamarine3",
15+
"aquamarine4","azure","azure1","azure2","azure3",
16+
"azure4","beige","bisque","bisque1","bisque2",
17+
"bisque3","bisque4","black","blanchedalmond","blue",
18+
"blue1","blue2","blue3","blue4","blueviolet",
19+
"brown","brown1","brown2","brown3","brown4",
20+
"burlywood","burlywood1","burlywood2","burlywood3","burlywood4",
21+
"cadetblue","cadetblue1","cadetblue2","cadetblue3","cadetblue4",
22+
"chartreuse","chartreuse1","chartreuse2","chartreuse3","chartreuse4",
23+
"chocolate","chocolate1","chocolate2","chocolate3","chocolate4",
24+
"coral","coral1","coral2","coral3","coral4",
25+
"cornflowerblue","cornsilk","cornsilk1","cornsilk2","cornsilk3",
26+
"cornsilk4","crimson","cyan","cyan1","cyan2",
27+
"cyan3","cyan4","darkgoldenrod","darkgoldenrod1","darkgoldenrod2",
28+
"darkgoldenrod3","darkgoldenrod4","darkgreen","darkkhaki","darkolivegreen",
29+
"darkolivegreen1","darkolivegreen2","darkolivegreen3","darkolivegreen4","darkorange",
30+
"darkorange1","darkorange2","darkorange3","darkorange4","darkorchid",
31+
"darkorchid1","darkorchid2","darkorchid3","darkorchid4","darksalmon",
32+
"darkseagreen","darkseagreen1","darkseagreen2","darkseagreen3","darkseagreen4",
33+
"darkslateblue","darkslategray","darkslategray1","darkslategray2","darkslategray3",
34+
"darkslategray4","darkslategrey","darkturquoise","darkviolet","deeppink",
35+
"deeppink1","deeppink2","deeppink3","deeppink4","deepskyblue",
36+
"deepskyblue1","deepskyblue2","deepskyblue3","deepskyblue4","dimgray",
37+
"dimgrey","dodgerblue","dodgerblue1","dodgerblue2","dodgerblue3",
38+
"dodgerblue4","firebrick","firebrick1","firebrick2","firebrick3",
39+
"firebrick4","floralwhite","forestgreen","gainsboro","ghostwhite",
40+
"gold","gold1","gold2","gold3","gold4",
41+
"goldenrod","goldenrod1","goldenrod2","goldenrod3","goldenrod4"
42+
}
43+
44+
45+
local storageHash = {}
46+
47+
local nodes = {}
48+
49+
local function createNode(name, data)
50+
local storageId
51+
for k, v in ipairs(storageHash) do
52+
if v == data then
53+
storageId = k
54+
end
55+
end
56+
local node = graph.Node("Storage id: "..storageId)
57+
function node:graphNodeName()
58+
return name
59+
end
60+
function node:graphNodeAttributes()
61+
return {color=colorNames[storageHash[data]]}
62+
end
63+
return node
64+
end
65+
66+
local function createInputNode(input, name)
67+
name = name or 'Input'
68+
if torch.isTensor(input) then
69+
local ptr = torch.pointer(input)
70+
local storagePtr = torch.pointer(input:storage())
71+
if not storageHash[storagePtr] then
72+
storageHash[storagePtr] = torch.random(1,#colorNames)
73+
table.insert(storageHash, storagePtr)
74+
end
75+
nodes[ptr] = createNode(name,storagePtr)
76+
else
77+
for k,v in ipairs(input) do
78+
createInputNode(nodes, v, name..' '..k)
79+
end
80+
end
81+
end
82+
83+
createInputNode(input)
84+
85+
local g = graph.Graph()
86+
87+
local function addEdge(p,c,name)
88+
if torch.isTensor(c) and torch.isTensor(p) then
89+
local cc = torch.pointer(c)
90+
local childStoragePtr = torch.pointer(c:storage())
91+
if not storageHash[childStoragePtr] then
92+
storageHash[childStoragePtr] = torch.random(1,#colorNames)
93+
table.insert(storageHash, childStoragePtr)
94+
end
95+
local pp = torch.pointer(p)
96+
nodes[cc] = nodes[cc] or createNode(name,childStoragePtr)
97+
local parent = nodes[pp]
98+
assert(parent, 'Parent node inexistant for module '.. name)
99+
g:add(graph.Edge(parent,nodes[cc]))
100+
elseif torch.isTensor(p) then
101+
for k,v in ipairs(c) do
102+
addEdge(p,v,name)
103+
end
104+
else
105+
for k,v in ipairs(p) do
106+
addEdge(v,c,name)
107+
end
108+
end
109+
end
110+
111+
local function apply_func(m)
112+
local oldf = m.updateOutput
113+
m.updateOutput = function(self, input)
114+
if not m.modules then
115+
local name = tostring(m)
116+
if m.inplace then -- handle it differently to avoid loops ?
117+
addEdge(input,self.output,name)
118+
else
119+
addEdge(input,self.output,name)
120+
end
121+
elseif torch.typename(m) == 'nn.Concat' or
122+
torch.typename(m) == 'nn.Parallel' or
123+
torch.typename(m) == 'nn.DepthConcat' then
124+
-- those containers effectively do some computation, so they have their
125+
-- place in the graph
126+
for i,branch in ipairs(m.modules) do
127+
local last_module = branch:get(branch:size())
128+
local out = last_module.output
129+
local ptr = torch.pointer(out)
130+
local storagePtr = torch.pointer(out:storage())
131+
if not storageHash[storagePtr] then
132+
storageHash[storagePtr] = torch.random(1,#colorNames)
133+
table.insert(storageHash, storagePtr)
134+
end
135+
local name = torch.typename(last_module)
136+
nodes[ptr] = nodes[ptr] or createNode(name,storagePtr)
137+
addEdge(out, self.output, torch.typename(m))
138+
end
139+
end
140+
return oldf(self, input)
141+
end
142+
end
143+
144+
145+
146+
147+
net:forward(input)
148+
--createInputNode(nodes, net.output, 'Output')
149+
net:apply(apply_func)
150+
net:forward(input)
151+
152+
153+
graph.dot(g)
154+

models.lua

+88
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,94 @@ models.alexnet = function()
7676
return model, input
7777
end
7878

79+
models.googlenet = function()
80+
local function inception(input_size, config)
81+
local concat = nn.Concat(2)
82+
if config[1][1] ~= 0 then
83+
local conv1 = nn.Sequential()
84+
conv1:add(nn.SpatialConvolution(input_size, config[1][1],1,1,1,1)):add(nn.ReLU(true))
85+
concat:add(conv1)
86+
end
87+
88+
local conv3 = nn.Sequential()
89+
conv3:add(nn.SpatialConvolution( input_size, config[2][1],1,1,1,1)):add(nn.ReLU(true))
90+
conv3:add(nn.SpatialConvolution(config[2][1], config[2][2],3,3,1,1,1,1)):add(nn.ReLU(true))
91+
concat:add(conv3)
92+
93+
local conv3xx = nn.Sequential()
94+
conv3xx:add(nn.SpatialConvolution( input_size, config[3][1],1,1,1,1)):add(nn.ReLU(true))
95+
conv3xx:add(nn.SpatialConvolution(config[3][1], config[3][2],3,3,1,1,1,1)):add(nn.ReLU(true))
96+
conv3xx:add(nn.SpatialConvolution(config[3][2], config[3][2],3,3,1,1,1,1)):add(nn.ReLU(true))
97+
concat:add(conv3xx)
98+
99+
local pool = nn.Sequential()
100+
pool:add(nn.SpatialZeroPadding(1,1,1,1)) -- remove after getting nn R2 into fbcode
101+
if config[4][1] == 'max' then
102+
pool:add(nn.SpatialMaxPooling(3,3,1,1):ceil())
103+
elseif config[4][1] == 'avg' then
104+
pool:add(nn.SpatialAveragePooling(3,3,1,1):ceil())
105+
else
106+
error('Unknown pooling')
107+
end
108+
if config[4][2] ~= 0 then
109+
pool:add(nn.SpatialConvolution(input_size, config[4][2],1,1,1,1)):add(nn.ReLU(true))
110+
end
111+
concat:add(pool)
112+
113+
return concat
114+
end
115+
116+
local nClasses = 1000
117+
118+
local features = nn.Sequential()
119+
features:add(nn.SpatialConvolution(3,64,7,7,2,2,3,3)):add(nn.ReLU(true))
120+
features:add(nn.SpatialMaxPooling(3,3,2,2):ceil())
121+
features:add(nn.SpatialConvolution(64,64,1,1)):add(nn.ReLU(true))
122+
features:add(nn.SpatialConvolution(64,192,3,3,1,1,1,1)):add(nn.ReLU(true))
123+
features:add(nn.SpatialMaxPooling(3,3,2,2):ceil())
124+
features:add(inception( 192, {{ 64},{ 64, 64},{ 64, 96},{'avg', 32}})) -- 3(a)
125+
features:add(inception( 256, {{ 64},{ 64, 96},{ 64, 96},{'avg', 64}})) -- 3(b)
126+
features:add(inception( 320, {{ 0},{128,160},{ 64, 96},{'max', 0}})) -- 3(c)
127+
features:add(nn.SpatialConvolution(576,576,2,2,2,2))
128+
features:add(inception( 576, {{224},{ 64, 96},{ 96,128},{'avg',128}})) -- 4(a)
129+
features:add(inception( 576, {{192},{ 96,128},{ 96,128},{'avg',128}})) -- 4(b)
130+
features:add(inception( 576, {{160},{128,160},{128,160},{'avg', 96}})) -- 4(c)
131+
features:add(inception( 576, {{ 96},{128,192},{160,192},{'avg', 96}})) -- 4(d)
132+
133+
local main_branch = nn.Sequential()
134+
main_branch:add(inception( 576, {{ 0},{128,192},{192,256},{'max', 0}})) -- 4(e)
135+
main_branch:add(nn.SpatialConvolution(1024,1024,2,2,2,2))
136+
main_branch:add(inception(1024, {{352},{192,320},{160,224},{'avg',128}})) -- 5(a)
137+
main_branch:add(inception(1024, {{352},{192,320},{192,224},{'max',128}})) -- 5(b)
138+
main_branch:add(nn.SpatialAveragePooling(7,7,1,1))
139+
main_branch:add(nn.View(1024):setNumInputDims(3))
140+
main_branch:add(nn.Linear(1024,nClasses))
141+
main_branch:add(nn.LogSoftMax())
142+
143+
-- add auxillary classifier here (thanks to Christian Szegedy for the details)
144+
local aux_classifier = nn.Sequential()
145+
aux_classifier:add(nn.SpatialAveragePooling(5,5,3,3):ceil())
146+
aux_classifier:add(nn.SpatialConvolution(576,128,1,1,1,1))
147+
aux_classifier:add(nn.View(128*4*4):setNumInputDims(3))
148+
aux_classifier:add(nn.Linear(128*4*4,768))
149+
aux_classifier:add(nn.ReLU())
150+
aux_classifier:add(nn.Linear(768,nClasses))
151+
aux_classifier:add(nn.LogSoftMax())
152+
153+
local splitter = nn.Concat(2)
154+
splitter:add(main_branch):add(aux_classifier)
155+
local model = nn.Sequential():add(features):add(splitter)
156+
157+
model.imageSize = 256
158+
model.imageCrop = 224
159+
160+
local input = torch.rand(1,3,model.imageCrop,model.imageCrop)
161+
162+
return model, input
163+
164+
165+
end
166+
79167
models.resnet = function(opt)
80168

81169
local Convolution = nn.SpatialConvolution

0 commit comments

Comments
 (0)