Skip to content

Commit

Permalink
feat: optional gradient clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
sehoffmann committed Apr 23, 2024
1 parent 59cca3c commit f7f3d5b
Showing 1 changed file with 23 additions and 8 deletions.
31 changes: 23 additions & 8 deletions dmlcloud/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def train_metric_prefix(self):
def val_metric_prefix(self):
return 'val'

def gradient_clip(self):
return 0.0

def run_epoch(self):
self.train_epoch()
self.val_epoch()
Expand All @@ -266,6 +269,24 @@ def train_step(self, batch):
def val_step(self, batch):
return self.step(batch)

def zero_grad(self):
for optimizer in self.optimizers():
optimizer.zero_grad()

def clip_gradients(self):
for optimizer in self.optimizers():
for group in optimizer.param_groups:
torch.nn.utils.clip_grad_norm_(group['params'], self.gradient_clip())

def optimize(self, loss):
loss.backward()

if self.gradient_clip():
self.clip_gradients()

for optimizer in self.optimizers():
optimizer.step()

def train_epoch(self):
self.is_train = True
self.metric_prefix = self.train_metric_prefix()
Expand All @@ -275,15 +296,9 @@ def train_epoch(self):
train_ds.sampler.set_epoch(self.current_epoch)

for batch in train_ds:
for optimizer in self.optimizers():
optimizer.zero_grad()

self.zero_grad()
loss = self.train_step(batch)
loss.backward()

for optimizer in self.optimizers():
optimizer.step()

self.optimize(loss)
self.track_reduce(self.loss_metric_name(), loss)

for scheduler in self.pipeline.schedulers.values():
Expand Down

0 comments on commit f7f3d5b

Please sign in to comment.