From 10e8119110ddc3f74c4d8bab7c846dc342c3fb6f Mon Sep 17 00:00:00 2001
From: Aravinda Kumar <arvind7447@gmail.com>
Date: Thu, 18 Jul 2024 16:08:16 +0530
Subject: [PATCH] added fsdp mnist example

---
 distributed/{ => fsdp}/FSDP/.gitignore        |   0
 distributed/{ => fsdp}/FSDP/README.md         |   0
 distributed/{ => fsdp}/FSDP/T5_training.py    |   0
 .../{ => fsdp}/FSDP/configs/__init__.py       |   0
 distributed/{ => fsdp}/FSDP/configs/fsdp.py   |   0
 .../{ => fsdp}/FSDP/configs/training.py       |   0
 .../{ => fsdp}/FSDP/download_dataset.sh       |   0
 .../FSDP/model_checkpointing/__init__.py      |   0
 .../model_checkpointing/checkpoint_handler.py |   0
 .../{ => fsdp}/FSDP/policies/__init__.py      |   0
 .../activation_checkpointing_functions.py     |   0
 .../FSDP/policies/mixed_precision.py          |   0
 .../{ => fsdp}/FSDP/policies/wrapping.py      |   0
 distributed/{ => fsdp}/FSDP/requirements.txt  |   0
 .../{ => fsdp}/FSDP/summarization_dataset.py  |   0
 distributed/{ => fsdp}/FSDP/utils/__init__.py |   0
 .../{ => fsdp}/FSDP/utils/environment.py      |   0
 .../{ => fsdp}/FSDP/utils/train_utils.py      |   0
 distributed/fsdp/fsdp-mnist/README.md         |   7 +
 distributed/fsdp/fsdp-mnist/fsdp_mnist.py     | 198 ++++++++++++++++++
 20 files changed, 205 insertions(+)
 rename distributed/{ => fsdp}/FSDP/.gitignore (100%)
 rename distributed/{ => fsdp}/FSDP/README.md (100%)
 rename distributed/{ => fsdp}/FSDP/T5_training.py (100%)
 rename distributed/{ => fsdp}/FSDP/configs/__init__.py (100%)
 rename distributed/{ => fsdp}/FSDP/configs/fsdp.py (100%)
 rename distributed/{ => fsdp}/FSDP/configs/training.py (100%)
 rename distributed/{ => fsdp}/FSDP/download_dataset.sh (100%)
 rename distributed/{ => fsdp}/FSDP/model_checkpointing/__init__.py (100%)
 rename distributed/{ => fsdp}/FSDP/model_checkpointing/checkpoint_handler.py (100%)
 rename distributed/{ => fsdp}/FSDP/policies/__init__.py (100%)
 rename distributed/{ => fsdp}/FSDP/policies/activation_checkpointing_functions.py (100%)
 rename distributed/{ => fsdp}/FSDP/policies/mixed_precision.py (100%)
 rename distributed/{ => fsdp}/FSDP/policies/wrapping.py (100%)
 rename distributed/{ => fsdp}/FSDP/requirements.txt (100%)
 rename distributed/{ => fsdp}/FSDP/summarization_dataset.py (100%)
 rename distributed/{ => fsdp}/FSDP/utils/__init__.py (100%)
 rename distributed/{ => fsdp}/FSDP/utils/environment.py (100%)
 rename distributed/{ => fsdp}/FSDP/utils/train_utils.py (100%)
 create mode 100644 distributed/fsdp/fsdp-mnist/README.md
 create mode 100644 distributed/fsdp/fsdp-mnist/fsdp_mnist.py

diff --git a/distributed/FSDP/.gitignore b/distributed/fsdp/FSDP/.gitignore
similarity index 100%
rename from distributed/FSDP/.gitignore
rename to distributed/fsdp/FSDP/.gitignore
diff --git a/distributed/FSDP/README.md b/distributed/fsdp/FSDP/README.md
similarity index 100%
rename from distributed/FSDP/README.md
rename to distributed/fsdp/FSDP/README.md
diff --git a/distributed/FSDP/T5_training.py b/distributed/fsdp/FSDP/T5_training.py
similarity index 100%
rename from distributed/FSDP/T5_training.py
rename to distributed/fsdp/FSDP/T5_training.py
diff --git a/distributed/FSDP/configs/__init__.py b/distributed/fsdp/FSDP/configs/__init__.py
similarity index 100%
rename from distributed/FSDP/configs/__init__.py
rename to distributed/fsdp/FSDP/configs/__init__.py
diff --git a/distributed/FSDP/configs/fsdp.py b/distributed/fsdp/FSDP/configs/fsdp.py
similarity index 100%
rename from distributed/FSDP/configs/fsdp.py
rename to distributed/fsdp/FSDP/configs/fsdp.py
diff --git a/distributed/FSDP/configs/training.py b/distributed/fsdp/FSDP/configs/training.py
similarity index 100%
rename from distributed/FSDP/configs/training.py
rename to distributed/fsdp/FSDP/configs/training.py
diff --git a/distributed/FSDP/download_dataset.sh b/distributed/fsdp/FSDP/download_dataset.sh
similarity index 100%
rename from distributed/FSDP/download_dataset.sh
rename to distributed/fsdp/FSDP/download_dataset.sh
diff --git a/distributed/FSDP/model_checkpointing/__init__.py b/distributed/fsdp/FSDP/model_checkpointing/__init__.py
similarity index 100%
rename from distributed/FSDP/model_checkpointing/__init__.py
rename to distributed/fsdp/FSDP/model_checkpointing/__init__.py
diff --git a/distributed/FSDP/model_checkpointing/checkpoint_handler.py b/distributed/fsdp/FSDP/model_checkpointing/checkpoint_handler.py
similarity index 100%
rename from distributed/FSDP/model_checkpointing/checkpoint_handler.py
rename to distributed/fsdp/FSDP/model_checkpointing/checkpoint_handler.py
diff --git a/distributed/FSDP/policies/__init__.py b/distributed/fsdp/FSDP/policies/__init__.py
similarity index 100%
rename from distributed/FSDP/policies/__init__.py
rename to distributed/fsdp/FSDP/policies/__init__.py
diff --git a/distributed/FSDP/policies/activation_checkpointing_functions.py b/distributed/fsdp/FSDP/policies/activation_checkpointing_functions.py
similarity index 100%
rename from distributed/FSDP/policies/activation_checkpointing_functions.py
rename to distributed/fsdp/FSDP/policies/activation_checkpointing_functions.py
diff --git a/distributed/FSDP/policies/mixed_precision.py b/distributed/fsdp/FSDP/policies/mixed_precision.py
similarity index 100%
rename from distributed/FSDP/policies/mixed_precision.py
rename to distributed/fsdp/FSDP/policies/mixed_precision.py
diff --git a/distributed/FSDP/policies/wrapping.py b/distributed/fsdp/FSDP/policies/wrapping.py
similarity index 100%
rename from distributed/FSDP/policies/wrapping.py
rename to distributed/fsdp/FSDP/policies/wrapping.py
diff --git a/distributed/FSDP/requirements.txt b/distributed/fsdp/FSDP/requirements.txt
similarity index 100%
rename from distributed/FSDP/requirements.txt
rename to distributed/fsdp/FSDP/requirements.txt
diff --git a/distributed/FSDP/summarization_dataset.py b/distributed/fsdp/FSDP/summarization_dataset.py
similarity index 100%
rename from distributed/FSDP/summarization_dataset.py
rename to distributed/fsdp/FSDP/summarization_dataset.py
diff --git a/distributed/FSDP/utils/__init__.py b/distributed/fsdp/FSDP/utils/__init__.py
similarity index 100%
rename from distributed/FSDP/utils/__init__.py
rename to distributed/fsdp/FSDP/utils/__init__.py
diff --git a/distributed/FSDP/utils/environment.py b/distributed/fsdp/FSDP/utils/environment.py
similarity index 100%
rename from distributed/FSDP/utils/environment.py
rename to distributed/fsdp/FSDP/utils/environment.py
diff --git a/distributed/FSDP/utils/train_utils.py b/distributed/fsdp/FSDP/utils/train_utils.py
similarity index 100%
rename from distributed/FSDP/utils/train_utils.py
rename to distributed/fsdp/FSDP/utils/train_utils.py
diff --git a/distributed/fsdp/fsdp-mnist/README.md b/distributed/fsdp/fsdp-mnist/README.md
new file mode 100644
index 0000000000..73dab0a188
--- /dev/null
+++ b/distributed/fsdp/fsdp-mnist/README.md
@@ -0,0 +1,7 @@
+## FSDP MNIST
+
+To run a simple MNIST example with FSDP:
+
+```bash
+python fsdp_mnist.py
+```
diff --git a/distributed/fsdp/fsdp-mnist/fsdp_mnist.py b/distributed/fsdp/fsdp-mnist/fsdp_mnist.py
new file mode 100644
index 0000000000..9ec17a1e85
--- /dev/null
+++ b/distributed/fsdp/fsdp-mnist/fsdp_mnist.py
@@ -0,0 +1,198 @@
+# Based on: https://github.com/pytorch/examples/blob/master/mnist/main.py
+import os
+import argparse
+import functools
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torchvision import datasets, transforms
+
+
+from torch.optim.lr_scheduler import StepLR
+
+import torch.distributed as dist
+import torch.multiprocessing as mp
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data.distributed import DistributedSampler
+from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+from torch.distributed.fsdp.fully_sharded_data_parallel import (
+    CPUOffload,
+    BackwardPrefetch,
+)
+from torch.distributed.fsdp.wrap import (
+    size_based_auto_wrap_policy,
+    enable_wrap,
+    wrap,
+)
+
+def setup(rank, world_size):
+    os.environ['MASTER_ADDR'] = 'localhost'
+    os.environ['MASTER_PORT'] = '12355'
+
+    # initialize the process group
+    dist.init_process_group("nccl", rank=rank, world_size=world_size)
+
+def cleanup():
+    dist.destroy_process_group()
+
+class Net(nn.Module):
+    def __init__(self):
+        super(Net, self).__init__()
+        self.conv1 = nn.Conv2d(1, 32, 3, 1)
+        self.conv2 = nn.Conv2d(32, 64, 3, 1)
+        self.dropout1 = nn.Dropout(0.25)
+        self.dropout2 = nn.Dropout(0.5)
+        self.fc1 = nn.Linear(9216, 128)
+        self.fc2 = nn.Linear(128, 10)
+
+    def forward(self, x):
+
+        x = self.conv1(x)
+        x = F.relu(x)
+        x = self.conv2(x)
+        x = F.relu(x)
+        x = F.max_pool2d(x, 2)
+        x = self.dropout1(x)
+        x = torch.flatten(x, 1)
+        x = self.fc1(x)
+        x = F.relu(x)
+        x = self.dropout2(x)
+        x = self.fc2(x)
+        output = F.log_softmax(x, dim=1)
+        return output
+    
+def train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=None):
+    model.train()
+    ddp_loss = torch.zeros(2).to(rank)
+    if sampler:
+        sampler.set_epoch(epoch)
+    for batch_idx, (data, target) in enumerate(train_loader):
+        data, target = data.to(rank), target.to(rank)
+        optimizer.zero_grad()
+        output = model(data)
+        loss = F.nll_loss(output, target, reduction='sum')
+        loss.backward()
+        optimizer.step()
+        ddp_loss[0] += loss.item()
+        ddp_loss[1] += len(data)
+
+    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
+    if rank == 0:
+        print('Train Epoch: {} \tLoss: {:.6f}'.format(epoch, ddp_loss[0] / ddp_loss[1]))
+
+def test(model, rank, world_size, test_loader):
+    model.eval()
+    correct = 0
+    ddp_loss = torch.zeros(3).to(rank)
+    with torch.no_grad():
+        for data, target in test_loader:
+            data, target = data.to(rank), target.to(rank)
+            output = model(data)
+            ddp_loss[0] += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
+            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
+            ddp_loss[1] += pred.eq(target.view_as(pred)).sum().item()
+            ddp_loss[2] += len(data)
+
+    dist.all_reduce(ddp_loss, op=dist.ReduceOp.SUM)
+
+    if rank == 0:
+        test_loss = ddp_loss[0] / ddp_loss[2]
+        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
+            test_loss, int(ddp_loss[1]), int(ddp_loss[2]),
+            100. * ddp_loss[1] / ddp_loss[2]))
+        
+def fsdp_main(rank, world_size, args):
+    setup(rank, world_size)
+
+    transform=transforms.Compose([
+        transforms.ToTensor(),
+        transforms.Normalize((0.1307,), (0.3081,))
+    ])
+
+    dataset1 = datasets.MNIST('../data', train=True, download=True,
+                        transform=transform)
+    dataset2 = datasets.MNIST('../data', train=False,
+                        transform=transform)
+
+    sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True)
+    sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)
+
+    train_kwargs = {'batch_size': args.batch_size, 'sampler': sampler1}
+    test_kwargs = {'batch_size': args.test_batch_size, 'sampler': sampler2}
+    cuda_kwargs = {'num_workers': 2,
+                    'pin_memory': True,
+                    'shuffle': False}
+    train_kwargs.update(cuda_kwargs)
+    test_kwargs.update(cuda_kwargs)
+
+    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
+    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)
+    my_auto_wrap_policy = functools.partial(
+        size_based_auto_wrap_policy, min_num_params=20000
+    )
+    torch.cuda.set_device(rank)
+
+
+    init_start_event = torch.cuda.Event(enable_timing=True)
+    init_end_event = torch.cuda.Event(enable_timing=True)
+
+    model = Net().to(rank)
+
+    model = FSDP(model,
+        fsdp_auto_wrap_policy=my_auto_wrap_policy,
+        cpu_offload=CPUOffload(offload_params=True)
+    )
+
+    optimizer = optim.Adadelta(model.parameters(), lr=args.lr)
+
+    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
+    init_start_event.record()
+    for epoch in range(1, args.epochs + 1):
+        train(args, model, rank, world_size, train_loader, optimizer, epoch, sampler=sampler1)
+        test(model, rank, world_size, test_loader)
+        scheduler.step()
+
+    init_end_event.record()
+
+    if rank == 0:
+        print(f"CUDA event elapsed time: {init_start_event.elapsed_time(init_end_event) / 1000}sec")
+        print(f"{model}")
+
+    if args.save_model:
+        # use a barrier to make sure training is done on all ranks
+        dist.barrier()
+        states = model.state_dict()
+        if rank == 0:
+            torch.save(states, "mnist_cnn.pt")
+
+    cleanup()
+
+if __name__ == '__main__':
+    # Training settings
+    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
+    parser.add_argument('--batch-size', type=int, default=64, metavar='N',
+                        help='input batch size for training (default: 64)')
+    parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
+                        help='input batch size for testing (default: 1000)')
+    parser.add_argument('--epochs', type=int, default=10, metavar='N',
+                        help='number of epochs to train (default: 14)')
+    parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
+                        help='learning rate (default: 1.0)')
+    parser.add_argument('--gamma', type=float, default=0.7, metavar='M',
+                        help='Learning rate step gamma (default: 0.7)')
+    parser.add_argument('--no-cuda', action='store_true', default=False,
+                        help='disables CUDA training')
+    parser.add_argument('--seed', type=int, default=1, metavar='S',
+                        help='random seed (default: 1)')
+    parser.add_argument('--save-model', action='store_true', default=False,
+                        help='For Saving the current Model')
+    args = parser.parse_args()
+
+    torch.manual_seed(args.seed)
+
+    WORLD_SIZE = torch.cuda.device_count()
+    mp.spawn(fsdp_main,
+        args=(WORLD_SIZE, args),
+        nprocs=WORLD_SIZE,
+        join=True)
\ No newline at end of file