Skip to content

Commit 21904f8

Browse files
committed
Initial commit
0 parents  commit 21904f8

File tree

2 files changed

+496
-0
lines changed

2 files changed

+496
-0
lines changed

optimize-nn.lua

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
require 'nn'
2+
3+
4+
function usedMemory(net, input, func)
5+
local func = func or 'updateOutput'
6+
net[func](net, input)
7+
local tensors = {}
8+
local function keepTrack(t)
9+
if torch.isTensor(t) and t:storage() then
10+
local ptr = torch.pointer(t:storage())
11+
if not tensors[ptr] then
12+
tensors[ptr] = t
13+
end
14+
return
15+
end
16+
if torch.type(t) == 'table' then
17+
for k, v in ipairs(t) do
18+
keepTrack(v)
19+
end
20+
end
21+
end
22+
local function new_func(m)
23+
local basefunc = m[func]
24+
m[func] = function(self, input)
25+
keepTrack(input)
26+
keepTrack(self.output)
27+
return basefunc(self, input)
28+
end
29+
end
30+
net:apply(new_func)
31+
net[func](net, input)
32+
-- clean up the modified function
33+
net:apply(function(x)
34+
x[func] = nil
35+
end)
36+
local total_size = 0
37+
for k,v in pairs(tensors) do
38+
local size = v:storage():size()*v:elementSize()
39+
total_size = total_size + size
40+
end
41+
return total_size--/(1024*1024) -- MB
42+
end
43+
44+
local kNotUsed = 10000---1
45+
local kNotDefined = 0
46+
local kMinimumForSharing = 2
47+
local kAlwaysLive = 10000
48+
49+
local function analyse(net, input, func)
50+
local analysis = {}
51+
local analysis2 = {}
52+
local func = func or 'updateOutput'
53+
net[func](net, input)
54+
local c = 1
55+
local function keepTrack(t, var, c, name, f, notUsed)
56+
if torch.isTensor(t) and t:storage() then
57+
local ptr = torch.pointer(t:storage())
58+
if not analysis[ptr] then
59+
--analysis[ptr] = {[var]=c, name=name, ptr=ptr, tensor=t}
60+
analysis[ptr] = {used=kNotUsed,defined=kNotDefined, name=name, ptr=ptr, tensor=t}
61+
table.insert(analysis2,analysis[ptr])
62+
end
63+
local val = analysis[ptr][var]
64+
if val == notUsed then
65+
analysis[ptr][var] = c
66+
else
67+
analysis[ptr][var] = f(c,val)
68+
end
69+
return
70+
end
71+
if torch.type(t) == 'table' then
72+
for k, v in ipairs(t) do
73+
keepTrack(v, var, c, name, f, notUsed)
74+
end
75+
end
76+
end
77+
local function new_func(m)
78+
local basefunc = m[func]
79+
m[func] = function(self, input)
80+
--if torch.typename(m) ~= 'nn.Sequential' then
81+
keepTrack(input, 'used', c, tostring(m), math.max, kNotUsed)
82+
keepTrack(self.output, 'defined', c, tostring(m), math.min, kNotDefined)
83+
c = c + 1
84+
--end
85+
return basefunc(self,input)
86+
end
87+
end
88+
net:apply(new_func)
89+
net[func](net, input)
90+
local function trackInputs(t)
91+
if torch.isTensor(t) then
92+
local f = function(a,b) return a end
93+
keepTrack(t, 'used', kAlwaysLive, 'input', f, 0)
94+
keepTrack(t, 'defined', -kAlwaysLive, 'input', f, 0)
95+
else
96+
for k,v in ipairs(t) do
97+
trackInputs(v)
98+
end
99+
end
100+
end
101+
trackInputs(input)
102+
-- clean up the modified function
103+
net:apply(function(x)
104+
x[func] = nil
105+
end)
106+
return analysis2
107+
end
108+
109+
local function isCompatible(candidate, assignment)
110+
if candidate.used == kNotUsed then
111+
return false
112+
end
113+
if candidate.tensor:numel() < kMinimumForSharing then
114+
return false
115+
end
116+
local a_used = assignment[#assignment].used-- or -1
117+
return candidate.defined > a_used
118+
end
119+
120+
local function assign(net, analysis)
121+
table.sort(analysis, function(a,b)
122+
local x = a.used-- or -1
123+
local y = b.used-- or -1
124+
return x < y
125+
end)
126+
local assignments = {}
127+
for _,candidate in ipairs(analysis) do
128+
local assigned = false
129+
for _, assignment in ipairs(assignments) do
130+
if isCompatible(candidate, assignment) then
131+
table.insert(assignment,candidate)
132+
assigned = true
133+
break
134+
end
135+
end
136+
if not assigned then
137+
table.insert(assignments, {candidate})
138+
end
139+
end
140+
return assignments
141+
end
142+
143+
local function applyAssignments(net, assignments)
144+
for _, assignment in ipairs(assignments) do
145+
local storage
146+
for k, v in ipairs(assignment) do
147+
if v.used == kAlwaysLive and v.defined == -kAlwaysLive then
148+
break
149+
end
150+
storage = storage or v.tensor.new(1):storage()
151+
v.tensor:set(storage)
152+
end
153+
end
154+
end
155+
156+
function optimizeMemory(net, input)
157+
local analysis = analyse(net, input)
158+
-- print('Analysis')
159+
-- print(analysis)
160+
local assignments = assign(net,analysis)
161+
-- print('Assignments')
162+
-- print(assignments)
163+
applyAssignments(net, assignments)
164+
end
165+
166+
function removeOptimization(net)
167+
local function rem(m)
168+
if torch.isTensor(m) then
169+
m:set()
170+
end
171+
if torch.type(m) == 'table' then
172+
for k, v in ipairs(m) do
173+
rem(v)
174+
end
175+
end
176+
end
177+
178+
net:apply(function(m)
179+
rem(m.output)
180+
end)
181+
end
182+
183+

0 commit comments

Comments
 (0)