@@ -67,7 +67,13 @@ local function analyse(net, input, opts)
67
67
utils .keepTrack (self .output , track , entry_fun , fun , opts )
68
68
69
69
for i ,branch in ipairs (m .modules ) do
70
- local last_module = branch :get (branch :size ())
70
+ local last_module
71
+ -- if brach is a container, get its last element, if not, take it
72
+ if branch .modules then
73
+ last_module = branch :get (branch :size ())
74
+ else
75
+ last_module = branch
76
+ end
71
77
local out = last_module .output
72
78
opts .var = ' defined' ; opts .f = math.min ; opts .notUsed = kNotDefined
73
79
utils .keepTrack (out , track , entry_fun , fun , opts )
343
349
344
350
function optnet .optimizeMemory (net , input , opts )
345
351
opts = opts or {}
352
+
353
+ if net .__memoryOptimized then
354
+ print (' Skipping memory optimization. ' ..
355
+ ' Network is already optimized for ' .. net .__memoryOptimized .. ' mode.' )
356
+ return
357
+ end
358
+
346
359
local mode = defaultValue (opts .mode ,' inference' )
347
360
348
361
local out = net [' forward' ](net , input )
@@ -363,9 +376,18 @@ function optnet.optimizeMemory(net, input, opts)
363
376
-- print(assignments)
364
377
applyAssignments (net , assignments )
365
378
resetInputDescriptors (net )
379
+
380
+ -- add flag to mention that it was optimized
381
+ net .__memoryOptimized = mode
366
382
end
367
383
368
384
function optnet .removeOptimization (net )
385
+
386
+ if not net .__memoryOptimized then
387
+ print (' Skipping memory optimization removal, as the network was not optimized.' )
388
+ return
389
+ end
390
+
369
391
local function rem (m )
370
392
if torch .isTensor (m ) then
371
393
m :set ()
@@ -394,6 +416,8 @@ function optnet.removeOptimization(net)
394
416
end )
395
417
resetInputDescriptors (net )
396
418
addGradParams (net )
419
+
420
+ net .__memoryOptimized = nil
397
421
end
398
422
399
423
return optnet
0 commit comments