11local optnet = require ' optnet.env'
22local models = require ' optnet.models'
3+ local utils = require ' optnet.utils'
34local countUsedMemory = optnet .countUsedMemory
45
56local optest = torch .TestSuite ()
@@ -23,20 +24,6 @@ local function resizeAndConvert(input, type)
2324 return res
2425end
2526
26- -- reuse this function
27- local function recursiveClone (out )
28- if torch .isTensor (out ) then
29- return out :clone ()
30- else
31- local res = {}
32- for k , v in ipairs (out ) do
33- res [k ] = recursiveClone (v )
34- end
35- return res
36- end
37- end
38-
39-
4027local function cudnnSetDeterministic (net )
4128 net :apply (function (m )
4229 if m .setMode then m :setMode (1 , 1 , 1 ) end
@@ -54,13 +41,13 @@ local function genericTestForward(model,opts)
5441 input = resizeAndConvert (input ,' torch.CudaTensor' )
5542 end
5643
57- local out_orig = recursiveClone (net :forward (input ))
44+ local out_orig = utils . recursiveClone (net :forward (input ))
5845
5946 local mems1 = optnet .countUsedMemory (net )
6047
6148 optnet .optimizeMemory (net , input )
6249
63- local out = recursiveClone (net :forward (input ))
50+ local out = utils . recursiveClone (net :forward (input ))
6451 local mems2 = countUsedMemory (net )
6552 tester :eq (out_orig , out , ' Outputs differ after optimization of ' .. model )
6653
@@ -101,21 +88,21 @@ local function genericTestBackward(model,opts)
10188 input = resizeAndConvert (input ,' torch.CudaTensor' )
10289 end
10390
104- local out_orig = recursiveClone (net :forward (input ))
105- local grad_orig = recursiveClone (out_orig )
91+ local out_orig = utils . recursiveClone (net :forward (input ))
92+ local grad_orig = utils . recursiveClone (out_orig )
10693 net :zeroGradParameters ()
107- local gradInput_orig = recursiveClone (net :backward (input , grad_orig ))
94+ local gradInput_orig = utils . recursiveClone (net :backward (input , grad_orig ))
10895 local _ , gradParams_orig = net :getParameters ()
10996 gradParams_orig = gradParams_orig :clone ()
11097
11198 local mems1 = optnet .countUsedMemory (net )
11299
113100 optnet .optimizeMemory (net , input , {mode = ' training' })
114101
115- local out = recursiveClone (net :forward (input ))
116- local grad = recursiveClone (out )
102+ local out = utils . recursiveClone (net :forward (input ))
103+ local grad = utils . recursiveClone (out )
117104 net :zeroGradParameters ()
118- local gradInput = recursiveClone (net :backward (input , grad ))
105+ local gradInput = utils . recursiveClone (net :backward (input , grad ))
119106 local _ , gradParams = net :getParameters ()
120107 gradParams = gradParams :clone ()
121108
@@ -165,20 +152,20 @@ local function genericTestRemoveOptim(model,opts)
165152 input = resizeAndConvert (input ,' torch.CudaTensor' )
166153 end
167154
168- local out_orig = recursiveClone (net :forward (input ))
169- local grad_orig = recursiveClone (out_orig )
155+ local out_orig = utils . recursiveClone (net :forward (input ))
156+ local grad_orig = utils . recursiveClone (out_orig )
170157 net :zeroGradParameters ()
171- local gradInput_orig = recursiveClone (net :backward (input , grad_orig ))
158+ local gradInput_orig = utils . recursiveClone (net :backward (input , grad_orig ))
172159 local _ , gradParams_orig = net :getParameters ()
173160 gradParams_orig = gradParams_orig :clone ()
174161
175162 optnet .optimizeMemory (net , input )
176163 optnet .removeOptimization (net )
177164
178- local out = recursiveClone (net :forward (input ))
179- local grad = recursiveClone (out )
165+ local out = utils . recursiveClone (net :forward (input ))
166+ local grad = utils . recursiveClone (out )
180167 net :zeroGradParameters ()
181- local gradInput = recursiveClone (net :backward (input , grad ))
168+ local gradInput = utils . recursiveClone (net :backward (input , grad ))
182169 local _ , gradParams = net :getParameters ()
183170 gradParams = gradParams :clone ()
184171
0 commit comments