Skip to content

Commit bdc69c0

Browse files
committed
Simplify memory count, and add deterministic cudnn
1 parent b587c78 commit bdc69c0

File tree

3 files changed

+31
-67
lines changed

3 files changed

+31
-67
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,15 +105,15 @@ models = require 'optnet.models'
105105
modelname = 'googlenet'
106106
net, input = models[modelname]()
107107

108-
mem1 = optnet.countUsedMemory(net, input)
108+
mem1 = optnet.countUsedMemory(net)
109109

110110
optnet.optimizeMemory(net, input)
111111

112-
mem2 = optnet.countUsedMemory(net, input)
112+
mem2 = optnet.countUsedMemory(net)
113113

114114
optnet.removeOptimization(net)
115115

116-
mem3 = optnet.countUsedMemory(net, input)
116+
mem3 = optnet.countUsedMemory(net)
117117

118118
print('Before optimization : '.. mem1.total_size/1024/1024 .. ' MBytes')
119119
print('After optimization : '.. mem2.total_size/1024/1024 .. ' MBytes')

countUsedMemory.lua

Lines changed: 15 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,42 +2,29 @@ local optnet = require 'optnet.env'
22
local utils = require 'optnet.utils'
33
local keepTrack = utils.keepTrack
44

5-
function optnet.countUsedMemory(net, input, opts)
6-
opts = opts or {}
7-
local func = opts.func or 'updateOutput'
8-
net[func](net, input)
5+
function optnet.countUsedMemory(net)
96
local tensors = {outputs={},buffers={},params={},gradInputs={}}
107
local function entry_fun(t)
118
return t
129
end
13-
local function new_func(m)
14-
local basefunc = m[func]
15-
m[func] = function(self, input)
16-
--keepTrack(input, tensors, entry_fun)
17-
keepTrack(self.output, tensors.outputs, entry_fun)
18-
keepTrack(self.gradInput, tensors.gradInputs, entry_fun)
19-
for k, v in pairs(self) do
20-
if torch.isTensor(v) and
21-
k ~= 'weight' and k ~= 'bias' and
22-
k ~= 'gradWeight' and k ~= 'gradBias' and
23-
k ~= 'output' and k ~= 'gradInput' then
24-
keepTrack(v, tensors.buffers, entry_fun)
25-
end
10+
local function count_func(self)
11+
keepTrack(self.output, tensors.outputs, entry_fun)
12+
keepTrack(self.gradInput, tensors.gradInputs, entry_fun)
13+
for k, v in pairs(self) do
14+
if torch.isTensor(v) and
15+
k ~= 'weight' and k ~= 'bias' and
16+
k ~= 'gradWeight' and k ~= 'gradBias' and
17+
k ~= 'output' and k ~= 'gradInput' then
18+
keepTrack(v, tensors.buffers, entry_fun)
2619
end
27-
for _, k in ipairs({'weight', 'bias', 'gradWeight','gradBias'}) do
28-
if self[k] then
29-
keepTrack(self[k], tensors.params, entry_fun)
30-
end
20+
end
21+
for _, k in ipairs({'weight', 'bias', 'gradWeight','gradBias'}) do
22+
if self[k] then
23+
keepTrack(self[k], tensors.params, entry_fun)
3124
end
32-
return basefunc(self, input)
3325
end
3426
end
35-
net:apply(new_func)
36-
net[func](net, input)
37-
-- clean up the modified function
38-
net:apply(function(x)
39-
x[func] = nil
40-
end)
27+
net:apply(count_func)
4128
local total_size = 0
4229
local sizes = {}
4330
for typeTensor, subTensors in pairs(tensors) do

tests.lua

Lines changed: 13 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -30,34 +30,9 @@ local function resizeAndConvert(input, type)
3030
return res
3131
end
3232

33-
-- what a pain... will finish later
3433
local function cudnnSetDeterministic(net)
35-
local conv_data = 'CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1'
36-
local conv_weight = 'CUDNN_CONVOLUTION_BWD_DATA_ALGO_1'
37-
3834
net:apply(function(m)
39-
if torch.typename(m) == 'cudnn.SpatialConvolution' then
40-
local algWorkspaceLimit = (m.nInputPlane * m.kH * m.kW * 4)
41-
42-
local algType_filter = ffi.new("cudnnConvolutionBwdFilterAlgo_t[?]", 1)
43-
local algSearchMode_data = 'CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE'
44-
cudnn.errcheck('cudnnGetConvolutionBackwardFilterAlgorithm',
45-
cudnn.getHandle(),
46-
m.iDesc[0], m.oDesc[0],
47-
m.convDesc[0], m.weightDesc[0],
48-
algSearchMode_data, algWorkspaceLimit, algType_filter)
49-
50-
local algType_data = ffi.new("cudnnConvolutionBwdDataAlgo_t[?]", 1)
51-
local algSearchMode = 'CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE'
52-
53-
cudnn.errcheck('cudnnGetConvolutionBackwardDataAlgorithm',
54-
cudnn.getHandle(),
55-
self.weightDesc[0], self.oDesc[0],
56-
self.convDesc[0], self.iDesc[0],
57-
algSearchMode, algWorkspaceLimit, algType_data)
58-
59-
m:setMode(nil, algType_filter[0], algType_data[0])
60-
end
35+
if m.setMode then m:setMode(1, 1, 1) end
6136
end)
6237
end
6338

@@ -74,12 +49,12 @@ local function genericTestForward(model,opts)
7449

7550
local out_orig = net:forward(input):clone()
7651

77-
local mems1 = optnet.countUsedMemory(net, input)
52+
local mems1 = optnet.countUsedMemory(net)
7853

7954
optnet.optimizeMemory(net, input)
8055

8156
local out = net:forward(input):clone()
82-
local mems2 = countUsedMemory(net, input)
57+
local mems2 = countUsedMemory(net)
8358
tester:eq(out_orig, out, 'Outputs differ after optimization of '..model)
8459

8560
local mem1 = mems1.total_size
@@ -102,7 +77,7 @@ local function genericTestForward(model,opts)
10277
print('Buffers',bmem1/1024/1024,bmem2/1024/1024, 1-bmem2/bmem1)
10378
print('Params', pmem1/1024/1024,pmem2/1024/1024, 1-pmem2/pmem1)
10479
end
105-
-- [[
80+
10681
function optest.basic()
10782
genericTestForward('basic1')
10883
end
@@ -146,7 +121,7 @@ function optest.resnet110()
146121
local opts = {dataset='cifar10',depth=110}
147122
genericTestForward('resnet', opts)
148123
end
149-
--]]
124+
150125

151126
-------------------------------------------------
152127
-- Backward
@@ -171,10 +146,8 @@ local function genericTestBackward(model,opts)
171146
net:training()
172147

173148
if use_cudnn then
174-
-- for the moment disable cudnn checks, because its backward pass is
175-
-- by default non-deterministic, breaking the tests
176-
--cudnn.convert(net,cudnn);
177-
--cudnnSetDeterministic(net)
149+
cudnn.convert(net,cudnn);
150+
cudnnSetDeterministic(net)
178151
net:cuda();
179152

180153
input = resizeAndConvert(input,'torch.CudaTensor')
@@ -187,7 +160,7 @@ local function genericTestBackward(model,opts)
187160
local _, gradParams_orig = net:getParameters()
188161
gradParams_orig = gradParams_orig:clone()
189162

190-
local mems1 = optnet.countUsedMemory(net, input)
163+
local mems1 = optnet.countUsedMemory(net)
191164

192165
optnet.optimizeMemory(net, input, {mode='training'})
193166

@@ -198,7 +171,7 @@ local function genericTestBackward(model,opts)
198171
local _, gradParams = net:getParameters()
199172
gradParams = gradParams:clone()
200173

201-
local mems2 = countUsedMemory(net, input)
174+
local mems2 = countUsedMemory(net)
202175
tester:eq(out_orig, out, 'Outputs differ after optimization of '..model)
203176
tester:eq(gradInput_orig, gradInput, backward_tol, 'GradInputs differ after optimization of '..model)
204177
tester:eq(gradParams_orig, gradParams, backward_tol, 'GradParams differ after optimization of '..model)
@@ -271,6 +244,10 @@ function optest.resnet56_backward()
271244
genericTestBackward('resnet', opts)
272245
end
273246

247+
function optest.resnet110_backward()
248+
local opts = {dataset='cifar10',depth=110}
249+
genericTestBackward('resnet', opts)
250+
end
274251

275252
tester:add(optest)
276253

0 commit comments

Comments
 (0)