diff --git a/graphgen.lua b/graphgen.lua index 7f07a26..f75c069 100644 --- a/graphgen.lua +++ b/graphgen.lua @@ -102,6 +102,9 @@ local function generateGraph(net, input, opts) local storageHash = {} local nodes = {} + local trickyNodes = {} + local current_module = {__input=input} + local stack_visited_modules = {} local g = graph.Graph() @@ -158,6 +161,45 @@ local function generateGraph(net, input, opts) end end + local origTorchFuncs = {DoubleTensor={},FloatTensor={}} + -- also hack the cuda counter-parts if cutorch is loaded + if package.loaded.cutorch then + origTorchFuncs.CudaTensor = {} + end + -- list of functions to hack. seems that can't extend due to stack + -- overflow reasons + local hackableTorchFuncs = {'select','__index'} + + -- we will temporarily overwrite torch functions to keep track + -- of all created tensors during the forward call. This will + -- allow us to handle some corner cases where the input tensor is + -- not part of the state of a module (i.e., it's not the output + -- of another module) + local function hackTorch() + for torchType, t in pairs(origTorchFuncs) do + for _, func in ipairs(hackableTorchFuncs) do + local oldFunc = torch[torchType][func] + t[func] = oldFunc + torch[torchType][func] = function(...) + local res = oldFunc(...) + if res then + -- heavy use of upvalues + trickyNodes[torch.pointer(res)] = {current_module, 'torch.'..func} + end + return res + end + end + end + end + + local function unhackTorch() + for torchType, t in pairs(origTorchFuncs) do + for _, func in ipairs(hackableTorchFuncs) do + torch[torchType][func] = t[func] + end + end + end + -- create edge "from" -> "to", creating "to" on the way with "name" -- the edges can be seen as linking modules, but in fact it links the output -- tensor of each module @@ -168,8 +210,19 @@ local function generateGraph(net, input, opts) nodes[toPtr] = nodes[toPtr] or createNode(name,to) - assert(nodes[fromPtr], 'Parent node inexistant for module '.. name) - + -- if "from" tensor is not present in "nodes" table, this means that + -- "from" is not the output of a module, and was created on the fly + -- during for example a slicing of a tensor. "trickyNodes" contains + -- all tensors that were generated on the fly + if not nodes[fromPtr] then + local trickyNode = trickyNodes[fromPtr] + assert(trickyNode, "Could't handle previous node to "..name) + local trickyNodeName = trickyNode[2] + + local trickyParentFrom = trickyNode[1].__input + addEdge(trickyParentFrom,from,trickyNodeName) + end + -- insert edge g:add(graph.Edge(nodes[fromPtr],nodes[toPtr])) elseif torch.isTensor(from) then @@ -188,6 +241,14 @@ local function generateGraph(net, input, opts) local function apply_func(m) local basefunc = m.updateOutput m.updateOutput = function(self, input) + -- add input to self to help keep track of it + self.__input = input + -- keeps a stack of visited modules + table.insert(stack_visited_modules, current_module) + current_module = self + local output = basefunc(self, input) + current_module = table.remove(stack_visited_modules) + -- add edges to the graph according to the node type if isSingleOperationModule(m) then local name = tostring(m) if m.inplace then -- handle it differently ? @@ -199,7 +260,13 @@ local function generateGraph(net, input, opts) -- those containers effectively do some computation, so they have their -- place in the graph for i,branch in ipairs(m.modules) do - local last_module = branch:get(branch:size()) + local last_module + if branch.modules then + last_module = branch:get(branch:size()) + else + last_module = branch + end + local out = last_module.output local ptr = torch.pointer(out) @@ -208,20 +275,20 @@ local function generateGraph(net, input, opts) addEdge(out, self.output, torch.typename(m)) end end - return basefunc(self, input) + return output end end createBoundaryNode(input, 'Input') - -- fill the states from each tensor - net:forward(input) - + hackTorch() -- overwriting the standard functions to generate our graph net:apply(apply_func) -- generate the graph net:forward(input) + unhackTorch() + if opts.addOutputNode then -- add dummy output node and link the last module to it local output = utils.recursiveClone(net.output) @@ -245,6 +312,7 @@ local function generateGraph(net, input, opts) -- clean up the modified function net:apply(function(x) x.updateOutput = nil + x.__input = nil end) return g diff --git a/models.lua b/models.lua index 8269c71..fd05903 100644 --- a/models.lua +++ b/models.lua @@ -72,6 +72,62 @@ models.siamese = function() return m, input end +models.siamese_parallel = function() + local fSize = {1, 32, 64} + local featuresOut = 128 + + local desc = nn.Sequential() + desc:add(nn.Reshape(1,64,64)) + desc:add(nn.SpatialAveragePooling(2,2,2,2)) + desc:add(nn.SpatialConvolution(fSize[1], fSize[2], 7,7)) + desc:add(nn.ReLU()) + desc:add(nn.SpatialMaxPooling(2,2,2,2)) + desc:add(nn.SpatialConvolution(fSize[2], fSize[3], 6,6)) + desc:add(nn.ReLU()) + desc:add(nn.View(-1):setNumInputDims(3)) + desc:add(nn.Linear(4096, 128)) + desc:add(nn.Contiguous()) + + local siamese = nn.Parallel(2,2) + local siam = desc:clone() + desc:share(siam, 'weight', 'bias', 'gradWeight', 'gradBias') + siamese:add(desc) + siamese:add(siam) + + local top = nn.Sequential() + top:add(nn.Linear(featuresOut*2, featuresOut*2)) + top:add(nn.ReLU()) + top:add(nn.Linear(featuresOut*2, 1)) + + local model = nn.Sequential():add(siamese):add(top) + + local input = torch.rand(1,2,64,64) + + return model, input +end + +models.basic_parallel_middle = function() + local model = nn.Sequential():add(nn.Linear(2,2)) + local prl = nn.Parallel(2,1) + prl:add(nn.Linear(2,2)) + prl:add(nn.Linear(2,2)) + model:add(prl) + local input = torch.rand(2,2) + return model, input +end + +models.basic_splitTable = function() + local model = nn.Sequential():add(nn.Linear(2,2)) + model:add(nn.SplitTable(2)) + local prl = nn.ParallelTable() + prl:add(nn.ReLU()) + prl:add(nn.Sigmoid()) + model:add(prl) + model:add(nn.JoinTable(1)) + local input = torch.rand(2,2) + return model, input +end + models.basic_concat = function() local m = nn.Sequential() local cat = nn.ConcatTable()