Skip to content

Commit 49f2f04

Browse files
committed
A few more changes
1 parent d0ff38b commit 49f2f04

File tree

7 files changed

+306
-129
lines changed

7 files changed

+306
-129
lines changed

README.md

+23-9
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,19 @@ Heavily inspired from the `Optimizer` from https://github.com/facebook/fb-caffe-
66

77
## How does it work ?
88

9-
It goes over the network and verify which buffers can be reused. Currently only
10-
the `output` of each module are reused.
9+
It goes over the network and verify which buffers can be reused.
10+
Currently, it only supports evaluation mode, but training mode will soon be included.
11+
12+
Here is a list of currently tested modules (numbers are for CPU version, with batch size of 1, in the format (total memory used, memory used for the outputs)):
13+
14+
| Network | before optimization | after optimization | Relative save |
15+
| ------- | :--------: | :-------: | :------: |
16+
|alexnet | (972MB, 6MB) | (933MB, 1.5MB) | (4%, 75%) |
17+
|vgg16 | (2311MB, 69MB) | (2119MB, 30MB) | (8%, 55%) |
18+
|googlenet | (505MB, 69MB) | (337MB, 30MB) | (33%, 57%) |
19+
|resnet 110 (cifar)| (113MB, 16MB) | (32MB, 4MB) | (72%, 73%) |
20+
21+
Note that most of the used memory goes to the convolution buffers from `nn`.
1122

1223
## Visualizing the memory reuse
1324

@@ -25,6 +36,8 @@ having to use `nngraph`.
2536
Let's have a look:
2637

2738
```lua
39+
-- some handy models are defined in optnet.models
40+
-- line alexnet, googlenet and resnet
2841
models = require 'optnet.models'
2942
modelname = 'googlenet'
3043
net, input = models[modelname]()
@@ -34,7 +47,6 @@ generateGraph = require 'optnet.graphgen'
3447
g = generateGraph(net, input)
3548

3649
graph.dot(g,modelname,modelname)
37-
3850
```
3951

4052
This generates the following graph:
@@ -49,11 +61,13 @@ models = require 'optnet.models'
4961
modelname = 'googlenet'
5062
net, input = models[modelname]()
5163

64+
opts = {inplace=true, reuseBuffers=true}
65+
5266
generateGraph = require 'optnet.graphgen'
5367

5468
optnet = require 'optnet'
5569

56-
optnet.optimizeMemory(net, input)
70+
optnet.optimizeMemory(net, input, opts)
5771

5872
g = generateGraph(net, input)
5973

@@ -71,22 +85,22 @@ Here is an example
7185

7286
```lua
7387
optnet = require 'optnet'
74-
utils = require 'optnet.utils'
75-
usedMemory = utils.usedMemory
7688

7789
models = require 'optnet.models'
7890
modelname = 'googlenet'
7991
net, input = models[modelname]()
8092

81-
mem1 = usedMemory(net, input)
93+
opts = {countBuffers=true}
94+
95+
mem1 = optnet.countUsedMemory(net, input, opts)
8296

8397
optnet.optimizeMemory(net, input)
8498

85-
mem2 = usedMemory(net, input)
99+
mem2 = optnet.countUsedMemory(net, input, opts)
86100

87101
optnet.removeOptimization(net)
88102

89-
mem3 = usedMemory(net, input)
103+
mem3 = optnet.countUsedMemory(net, input, opts)
90104

91105
print('Before optimization : '.. mem1/1024/1024 .. ' MBytes')
92106
print('After optimization : '.. mem2/1024/1024 .. ' MBytes')

countUsedMemory.lua

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
local optnet = require 'optnet.env'
2+
local utils = require 'optnet.utils'
3+
local keepTrack = utils.keepTrack
4+
5+
function optnet.countUsedMemory(net, input, opts)
6+
opts = opts or {}
7+
local countBuffers = opts.countBuffers or false
8+
local func = opts.func or 'updateOutput'
9+
net[func](net, input)
10+
local tensors = {}
11+
local function entry_fun(t)
12+
return t
13+
end
14+
local function new_func(m)
15+
local basefunc = m[func]
16+
m[func] = function(self, input)
17+
--keepTrack(input, tensors, entry_fun)
18+
keepTrack(self.output, tensors, entry_fun)
19+
if countBuffers then
20+
for k, v in pairs(self) do
21+
if torch.isTensor(v) then
22+
keepTrack(v, tensors, entry_fun)
23+
end
24+
end
25+
end
26+
return basefunc(self, input)
27+
end
28+
end
29+
net:apply(new_func)
30+
net[func](net, input)
31+
-- clean up the modified function
32+
net:apply(function(x)
33+
x[func] = nil
34+
end)
35+
local total_size = 0
36+
for k,v in pairs(tensors) do
37+
local size = v:storage():size()*v:elementSize()
38+
total_size = total_size + size
39+
end
40+
return total_size--/(1024*1024) -- MB
41+
end

env.lua

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
local optnet = {}
2+
return optnet

init.lua

+96-13
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
require 'nn'
22

3-
local utils = require 'optnet.utils'
3+
local optnet = require 'optnet.env'
4+
require 'optnet.countUsedMemory'
45

5-
local optnet = {}
6+
local utils = require 'optnet.utils'
67

78
local kNotUsed = 10000---1
89
local kNotDefined = 0
@@ -11,13 +12,6 @@ local kAlwaysLive = 10000
1112

1213
local function analyse(net, input, func)
1314
local func = func or 'updateOutput'
14-
local grad
15-
if func == 'backward' then
16-
-- need to run forward before backward
17-
grad = net['forward'](net, input)
18-
end
19-
-- do a pass over the network to initialize its fields
20-
net[func](net, input, grad)
2115

2216
local track = {}
2317
local analysis = {}
@@ -81,6 +75,17 @@ local function analyse(net, input, func)
8175
net:apply(function(x)
8276
x[func] = nil
8377
end)
78+
79+
-- disable backward pass if in evaluation mode
80+
if func == 'updateOutput' then
81+
net:apply(function(m)
82+
m.updateGradInput = function(self, input, gradInput)
83+
error([[Backward pass disabled!
84+
You are using inference optimization.
85+
Call optnet.removeOptimization(net) to enable backward again]])
86+
end
87+
end)
88+
end
8489
return analysis
8590
end
8691

@@ -131,13 +136,79 @@ local function applyAssignments(net, assignments)
131136
end
132137
end
133138

139+
local function defaultValue(var, val)
140+
if var == nil then
141+
var = val
142+
end
143+
return var
144+
end
145+
146+
-- set to inplace modules that allows it
147+
local function setInplace(net, opts)
148+
local inplace = defaultValue(opts.inplace, true)
149+
150+
if inplace then
151+
net:apply(function(m)
152+
if m.inplace ~= nil then
153+
-- inplace is not always supported for threshold,
154+
-- depending on the values. Disabling it for the moment
155+
if torch.typename(m) ~= 'nn.Threshold' then
156+
m.inplace = true
157+
end
158+
end
159+
end)
160+
end
161+
end
162+
163+
local reusableBuffers = {
164+
['nn.SpatialConvolution'] = {{'finput','fgradInput'},{}},
165+
['nn.SpatialConvolutionMM'] = {{'finput','fgradInput'},{}},
166+
['nn,Normalize'] = {{'norm','buffer','normp','_indices'},{}},
167+
['nn.SpatialCrossMapLRN'] = {{'scale'},{}},
168+
['nn.SpatialMaxPooling'] = {{'indices'},{}},
169+
}
170+
-- basic reusing scheme: keeps a list of all possible buffers
171+
-- that can be reused in evaluation mode and also in training
172+
-- mode.
173+
local function reuseStateBuffers(net, opts)
174+
local reuseBuffers = defaultValue(opts.reuseBuffers, true)
175+
if reuseBuffers then
176+
local reusedBuffers = {}
177+
net:apply(function(m)
178+
local name = torch.typename(m)
179+
if reusableBuffers[name] then
180+
local rb = reusableBuffers[name][1]
181+
for k, v in ipairs(rb) do
182+
if m[v] then
183+
reusedBuffers[name..','..v] = reusedBuffers[name..','..v] or m[v]:storage()
184+
if reusedBuffers[name..','..v] then
185+
m[v]:set(reusedBuffers[name..','..v])
186+
end
187+
end
188+
end
189+
end
190+
end)
191+
end
192+
end
193+
134194
function optnet.optimizeMemory(net, input, opts)
195+
opts = opts or {}
196+
local func = defaultValue(opts.func,'forward')
197+
198+
local grad
199+
if func == 'backward' then
200+
-- need to run forward before backward
201+
grad = net['forward'](net, input)
202+
end
203+
-- do a pass over the network to initialize its fields
204+
net[func](net, input, grad)
205+
206+
setInplace(net, opts)
207+
reuseStateBuffers(net, opts)
208+
209+
-- share outputs
135210
local analysis = analyse(net, input)
136-
-- print('Analysis')
137-
-- print(analysis)
138211
local assignments = assign(net,analysis)
139-
-- print('Assignments')
140-
-- print(assignments)
141212
applyAssignments(net, assignments)
142213
end
143214

@@ -156,6 +227,18 @@ function optnet.removeOptimization(net)
156227
net:apply(function(m)
157228
rem(m.output)
158229
rem(m.gradInput)
230+
local name = torch.typename(m)
231+
if reusableBuffers[name] then
232+
local rb = reusableBuffers[name][1]
233+
for k, v in ipairs(rb) do
234+
if m[v] then
235+
m[v]:set()
236+
end
237+
end
238+
end
239+
240+
-- remove backward blocking
241+
m.updateGradInput = nil
159242
end)
160243
end
161244

0 commit comments

Comments
 (0)