Skip to content

Commit d3d5671

Browse files
committed
Trying to fix graphgen with parallel containers
1 parent 66c55fa commit d3d5671

File tree

2 files changed

+76
-2
lines changed

2 files changed

+76
-2
lines changed

Diff for: graphgen.lua

+20-2
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ local function generateGraph(net, input, opts)
102102

103103
local storageHash = {}
104104
local nodes = {}
105+
local trickyNodes = {}
105106

106107
local g = graph.Graph()
107108

@@ -168,7 +169,17 @@ local function generateGraph(net, input, opts)
168169

169170
nodes[toPtr] = nodes[toPtr] or createNode(name,to)
170171

171-
assert(nodes[fromPtr], 'Parent node inexistant for module '.. name)
172+
--assert(nodes[fromPtr], 'Parent node inexistant for module '.. name)
173+
if not nodes[fromPtr] then
174+
--[[
175+
print('Printing debug')
176+
print(debug.getinfo(2))
177+
--]]
178+
179+
nodes[fromPtr] = createNode('oups',from)
180+
table.insert(trickyNodes, fromPtr)
181+
trickyNodes[fromPtr] = nodes[fromPtr]
182+
end
172183

173184
-- insert edge
174185
g:add(graph.Edge(nodes[fromPtr],nodes[toPtr]))
@@ -199,7 +210,14 @@ local function generateGraph(net, input, opts)
199210
-- those containers effectively do some computation, so they have their
200211
-- place in the graph
201212
for i,branch in ipairs(m.modules) do
202-
local last_module = branch:get(branch:size())
213+
--local last_module = branch:get(branch:size())
214+
local last_module
215+
if branch.modues then
216+
last_module = branch:get(#branch.modules)
217+
else
218+
last_module = branch
219+
end
220+
203221
local out = last_module.output
204222
local ptr = torch.pointer(out)
205223

Diff for: models.lua

+56
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,62 @@ models.siamese = function()
7272
return m, input
7373
end
7474

75+
models.siamese_parallel = function()
76+
local fSize = {1, 32, 64}
77+
local featuresOut = 128
78+
79+
local desc = nn.Sequential()
80+
desc:add(nn.Reshape(1,64,64))
81+
desc:add(nn.SpatialAveragePooling(2,2,2,2))
82+
desc:add(nn.SpatialConvolution(fSize[1], fSize[2], 7,7))
83+
desc:add(nn.ReLU())
84+
desc:add(nn.SpatialMaxPooling(2,2,2,2))
85+
desc:add(nn.SpatialConvolution(fSize[2], fSize[3], 6,6))
86+
desc:add(nn.ReLU())
87+
desc:add(nn.View(-1):setNumInputDims(3))
88+
desc:add(nn.Linear(4096, 128))
89+
desc:add(nn.Contiguous())
90+
91+
local siamese = nn.Parallel(2,2)
92+
local siam = desc:clone()
93+
desc:share(siam, 'weight', 'bias', 'gradWeight', 'gradBias')
94+
siamese:add(desc)
95+
siamese:add(siam)
96+
97+
local top = nn.Sequential()
98+
top:add(nn.Linear(featuresOut*2, featuresOut*2))
99+
top:add(nn.ReLU())
100+
top:add(nn.Linear(featuresOut*2, 1))
101+
102+
local model = nn.Sequential():add(siamese):add(top)
103+
104+
local input = torch.rand(1,2,64,64)
105+
106+
return model, input
107+
end
108+
109+
models.basic_parallel_middle = function()
110+
local model = nn.Sequential():add(nn.Linear(2,2))
111+
local prl = nn.Parallel(2,1)
112+
prl:add(nn.Linear(2,2))
113+
prl:add(nn.Linear(2,2))
114+
model:add(prl)
115+
local input = torch.rand(2,2)
116+
return model, input
117+
end
118+
119+
models.basic_splitTable = function()
120+
local model = nn.Sequential():add(nn.Linear(2,2))
121+
model:add(nn.SplitTable(2))
122+
local prl = nn.ParallelTable()
123+
prl:add(nn.ReLU())
124+
prl:add(nn.Sigmoid())
125+
model:add(prl)
126+
model:add(nn.JoinTable(1))
127+
local input = torch.rand(2,2)
128+
return model, input
129+
end
130+
75131
models.basic_concat = function()
76132
local m = nn.Sequential()
77133
local cat = nn.ConcatTable()

0 commit comments

Comments
 (0)