1
1
require ' nn'
2
2
3
- local function keepTrack (t , track , entry_fun , fun , ...)
4
- if torch .isTensor (t ) and t :storage () then
5
- local ptr = torch .pointer (t :storage ())
6
- if not track [ptr ] then
7
- track [ptr ] = entry_fun (t , ... )
8
- end
9
- if fun then
10
- fun (t ,track ,... )
11
- end
12
- return
13
- end
14
- if torch .type (t ) == ' table' then
15
- for k , v in ipairs (t ) do
16
- keepTrack (v , track , entry_fun , fun , ... )
17
- end
18
- end
19
- end
20
-
21
- function usedMemory (net , input , func )
22
- local func = func or ' updateOutput'
23
- net [func ](net , input )
24
- local tensors = {}
25
- local function entry_fun (t )
26
- return t
27
- end
28
- local function new_func (m )
29
- local basefunc = m [func ]
30
- m [func ] = function (self , input )
31
- keepTrack (input , tensors , entry_fun )
32
- keepTrack (self .output , tensors , entry_fun )
33
- return basefunc (self , input )
34
- end
35
- end
36
- net :apply (new_func )
37
- net [func ](net , input )
38
- -- clean up the modified function
39
- net :apply (function (x )
40
- x [func ] = nil
41
- end )
42
- local total_size = 0
43
- for k ,v in pairs (tensors ) do
44
- local size = v :storage ():size ()* v :elementSize ()
45
- total_size = total_size + size
46
- end
47
- return total_size -- /(1024*1024) -- MB
48
- end
49
-
3
+ -- local utils = require 'optimize-nn.utils'
4
+ local utils = dofile ' utils.lua'
5
+ usedMemory = utils .usedMemory
50
6
51
7
local kNotUsed = 10000 --- 1
52
8
local kNotDefined = 0
@@ -87,22 +43,18 @@ local function analyse(net, input, func)
87
43
local c = 1
88
44
local function apply_func (m )
89
45
local basefunc = m [func ]
90
- local base_opts = {
91
- analysis = analysis , c = c , name = tostring (m ),
92
- kNotUsed = kNotUsed , kNotDefined = kNotDefined
93
- }
94
46
m [func ] = function (self , input )
95
- -- local opts = {}; for k, v in pairs(base_opts) do opts[k] = v; end
96
- -- opts.var = 'used'; opts.f = math.max; opts.notUsed = kNotUsed
97
- keepTrack ( input , track , entry_fun , fun , -- opts)--[[
98
- { var = ' used ' , c = c , f = math.max ,
99
- notUsed = kNotUsed , name = tostring ( m )}) -- ]]
100
-
101
- -- opts = {}; for k, v in pairs(base_opts) do opts[k] = v; end
102
- -- opts.var = 'defined'; opts.f = math.min; opts.notUsed = kNotDefined
103
- keepTrack ( self . output , track , entry_fun , fun , -- opts)--[[
104
- { var = ' defined ' , c = c , f = math.min ,
105
- notUsed = kNotDefined , name = tostring ( m )}) -- ]]
47
+ local opts = {
48
+ analysis = analysis , c = c , name = tostring ( m ),
49
+ kNotUsed = kNotUsed , kNotDefined = kNotDefined
50
+ }
51
+
52
+ opts . var = ' used ' ; opts . f = math.max ; opts . notUsed = kNotUsed
53
+ utils . keepTrack ( input , track , entry_fun , fun , opts )
54
+
55
+ opts . var = ' defined ' ; opts . f = math.min ; opts . notUsed = kNotDefined
56
+ utils . keepTrack ( self . output , track , entry_fun , fun , opts )
57
+
106
58
c = c + 1
107
59
return basefunc (self ,input )
108
60
end
@@ -112,10 +64,10 @@ local function analyse(net, input, func)
112
64
local function trackInputs (t )
113
65
if torch .isTensor (t ) then
114
66
local f = function (a ,b ) return a end
115
- keepTrack (t , track , entry_fun , fun ,
67
+ utils . keepTrack (t , track , entry_fun , fun ,
116
68
{var = ' used' , c = kAlwaysLive ,
117
69
f = f , notUsed = 0 , name = ' input' })
118
- keepTrack (t , track , entry_fun , fun ,
70
+ utils . keepTrack (t , track , entry_fun , fun ,
119
71
{var = ' defined' , c =- kAlwaysLive ,
120
72
f = f , notUsed = 0 , name = ' input' })
121
73
else
0 commit comments