Skip to content

Commit 82dd91c

Browse files
committed
Move functions to utils.lua
1 parent 9e28df0 commit 82dd91c

File tree

2 files changed

+67
-64
lines changed

2 files changed

+67
-64
lines changed

optimize-nn.lua

+16-64
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,8 @@
11
require 'nn'
22

3-
local function keepTrack(t, track, entry_fun, fun, ...)
4-
if torch.isTensor(t) and t:storage() then
5-
local ptr = torch.pointer(t:storage())
6-
if not track[ptr] then
7-
track[ptr] = entry_fun(t, ...)
8-
end
9-
if fun then
10-
fun(t,track,...)
11-
end
12-
return
13-
end
14-
if torch.type(t) == 'table' then
15-
for k, v in ipairs(t) do
16-
keepTrack(v, track, entry_fun, fun, ...)
17-
end
18-
end
19-
end
20-
21-
function usedMemory(net, input, func)
22-
local func = func or 'updateOutput'
23-
net[func](net, input)
24-
local tensors = {}
25-
local function entry_fun(t)
26-
return t
27-
end
28-
local function new_func(m)
29-
local basefunc = m[func]
30-
m[func] = function(self, input)
31-
keepTrack(input, tensors, entry_fun)
32-
keepTrack(self.output, tensors, entry_fun)
33-
return basefunc(self, input)
34-
end
35-
end
36-
net:apply(new_func)
37-
net[func](net, input)
38-
-- clean up the modified function
39-
net:apply(function(x)
40-
x[func] = nil
41-
end)
42-
local total_size = 0
43-
for k,v in pairs(tensors) do
44-
local size = v:storage():size()*v:elementSize()
45-
total_size = total_size + size
46-
end
47-
return total_size--/(1024*1024) -- MB
48-
end
49-
3+
--local utils = require 'optimize-nn.utils'
4+
local utils = dofile 'utils.lua'
5+
usedMemory = utils.usedMemory
506

517
local kNotUsed = 10000---1
528
local kNotDefined = 0
@@ -87,22 +43,18 @@ local function analyse(net, input, func)
8743
local c = 1
8844
local function apply_func(m)
8945
local basefunc = m[func]
90-
local base_opts = {
91-
analysis=analysis, c=c, name=tostring(m),
92-
kNotUsed=kNotUsed, kNotDefined=kNotDefined
93-
}
9446
m[func] = function(self, input)
95-
--local opts = {}; for k, v in pairs(base_opts) do opts[k] = v; end
96-
--opts.var = 'used'; opts.f = math.max; opts.notUsed = kNotUsed
97-
keepTrack(input, track, entry_fun, fun,-- opts)--[[
98-
{var='used', c=c, f=math.max,
99-
notUsed=kNotUsed, name=tostring(m)})--]]
100-
101-
--opts = {}; for k, v in pairs(base_opts) do opts[k] = v; end
102-
--opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined
103-
keepTrack(self.output, track, entry_fun, fun,-- opts)--[[
104-
{var='defined',c=c, f=math.min,
105-
notUsed=kNotDefined, name=tostring(m)})--]]
47+
local opts = {
48+
analysis=analysis, c=c, name=tostring(m),
49+
kNotUsed=kNotUsed, kNotDefined=kNotDefined
50+
}
51+
52+
opts.var = 'used'; opts.f = math.max; opts.notUsed = kNotUsed
53+
utils.keepTrack(input, track, entry_fun, fun, opts)
54+
55+
opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined
56+
utils.keepTrack(self.output, track, entry_fun, fun, opts)
57+
10658
c = c + 1
10759
return basefunc(self,input)
10860
end
@@ -112,10 +64,10 @@ local function analyse(net, input, func)
11264
local function trackInputs(t)
11365
if torch.isTensor(t) then
11466
local f = function(a,b) return a end
115-
keepTrack(t, track, entry_fun, fun,
67+
utils.keepTrack(t, track, entry_fun, fun,
11668
{var='used', c=kAlwaysLive,
11769
f=f, notUsed=0, name='input'})
118-
keepTrack(t, track, entry_fun, fun,
70+
utils.keepTrack(t, track, entry_fun, fun,
11971
{var='defined', c=-kAlwaysLive,
12072
f=f, notUsed=0, name='input'})
12173
else

utils.lua

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
local utils = {}
2+
3+
local function keepTrack(t, track, entry_fun, fun, ...)
4+
if torch.isTensor(t) and t:storage() then
5+
local ptr = torch.pointer(t:storage())
6+
if not track[ptr] then
7+
track[ptr] = entry_fun(t, ...)
8+
end
9+
if fun then
10+
fun(t,track,...)
11+
end
12+
return
13+
end
14+
if torch.type(t) == 'table' then
15+
for k, v in ipairs(t) do
16+
keepTrack(v, track, entry_fun, fun, ...)
17+
end
18+
end
19+
end
20+
utils.keepTrack = keepTrack
21+
22+
function utils.usedMemory(net, input, func)
23+
local func = func or 'updateOutput'
24+
net[func](net, input)
25+
local tensors = {}
26+
local function entry_fun(t)
27+
return t
28+
end
29+
local function new_func(m)
30+
local basefunc = m[func]
31+
m[func] = function(self, input)
32+
keepTrack(input, tensors, entry_fun)
33+
keepTrack(self.output, tensors, entry_fun)
34+
return basefunc(self, input)
35+
end
36+
end
37+
net:apply(new_func)
38+
net[func](net, input)
39+
-- clean up the modified function
40+
net:apply(function(x)
41+
x[func] = nil
42+
end)
43+
local total_size = 0
44+
for k,v in pairs(tensors) do
45+
local size = v:storage():size()*v:elementSize()
46+
total_size = total_size + size
47+
end
48+
return total_size--/(1024*1024) -- MB
49+
end
50+
51+
return utils

0 commit comments

Comments
 (0)