Skip to content

Commit 7accc94

Browse files
committed
Mode recusiveClone to utils
1 parent ced0209 commit 7accc94

File tree

3 files changed

+30
-44
lines changed

3 files changed

+30
-44
lines changed

init.lua

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,6 @@ require 'optnet.tests'
66

77
local utils = require 'optnet.utils'
88

9-
local function recursiveClone(out)
10-
if torch.isTensor(out) then
11-
return out:clone()
12-
else
13-
local res = {}
14-
for k, v in ipairs(out) do
15-
res[k] = recursiveClone(v)
16-
end
17-
return res
18-
end
19-
end
20-
21-
22-
239
local kNotUsed = 10000---1
2410
local kNotDefined = 0
2511
local kMinimumForSharing = 2
@@ -147,7 +133,7 @@ local function analyse(net, input, opts)
147133
local out = net['forward'](net, input)
148134
local grad
149135
if mode == 'training' then
150-
grad = recursiveClone(out)
136+
grad = utils.recursiveClone(out)
151137
net['backward'](net, input, grad)
152138
end
153139
local function trackInputs(t, name)
@@ -362,7 +348,7 @@ function optnet.optimizeMemory(net, input, opts)
362348
local out = net['forward'](net, input)
363349
local grad
364350
if mode == 'training' then
365-
grad = recursiveClone(out)
351+
grad = utils.recursiveClone(out)
366352
net['backward'](net, input, grad)
367353
end
368354

tests.lua

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
local optnet = require 'optnet.env'
22
local models = require 'optnet.models'
3+
local utils = require 'optnet.utils'
34
local countUsedMemory = optnet.countUsedMemory
45

56
local optest = torch.TestSuite()
@@ -23,20 +24,6 @@ local function resizeAndConvert(input, type)
2324
return res
2425
end
2526

26-
-- reuse this function
27-
local function recursiveClone(out)
28-
if torch.isTensor(out) then
29-
return out:clone()
30-
else
31-
local res = {}
32-
for k, v in ipairs(out) do
33-
res[k] = recursiveClone(v)
34-
end
35-
return res
36-
end
37-
end
38-
39-
4027
local function cudnnSetDeterministic(net)
4128
net:apply(function(m)
4229
if m.setMode then m:setMode(1, 1, 1) end
@@ -54,13 +41,13 @@ local function genericTestForward(model,opts)
5441
input = resizeAndConvert(input,'torch.CudaTensor')
5542
end
5643

57-
local out_orig = recursiveClone(net:forward(input))
44+
local out_orig = utils.recursiveClone(net:forward(input))
5845

5946
local mems1 = optnet.countUsedMemory(net)
6047

6148
optnet.optimizeMemory(net, input)
6249

63-
local out = recursiveClone(net:forward(input))
50+
local out = utils.recursiveClone(net:forward(input))
6451
local mems2 = countUsedMemory(net)
6552
tester:eq(out_orig, out, 'Outputs differ after optimization of '..model)
6653

@@ -101,21 +88,21 @@ local function genericTestBackward(model,opts)
10188
input = resizeAndConvert(input,'torch.CudaTensor')
10289
end
10390

104-
local out_orig = recursiveClone(net:forward(input))
105-
local grad_orig = recursiveClone(out_orig)
91+
local out_orig = utils.recursiveClone(net:forward(input))
92+
local grad_orig = utils.recursiveClone(out_orig)
10693
net:zeroGradParameters()
107-
local gradInput_orig = recursiveClone(net:backward(input, grad_orig))
94+
local gradInput_orig = utils.recursiveClone(net:backward(input, grad_orig))
10895
local _, gradParams_orig = net:getParameters()
10996
gradParams_orig = gradParams_orig:clone()
11097

11198
local mems1 = optnet.countUsedMemory(net)
11299

113100
optnet.optimizeMemory(net, input, {mode='training'})
114101

115-
local out = recursiveClone(net:forward(input))
116-
local grad = recursiveClone(out)
102+
local out = utils.recursiveClone(net:forward(input))
103+
local grad = utils.recursiveClone(out)
117104
net:zeroGradParameters()
118-
local gradInput = recursiveClone(net:backward(input, grad))
105+
local gradInput = utils.recursiveClone(net:backward(input, grad))
119106
local _, gradParams = net:getParameters()
120107
gradParams = gradParams:clone()
121108

@@ -165,20 +152,20 @@ local function genericTestRemoveOptim(model,opts)
165152
input = resizeAndConvert(input,'torch.CudaTensor')
166153
end
167154

168-
local out_orig = recursiveClone(net:forward(input))
169-
local grad_orig = recursiveClone(out_orig)
155+
local out_orig = utils.recursiveClone(net:forward(input))
156+
local grad_orig = utils.recursiveClone(out_orig)
170157
net:zeroGradParameters()
171-
local gradInput_orig = recursiveClone(net:backward(input, grad_orig))
158+
local gradInput_orig = utils.recursiveClone(net:backward(input, grad_orig))
172159
local _, gradParams_orig = net:getParameters()
173160
gradParams_orig = gradParams_orig:clone()
174161

175162
optnet.optimizeMemory(net, input)
176163
optnet.removeOptimization(net)
177164

178-
local out = recursiveClone(net:forward(input))
179-
local grad = recursiveClone(out)
165+
local out = utils.recursiveClone(net:forward(input))
166+
local grad = utils.recursiveClone(out)
180167
net:zeroGradParameters()
181-
local gradInput = recursiveClone(net:backward(input, grad))
168+
local gradInput = utils.recursiveClone(net:backward(input, grad))
182169
local _, gradParams = net:getParameters()
183170
gradParams = gradParams:clone()
184171

utils.lua

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,17 @@ local function keepTrack(t, track, entry_fun, fun, ...)
1919
end
2020
utils.keepTrack = keepTrack
2121

22+
local function recursiveClone(out)
23+
if torch.isTensor(out) then
24+
return out:clone()
25+
else
26+
local res = {}
27+
for k, v in ipairs(out) do
28+
res[k] = recursiveClone(v)
29+
end
30+
return res
31+
end
32+
end
33+
utils.recursiveClone = recursiveClone
34+
2235
return utils

0 commit comments

Comments
 (0)