Skip to content

Commit 1c76746

Browse files
nzmorafacebook-github-bot
authored andcommitted
SGD: remove unneeded multiply-add initialization operations (pytorch#18114)
Summary: The momentum buffer is initialized to the value of d_p, but the current code takes the long way to do this: 1. Create a buffer of zeros 2. Multiply the buffer by the momentum coefficient 3. Add d_p to the buffer All of these can be collapsed into a single step: 1. Create a clone of d_p Pull Request resolved: pytorch#18114 Differential Revision: D14509122 Pulled By: ezyang fbshipit-source-id: 4a79b896201d5ff20770b7ae790c244ba744edb8
1 parent a50ba7e commit 1c76746

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

torch/optim/sgd.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,7 @@ def step(self, closure=None):
9494
if momentum != 0:
9595
param_state = self.state[p]
9696
if 'momentum_buffer' not in param_state:
97-
buf = param_state['momentum_buffer'] = torch.zeros_like(p.data)
98-
buf.mul_(momentum).add_(d_p)
97+
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
9998
else:
10099
buf = param_state['momentum_buffer']
101100
buf.mul_(momentum).add_(1 - dampening, d_p)

0 commit comments

Comments
 (0)