From 96ad654f48c3d1f1f1471498f6e6f224570074ba Mon Sep 17 00:00:00 2001 From: Joshua David Date: Mon, 8 Jul 2024 23:03:17 -0700 Subject: [PATCH] Update the remainig logs and checks for the global_step --- train.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/train.py b/train.py index e6c1f88..ed13870 100644 --- a/train.py +++ b/train.py @@ -274,19 +274,29 @@ def train( # Log epoch results logger.info( - f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Train Perplexity: {train_perplexity:.4f}, " + f"Epoch {epoch+1}, Global Step {global_step}, " + f"Train Loss: {avg_train_loss:.4f}, Train Perplexity: {train_perplexity:.4f}, " f"Val Loss: {avg_val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}" ) # Save checkpoint accelerator.save_state( - {"epoch": epoch, "best_val_loss": best_val_loss}, - f"checkpoint_epoch_{epoch}.pt", + { + "epoch": epoch, + "global_step": global_step, + "best_val_loss": best_val_loss, + }, + f"checkpoint_epoch_{epoch}_step_{global_step}.pt", ) # Save latest checkpoint accelerator.save_state( - {"epoch": epoch, "best_val_loss": best_val_loss}, "checkpoint_latest.pt" + { + "epoch": epoch, + "global_step": global_step, + "best_val_loss": best_val_loss, + }, + "checkpoint_latest.pt", ) # Early stopping @@ -295,7 +305,12 @@ def train( patience = 0 # Save best model accelerator.save_state( - {"epoch": epoch, "best_val_loss": best_val_loss}, "best_model.pt" + { + "epoch": epoch, + "global_step": global_step, + "best_val_loss": best_val_loss, + }, + "best_model.pt", ) else: patience += 1 @@ -303,6 +318,9 @@ def train( logger.info("Early stopping triggered") break + if max_steps and global_step >= max_steps: + break + # %% def main():