Skip to content

Commit 3b7a913

Browse files
committed
Packaging
1 parent a38e4bf commit 3b7a913

File tree

8 files changed

+279
-163
lines changed

8 files changed

+279
-163
lines changed

CMakeLists.txt

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
CMAKE_MINIMUM_REQUIRED(VERSION 2.6 FATAL_ERROR)
2+
CMAKE_POLICY(VERSION 2.6)
3+
FIND_PACKAGE(Torch REQUIRED)
4+
5+
FILE(GLOB luasrc *.lua)
6+
7+
ADD_TORCH_PACKAGE(optnet "" "${luasrc}" "Memory optimizations for nn")

README.md

+53
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
#OptNet - reducing memory usage in torch neural networks
2+
3+
Memory optimizations for torch neural networks.
4+
Heavily inspired from the `Optimizer` from https://github.com/facebook/fb-caffe-exts
5+
6+
## How does it work ?
7+
8+
It goes over the network and verify which buffers can be reused (currently only
9+
the `output` of each module).
10+
11+
## Visualizing the memory reuse
12+
13+
We can analyse the sharing of the internal buffers by looking at the computation
14+
graph of the network before and after the sharing.
15+
16+
For that, we have the `createGraph(net, input, opts)` function, which creates the
17+
graph corresponding to the network `net`. The generated graph contains the storage
18+
id of each `output`, and same colors means same storage.
19+
20+
Let's have a look:
21+
22+
```lua
23+
models = require 'optnet.models'
24+
modelname = 'googlenet'
25+
net, input = models[modelname]()
26+
27+
generateGraph = require 'optnet.graphgen'
28+
29+
g = generateGraph(net, input)
30+
31+
graph.dot(g,modelname,modelname)
32+
33+
```
34+
35+
This generates the following graph:
36+
37+
Now what happens after we optimize the network ?
38+
39+
```lua
40+
models = require 'optnet.models'
41+
modelname = 'googlenet'
42+
net, input = models[modelname]()
43+
44+
generateGraph = require 'optnet.graphgen'
45+
46+
optnet = require 'optnet'
47+
48+
optnet.optimizeMemory(net, input)
49+
50+
g = generateGraph(net, input)
51+
52+
graph.dot(g,modelname..'_optimized',modelname..'_optimized')
53+
```

example.lua

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
optnet = require 'optnet'
2+
generateGraph = require 'optnet.graphgen'
3+
models = require 'optnet.models'
4+
5+
modelname = 'googlenet'
6+
net, input = models[modelname]()
7+
8+
g = generateGraph(net, input)
9+
graph.dot(g, modelname, modelname)
10+
11+
optnet.optimizeMemory(net, input)
12+
13+
g = generateGraph(net, input)
14+
graph.dot(g, modelname..'_optimized', modelname..'_optimized')

graphgen.lua

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
require 'graph'
2+
3+
-- taken from http://www.graphviz.org/doc/info/colors.html
4+
local colorNames = {
5+
"aliceblue","antiquewhite","antiquewhite1","antiquewhite2","antiquewhite3",
6+
"antiquewhite4","aquamarine","aquamarine1","aquamarine2","aquamarine3",
7+
"aquamarine4","azure","azure1","azure2","azure3",
8+
"azure4","beige","bisque","bisque1","bisque2",
9+
"bisque3","bisque4","black","blanchedalmond","blue",
10+
"blue1","blue2","blue3","blue4","blueviolet",
11+
"brown","brown1","brown2","brown3","brown4",
12+
"burlywood","burlywood1","burlywood2","burlywood3","burlywood4",
13+
"cadetblue","cadetblue1","cadetblue2","cadetblue3","cadetblue4",
14+
"chartreuse","chartreuse1","chartreuse2","chartreuse3","chartreuse4",
15+
"chocolate","chocolate1","chocolate2","chocolate3","chocolate4",
16+
"coral","coral1","coral2","coral3","coral4",
17+
"cornflowerblue","cornsilk","cornsilk1","cornsilk2","cornsilk3",
18+
"cornsilk4","crimson","cyan","cyan1","cyan2",
19+
"cyan3","cyan4","darkgoldenrod","darkgoldenrod1","darkgoldenrod2",
20+
"darkgoldenrod3","darkgoldenrod4","darkgreen","darkkhaki","darkolivegreen",
21+
"darkolivegreen1","darkolivegreen2","darkolivegreen3","darkolivegreen4","darkorange",
22+
"darkorange1","darkorange2","darkorange3","darkorange4","darkorchid",
23+
"darkorchid1","darkorchid2","darkorchid3","darkorchid4","darksalmon",
24+
"darkseagreen","darkseagreen1","darkseagreen2","darkseagreen3","darkseagreen4",
25+
"darkslateblue","darkslategray","darkslategray1","darkslategray2","darkslategray3",
26+
"darkslategray4","darkslategrey","darkturquoise","darkviolet","deeppink",
27+
"deeppink1","deeppink2","deeppink3","deeppink4","deepskyblue",
28+
"deepskyblue1","deepskyblue2","deepskyblue3","deepskyblue4","dimgray",
29+
"dimgrey","dodgerblue","dodgerblue1","dodgerblue2","dodgerblue3",
30+
"dodgerblue4","firebrick","firebrick1","firebrick2","firebrick3",
31+
"firebrick4","floralwhite","forestgreen","gainsboro","ghostwhite",
32+
"gold","gold1","gold2","gold3","gold4",
33+
"goldenrod","goldenrod1","goldenrod2","goldenrod3","goldenrod4"
34+
}
35+
36+
37+
local function generateGraph(net, input, opts)
38+
39+
local storageHash = {}
40+
local nodes = {}
41+
42+
local g = graph.Graph()
43+
44+
-- basic function for creating an annotated nn.Node to suit our purposes
45+
-- gives the same color for the same storage.
46+
-- note that two colors being the same does not imply the same storage
47+
-- as we have a limited number of colors
48+
local function createNode(name, tensor)
49+
local data = torch.pointer(tensor:storage())
50+
local storageId
51+
if not storageHash[data] then
52+
storageHash[data] = torch.random(1,#colorNames)
53+
table.insert(storageHash, data)
54+
end
55+
for k, v in ipairs(storageHash) do
56+
if v == data then
57+
storageId = k
58+
end
59+
end
60+
local node = graph.Node("Storage id: "..storageId)
61+
function node:graphNodeName()
62+
return name
63+
end
64+
function node:graphNodeAttributes()
65+
return {color=colorNames[storageHash[data]]}
66+
end
67+
return node
68+
end
69+
70+
-- generate input/output nodes
71+
local function createBoundaryNode(input, name)
72+
if torch.isTensor(input) then
73+
local ptr = torch.pointer(input)
74+
nodes[ptr] = createNode(name,input)
75+
else
76+
for k,v in ipairs(input) do
77+
createBoundaryNode(nodes, v, name..' '..k)
78+
end
79+
end
80+
end
81+
82+
-- create edge "from" -> "to", creating "to" on the way with "name"
83+
-- the edges can be seen as linking modules, but in fact it links the output
84+
-- tensor of each module
85+
local function addEdge(from, to, name)
86+
if torch.isTensor(to) and torch.isTensor(from) then
87+
local fromPtr = torch.pointer(from)
88+
local toPtr = torch.pointer(to)
89+
90+
nodes[toPtr] = nodes[toPtr] or createNode(name,to)
91+
92+
assert(nodes[fromPtr], 'Parent node inexistant for module '.. name)
93+
94+
-- insert edge
95+
g:add(graph.Edge(nodes[fromPtr],nodes[toPtr]))
96+
97+
elseif torch.isTensor(from) then
98+
for k,v in ipairs(to) do
99+
addEdge(from, v, name)
100+
end
101+
else
102+
for k,v in ipairs(from) do
103+
addEdge(v, to, name)
104+
end
105+
end
106+
end
107+
108+
-- go over the network keeping track of the input/output for each module
109+
-- we overwrite the updateOutput for that.
110+
local function apply_func(m)
111+
local basefunc = m.updateOutput
112+
m.updateOutput = function(self, input)
113+
if not m.modules then
114+
local name = tostring(m)
115+
if m.inplace then -- handle it differently ?
116+
addEdge(input,self.output,name)
117+
else
118+
addEdge(input,self.output,name)
119+
end
120+
elseif torch.typename(m) == 'nn.Concat' or
121+
torch.typename(m) == 'nn.Parallel' or
122+
torch.typename(m) == 'nn.DepthConcat' then
123+
-- those containers effectively do some computation, so they have their
124+
-- place in the graph
125+
for i,branch in ipairs(m.modules) do
126+
local last_module = branch:get(branch:size())
127+
local out = last_module.output
128+
local ptr = torch.pointer(out)
129+
130+
local name = torch.typename(last_module)
131+
nodes[ptr] = nodes[ptr] or createNode(name,out)
132+
addEdge(out, self.output, torch.typename(m))
133+
end
134+
end
135+
return basefunc(self, input)
136+
end
137+
end
138+
139+
createBoundaryNode(input, 'Input')
140+
141+
-- fill the states from each tensor
142+
net:forward(input)
143+
144+
--createInputNode(nodes, net.output, 'Output')
145+
146+
-- overwriting the standard functions to generate our graph
147+
net:apply(apply_func)
148+
-- generate the graph
149+
net:forward(input)
150+
151+
-- clean up the modified function
152+
net:apply(function(x)
153+
x.updateOutput = nil
154+
end)
155+
156+
return g
157+
end
158+
159+
return generateGraph
160+

optimize-nn.lua renamed to init.lua

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
require 'nn'
22

3-
--local utils = require 'optimize-nn.utils'
4-
local utils = dofile 'utils.lua'
5-
usedMemory = utils.usedMemory
3+
local utils = require 'optnet.utils'
4+
5+
local optnet = {}
66

77
local kNotUsed = 10000---1
88
local kNotDefined = 0
@@ -131,7 +131,7 @@ local function applyAssignments(net, assignments)
131131
end
132132
end
133133

134-
function optimizeMemory(net, input, opts)
134+
function optnet.optimizeMemory(net, input, opts)
135135
local analysis = analyse(net, input)
136136
-- print('Analysis')
137137
-- print(analysis)
@@ -141,7 +141,7 @@ function optimizeMemory(net, input, opts)
141141
applyAssignments(net, assignments)
142142
end
143143

144-
function removeOptimization(net)
144+
function optnet.removeOptimization(net)
145145
local function rem(m)
146146
if torch.isTensor(m) then
147147
m:set()
@@ -159,4 +159,5 @@ function removeOptimization(net)
159159
end)
160160
end
161161

162+
return optnet
162163

0 commit comments

Comments
 (0)