Skip to content

Commit

Permalink
Update the remainig logs and checks for the global_step
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 9, 2024
1 parent 947bd63 commit 96ad654
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,19 +274,29 @@ def train(

# Log epoch results
logger.info(
f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Train Perplexity: {train_perplexity:.4f}, "
f"Epoch {epoch+1}, Global Step {global_step}, "
f"Train Loss: {avg_train_loss:.4f}, Train Perplexity: {train_perplexity:.4f}, "
f"Val Loss: {avg_val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}"
)

# Save checkpoint
accelerator.save_state(
{"epoch": epoch, "best_val_loss": best_val_loss},
f"checkpoint_epoch_{epoch}.pt",
{
"epoch": epoch,
"global_step": global_step,
"best_val_loss": best_val_loss,
},
f"checkpoint_epoch_{epoch}_step_{global_step}.pt",
)

# Save latest checkpoint
accelerator.save_state(
{"epoch": epoch, "best_val_loss": best_val_loss}, "checkpoint_latest.pt"
{
"epoch": epoch,
"global_step": global_step,
"best_val_loss": best_val_loss,
},
"checkpoint_latest.pt",
)

# Early stopping
Expand All @@ -295,14 +305,22 @@ def train(
patience = 0
# Save best model
accelerator.save_state(
{"epoch": epoch, "best_val_loss": best_val_loss}, "best_model.pt"
{
"epoch": epoch,
"global_step": global_step,
"best_val_loss": best_val_loss,
},
"best_model.pt",
)
else:
patience += 1
if patience >= max_patience:
logger.info("Early stopping triggered")
break

if max_steps and global_step >= max_steps:
break


# %%
def main():
Expand Down

0 comments on commit 96ad654

Please sign in to comment.