|
| 1 | +require 'nn' |
| 2 | + |
| 3 | + |
| 4 | +function usedMemory(net, input, func) |
| 5 | + local func = func or 'updateOutput' |
| 6 | + net[func](net, input) |
| 7 | + local tensors = {} |
| 8 | + local function keepTrack(t) |
| 9 | + if torch.isTensor(t) and t:storage() then |
| 10 | + local ptr = torch.pointer(t:storage()) |
| 11 | + if not tensors[ptr] then |
| 12 | + tensors[ptr] = t |
| 13 | + end |
| 14 | + return |
| 15 | + end |
| 16 | + if torch.type(t) == 'table' then |
| 17 | + for k, v in ipairs(t) do |
| 18 | + keepTrack(v) |
| 19 | + end |
| 20 | + end |
| 21 | + end |
| 22 | + local function new_func(m) |
| 23 | + local basefunc = m[func] |
| 24 | + m[func] = function(self, input) |
| 25 | + keepTrack(input) |
| 26 | + keepTrack(self.output) |
| 27 | + return basefunc(self, input) |
| 28 | + end |
| 29 | + end |
| 30 | + net:apply(new_func) |
| 31 | + net[func](net, input) |
| 32 | + -- clean up the modified function |
| 33 | + net:apply(function(x) |
| 34 | + x[func] = nil |
| 35 | + end) |
| 36 | + local total_size = 0 |
| 37 | + for k,v in pairs(tensors) do |
| 38 | + local size = v:storage():size()*v:elementSize() |
| 39 | + total_size = total_size + size |
| 40 | + end |
| 41 | + return total_size--/(1024*1024) -- MB |
| 42 | +end |
| 43 | + |
| 44 | +local kNotUsed = 10000---1 |
| 45 | +local kNotDefined = 0 |
| 46 | +local kMinimumForSharing = 2 |
| 47 | +local kAlwaysLive = 10000 |
| 48 | + |
| 49 | +local function analyse(net, input, func) |
| 50 | + local analysis = {} |
| 51 | + local analysis2 = {} |
| 52 | + local func = func or 'updateOutput' |
| 53 | + net[func](net, input) |
| 54 | + local c = 1 |
| 55 | + local function keepTrack(t, var, c, name, f, notUsed) |
| 56 | + if torch.isTensor(t) and t:storage() then |
| 57 | + local ptr = torch.pointer(t:storage()) |
| 58 | + if not analysis[ptr] then |
| 59 | + --analysis[ptr] = {[var]=c, name=name, ptr=ptr, tensor=t} |
| 60 | + analysis[ptr] = {used=kNotUsed,defined=kNotDefined, name=name, ptr=ptr, tensor=t} |
| 61 | + table.insert(analysis2,analysis[ptr]) |
| 62 | + end |
| 63 | + local val = analysis[ptr][var] |
| 64 | + if val == notUsed then |
| 65 | + analysis[ptr][var] = c |
| 66 | + else |
| 67 | + analysis[ptr][var] = f(c,val) |
| 68 | + end |
| 69 | + return |
| 70 | + end |
| 71 | + if torch.type(t) == 'table' then |
| 72 | + for k, v in ipairs(t) do |
| 73 | + keepTrack(v, var, c, name, f, notUsed) |
| 74 | + end |
| 75 | + end |
| 76 | + end |
| 77 | + local function new_func(m) |
| 78 | + local basefunc = m[func] |
| 79 | + m[func] = function(self, input) |
| 80 | + --if torch.typename(m) ~= 'nn.Sequential' then |
| 81 | + keepTrack(input, 'used', c, tostring(m), math.max, kNotUsed) |
| 82 | + keepTrack(self.output, 'defined', c, tostring(m), math.min, kNotDefined) |
| 83 | + c = c + 1 |
| 84 | + --end |
| 85 | + return basefunc(self,input) |
| 86 | + end |
| 87 | + end |
| 88 | + net:apply(new_func) |
| 89 | + net[func](net, input) |
| 90 | + local function trackInputs(t) |
| 91 | + if torch.isTensor(t) then |
| 92 | + local f = function(a,b) return a end |
| 93 | + keepTrack(t, 'used', kAlwaysLive, 'input', f, 0) |
| 94 | + keepTrack(t, 'defined', -kAlwaysLive, 'input', f, 0) |
| 95 | + else |
| 96 | + for k,v in ipairs(t) do |
| 97 | + trackInputs(v) |
| 98 | + end |
| 99 | + end |
| 100 | + end |
| 101 | + trackInputs(input) |
| 102 | + -- clean up the modified function |
| 103 | + net:apply(function(x) |
| 104 | + x[func] = nil |
| 105 | + end) |
| 106 | + return analysis2 |
| 107 | +end |
| 108 | + |
| 109 | +local function isCompatible(candidate, assignment) |
| 110 | + if candidate.used == kNotUsed then |
| 111 | + return false |
| 112 | + end |
| 113 | + if candidate.tensor:numel() < kMinimumForSharing then |
| 114 | + return false |
| 115 | + end |
| 116 | + local a_used = assignment[#assignment].used-- or -1 |
| 117 | + return candidate.defined > a_used |
| 118 | +end |
| 119 | + |
| 120 | +local function assign(net, analysis) |
| 121 | + table.sort(analysis, function(a,b) |
| 122 | + local x = a.used-- or -1 |
| 123 | + local y = b.used-- or -1 |
| 124 | + return x < y |
| 125 | + end) |
| 126 | + local assignments = {} |
| 127 | + for _,candidate in ipairs(analysis) do |
| 128 | + local assigned = false |
| 129 | + for _, assignment in ipairs(assignments) do |
| 130 | + if isCompatible(candidate, assignment) then |
| 131 | + table.insert(assignment,candidate) |
| 132 | + assigned = true |
| 133 | + break |
| 134 | + end |
| 135 | + end |
| 136 | + if not assigned then |
| 137 | + table.insert(assignments, {candidate}) |
| 138 | + end |
| 139 | + end |
| 140 | + return assignments |
| 141 | +end |
| 142 | + |
| 143 | +local function applyAssignments(net, assignments) |
| 144 | + for _, assignment in ipairs(assignments) do |
| 145 | + local storage |
| 146 | + for k, v in ipairs(assignment) do |
| 147 | + if v.used == kAlwaysLive and v.defined == -kAlwaysLive then |
| 148 | + break |
| 149 | + end |
| 150 | + storage = storage or v.tensor.new(1):storage() |
| 151 | + v.tensor:set(storage) |
| 152 | + end |
| 153 | + end |
| 154 | +end |
| 155 | + |
| 156 | +function optimizeMemory(net, input) |
| 157 | + local analysis = analyse(net, input) |
| 158 | +-- print('Analysis') |
| 159 | +-- print(analysis) |
| 160 | + local assignments = assign(net,analysis) |
| 161 | +-- print('Assignments') |
| 162 | +-- print(assignments) |
| 163 | + applyAssignments(net, assignments) |
| 164 | +end |
| 165 | + |
| 166 | +function removeOptimization(net) |
| 167 | + local function rem(m) |
| 168 | + if torch.isTensor(m) then |
| 169 | + m:set() |
| 170 | + end |
| 171 | + if torch.type(m) == 'table' then |
| 172 | + for k, v in ipairs(m) do |
| 173 | + rem(v) |
| 174 | + end |
| 175 | + end |
| 176 | + end |
| 177 | + |
| 178 | + net:apply(function(m) |
| 179 | + rem(m.output) |
| 180 | + end) |
| 181 | +end |
| 182 | + |
| 183 | + |
0 commit comments