@@ -33,6 +33,27 @@ local colorNames = {
3333 " goldenrod" ," goldenrod1" ," goldenrod2" ," goldenrod3" ," goldenrod4"
3434}
3535
36+ -- some modules exist only for constructing
37+ -- the flow of information, and should not
38+ -- have their place in the computation graph
39+ -- as separate entities
40+ local function isSingleOperationModule (m )
41+ if m .modules then
42+ return false
43+ end
44+ local constructorModules = {
45+ ' nn.Identity' ,
46+ ' nn.SelectTable' ,
47+ ' nn.NarrowTable'
48+ }
49+ local mType = torch .typename (m )
50+ for _ , v in ipairs (constructorModules ) do
51+ if mType == v then
52+ return false
53+ end
54+ end
55+ return true
56+ end
3657
3758local function generateGraph (net , input , opts )
3859 opts = opts or {}
@@ -90,7 +111,7 @@ local function generateGraph(net, input, opts)
90111 nodes [ptr ] = createNode (name ,input )
91112 else
92113 for k ,v in ipairs (input ) do
93- createBoundaryNode (nodes , v , name .. ' ' .. k )
114+ createBoundaryNode (v , name .. ' ' .. k )
94115 end
95116 end
96117 end
@@ -104,12 +125,11 @@ local function generateGraph(net, input, opts)
104125 local toPtr = torch .pointer (to )
105126
106127 nodes [toPtr ] = nodes [toPtr ] or createNode (name ,to )
107-
128+
108129 assert (nodes [fromPtr ], ' Parent node inexistant for module ' .. name )
109130
110131 -- insert edge
111132 g :add (graph .Edge (nodes [fromPtr ],nodes [toPtr ]))
112-
113133 elseif torch .isTensor (from ) then
114134 for k ,v in ipairs (to ) do
115135 addEdge (from , v , name )
@@ -126,7 +146,7 @@ local function generateGraph(net, input, opts)
126146 local function apply_func (m )
127147 local basefunc = m .updateOutput
128148 m .updateOutput = function (self , input )
129- if not m . modules then
149+ if isSingleOperationModule ( m ) then
130150 local name = tostring (m )
131151 if m .inplace then -- handle it differently ?
132152 addEdge (input ,self .output ,name )
0 commit comments