From 4e34fa8e0eff6fa7c48e77c6cffda7fd11c35173 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Sun, 14 Jul 2024 21:54:29 -0700 Subject: [PATCH] Update the training notebook with the latest training updates --- notebooks/01_LongRoPE_training.ipynb | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/notebooks/01_LongRoPE_training.ipynb b/notebooks/01_LongRoPE_training.ipynb index 233e6aa..4828461 100644 --- a/notebooks/01_LongRoPE_training.ipynb +++ b/notebooks/01_LongRoPE_training.ipynb @@ -275,6 +275,10 @@ " scaler.scale(loss).backward()\n", "\n", " if (i + 1) % gradient_accumulation_steps == 0:\n", + " # Gradient clipping\n", + " scaler.unscale_(optimizer)\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n", + "\n", " # Update weights and reset gradients\n", " scaler.step(optimizer)\n", " scaler.update()\n", @@ -348,6 +352,12 @@ " f\"Val Loss: {avg_val_loss:.4f}, Val Perplexity: {val_perplexity:.4f}\"\n", " )\n", "\n", + " # Log GPU memory usage\n", + " for gpu in GPUtil.getGPUs():\n", + " gpu_memory_used = gpu.memoryUsed\n", + " logger.info(f\"GPU {gpu.id} memory use: {gpu_memory_used}MB\")\n", + " wandb.log({f\"GPU_{gpu.id}_memory_used\": gpu_memory_used})\n", + "\n", " # Save checkpoint\n", " accelerator.save_state(\n", " {\n", @@ -388,7 +398,7 @@ " break\n", "\n", " if max_steps and global_step >= max_steps:\n", - " break" + " break\n" ] }, {