Skip to content

Commit ea0b101

Browse files
committed
Checks to avoid double optimization
Also contains a fix for branches with modules instead of containers
1 parent b8d58d8 commit ea0b101

File tree

1 file changed

+25
-1
lines changed

1 file changed

+25
-1
lines changed

init.lua

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,13 @@ local function analyse(net, input, opts)
6767
utils.keepTrack(self.output, track, entry_fun, fun, opts)
6868

6969
for i,branch in ipairs(m.modules) do
70-
local last_module = branch:get(branch:size())
70+
local last_module
71+
-- if brach is a container, get its last element, if not, take it
72+
if branch.modules then
73+
last_module = branch:get(branch:size())
74+
else
75+
last_module = branch
76+
end
7177
local out = last_module.output
7278
opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined
7379
utils.keepTrack(out, track, entry_fun, fun, opts)
@@ -343,6 +349,13 @@ end
343349

344350
function optnet.optimizeMemory(net, input, opts)
345351
opts = opts or {}
352+
353+
if net.__memoryOptimized then
354+
print('Skipping memory optimization. '..
355+
'Network is already optimized for '..net.__memoryOptimized..' mode.')
356+
return
357+
end
358+
346359
local mode = defaultValue(opts.mode,'inference')
347360

348361
local out = net['forward'](net, input)
@@ -363,9 +376,18 @@ function optnet.optimizeMemory(net, input, opts)
363376
--print(assignments)
364377
applyAssignments(net, assignments)
365378
resetInputDescriptors(net)
379+
380+
-- add flag to mention that it was optimized
381+
net.__memoryOptimized = mode
366382
end
367383

368384
function optnet.removeOptimization(net)
385+
386+
if not net.__memoryOptimized then
387+
print('Skipping memory optimization removal, as the network was not optimized.')
388+
return
389+
end
390+
369391
local function rem(m)
370392
if torch.isTensor(m) then
371393
m:set()
@@ -394,6 +416,8 @@ function optnet.removeOptimization(net)
394416
end)
395417
resetInputDescriptors(net)
396418
addGradParams(net)
419+
420+
net.__memoryOptimized = nil
397421
end
398422

399423
return optnet

0 commit comments

Comments
 (0)