Skip to content

Commit 541427d

Browse files
committed
Fix for inn.SpatialPyramidPooling
1 parent 7e698ef commit 541427d

File tree

1 file changed

+29
-3
lines changed

1 file changed

+29
-3
lines changed

graphgen.lua

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,34 @@ local function isSingleOperationModule(m)
5757
return true
5858
end
5959

60+
local function isOperativeContainer(m)
61+
local mType = torch.typename(m)
62+
63+
local opContainers = {
64+
'nn.Concat',
65+
'nn.Parallel',
66+
'nn.DepthConcat'
67+
}
68+
for _, v in ipairs(opContainers) do
69+
if mType == v then
70+
return true
71+
end
72+
end
73+
74+
-- those modules heritate from an
75+
-- operative container like nn.Concat
76+
local fakeContainers = {
77+
'inn.SpatialPyramidPooling',
78+
}
79+
for _, v in ipairs(fakeContainers) do
80+
if mType == v then
81+
return true
82+
end
83+
end
84+
85+
return false
86+
end
87+
6088
-- generates a graph from a nn network
6189
-- Arguments:
6290
-- net: nn network
@@ -167,9 +195,7 @@ local function generateGraph(net, input, opts)
167195
else
168196
addEdge(input,self.output,name)
169197
end
170-
elseif torch.typename(m) == 'nn.Concat' or
171-
torch.typename(m) == 'nn.Parallel' or
172-
torch.typename(m) == 'nn.DepthConcat' then
198+
elseif isOperativeContainer(m) then
173199
-- those containers effectively do some computation, so they have their
174200
-- place in the graph
175201
for i,branch in ipairs(m.modules) do

0 commit comments

Comments
 (0)