Skip to content

Commit ced0209

Browse files
committed
Improve graph generation + bugfix
Now it properly handles modules that only exists for directing the flow of information in the graph, like Identity, SelectTable and NarrowTable
1 parent bf3e737 commit ced0209

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

graphgen.lua

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,27 @@ local colorNames = {
3333
"goldenrod","goldenrod1","goldenrod2","goldenrod3","goldenrod4"
3434
}
3535

36+
-- some modules exist only for constructing
37+
-- the flow of information, and should not
38+
-- have their place in the computation graph
39+
-- as separate entities
40+
local function isSingleOperationModule(m)
41+
if m.modules then
42+
return false
43+
end
44+
local constructorModules = {
45+
'nn.Identity',
46+
'nn.SelectTable',
47+
'nn.NarrowTable'
48+
}
49+
local mType = torch.typename(m)
50+
for _, v in ipairs(constructorModules) do
51+
if mType == v then
52+
return false
53+
end
54+
end
55+
return true
56+
end
3657

3758
local function generateGraph(net, input, opts)
3859
opts = opts or {}
@@ -90,7 +111,7 @@ local function generateGraph(net, input, opts)
90111
nodes[ptr] = createNode(name,input)
91112
else
92113
for k,v in ipairs(input) do
93-
createBoundaryNode(nodes, v, name..' '..k)
114+
createBoundaryNode(v, name..' '..k)
94115
end
95116
end
96117
end
@@ -104,12 +125,11 @@ local function generateGraph(net, input, opts)
104125
local toPtr = torch.pointer(to)
105126

106127
nodes[toPtr] = nodes[toPtr] or createNode(name,to)
107-
128+
108129
assert(nodes[fromPtr], 'Parent node inexistant for module '.. name)
109130

110131
-- insert edge
111132
g:add(graph.Edge(nodes[fromPtr],nodes[toPtr]))
112-
113133
elseif torch.isTensor(from) then
114134
for k,v in ipairs(to) do
115135
addEdge(from, v, name)
@@ -126,7 +146,7 @@ local function generateGraph(net, input, opts)
126146
local function apply_func(m)
127147
local basefunc = m.updateOutput
128148
m.updateOutput = function(self, input)
129-
if not m.modules then
149+
if isSingleOperationModule(m) then
130150
local name = tostring(m)
131151
if m.inplace then -- handle it differently ?
132152
addEdge(input,self.output,name)

0 commit comments

Comments
 (0)