@@ -30,34 +30,9 @@ local function resizeAndConvert(input, type)
3030 return res
3131end
3232
33- -- what a pain... will finish later
3433local function cudnnSetDeterministic (net )
35- local conv_data = ' CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1'
36- local conv_weight = ' CUDNN_CONVOLUTION_BWD_DATA_ALGO_1'
37-
3834 net :apply (function (m )
39- if torch .typename (m ) == ' cudnn.SpatialConvolution' then
40- local algWorkspaceLimit = (m .nInputPlane * m .kH * m .kW * 4 )
41-
42- local algType_filter = ffi .new (" cudnnConvolutionBwdFilterAlgo_t[?]" , 1 )
43- local algSearchMode_data = ' CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE'
44- cudnn .errcheck (' cudnnGetConvolutionBackwardFilterAlgorithm' ,
45- cudnn .getHandle (),
46- m .iDesc [0 ], m .oDesc [0 ],
47- m .convDesc [0 ], m .weightDesc [0 ],
48- algSearchMode_data , algWorkspaceLimit , algType_filter )
49-
50- local algType_data = ffi .new (" cudnnConvolutionBwdDataAlgo_t[?]" , 1 )
51- local algSearchMode = ' CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE'
52-
53- cudnn .errcheck (' cudnnGetConvolutionBackwardDataAlgorithm' ,
54- cudnn .getHandle (),
55- self .weightDesc [0 ], self .oDesc [0 ],
56- self .convDesc [0 ], self .iDesc [0 ],
57- algSearchMode , algWorkspaceLimit , algType_data )
58-
59- m :setMode (nil , algType_filter [0 ], algType_data [0 ])
60- end
35+ if m .setMode then m :setMode (1 , 1 , 1 ) end
6136 end )
6237end
6338
@@ -74,12 +49,12 @@ local function genericTestForward(model,opts)
7449
7550 local out_orig = net :forward (input ):clone ()
7651
77- local mems1 = optnet .countUsedMemory (net , input )
52+ local mems1 = optnet .countUsedMemory (net )
7853
7954 optnet .optimizeMemory (net , input )
8055
8156 local out = net :forward (input ):clone ()
82- local mems2 = countUsedMemory (net , input )
57+ local mems2 = countUsedMemory (net )
8358 tester :eq (out_orig , out , ' Outputs differ after optimization of ' .. model )
8459
8560 local mem1 = mems1 .total_size
@@ -102,7 +77,7 @@ local function genericTestForward(model,opts)
10277 print (' Buffers' ,bmem1 / 1024 / 1024 ,bmem2 / 1024 / 1024 , 1 - bmem2 / bmem1 )
10378 print (' Params' , pmem1 / 1024 / 1024 ,pmem2 / 1024 / 1024 , 1 - pmem2 / pmem1 )
10479end
105- -- [[
80+
10681function optest .basic ()
10782 genericTestForward (' basic1' )
10883end
@@ -146,7 +121,7 @@ function optest.resnet110()
146121 local opts = {dataset = ' cifar10' ,depth = 110 }
147122 genericTestForward (' resnet' , opts )
148123end
149- -- ]]
124+
150125
151126---- ---------------------------------------------
152127-- Backward
@@ -171,10 +146,8 @@ local function genericTestBackward(model,opts)
171146 net :training ()
172147
173148 if use_cudnn then
174- -- for the moment disable cudnn checks, because its backward pass is
175- -- by default non-deterministic, breaking the tests
176- -- cudnn.convert(net,cudnn);
177- -- cudnnSetDeterministic(net)
149+ cudnn .convert (net ,cudnn );
150+ cudnnSetDeterministic (net )
178151 net :cuda ();
179152
180153 input = resizeAndConvert (input ,' torch.CudaTensor' )
@@ -187,7 +160,7 @@ local function genericTestBackward(model,opts)
187160 local _ , gradParams_orig = net :getParameters ()
188161 gradParams_orig = gradParams_orig :clone ()
189162
190- local mems1 = optnet .countUsedMemory (net , input )
163+ local mems1 = optnet .countUsedMemory (net )
191164
192165 optnet .optimizeMemory (net , input , {mode = ' training' })
193166
@@ -198,7 +171,7 @@ local function genericTestBackward(model,opts)
198171 local _ , gradParams = net :getParameters ()
199172 gradParams = gradParams :clone ()
200173
201- local mems2 = countUsedMemory (net , input )
174+ local mems2 = countUsedMemory (net )
202175 tester :eq (out_orig , out , ' Outputs differ after optimization of ' .. model )
203176 tester :eq (gradInput_orig , gradInput , backward_tol , ' GradInputs differ after optimization of ' .. model )
204177 tester :eq (gradParams_orig , gradParams , backward_tol , ' GradParams differ after optimization of ' .. model )
@@ -271,6 +244,10 @@ function optest.resnet56_backward()
271244 genericTestBackward (' resnet' , opts )
272245end
273246
247+ function optest .resnet110_backward ()
248+ local opts = {dataset = ' cifar10' ,depth = 110 }
249+ genericTestBackward (' resnet' , opts )
250+ end
274251
275252tester :add (optest )
276253
0 commit comments