Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a LightningDataModule example for lit_minist #31

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,5 @@ venv.bak/
lightning_logs/
MNIST
.DS_Store

logs/
27 changes: 3 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,33 +50,12 @@ pip install -r requirements.txt
```
Next, navigate to any file and run it.
```bash
# module folder
cd project

# run module (example: mnist as your main contribution)
python lit_classifier_main.py
python project/train.py fit --config ../config/config.yaml
```

## Imports
This project is setup as a package which means you can now easily import any file into any other file like so:
```python
from project.datasets.mnist import mnist
from project.lit_classifier_main import LitClassifier
from pytorch_lightning import Trainer

# model
model = LitClassifier()

# data
train, val, test = mnist()

# train
trainer = Trainer()
trainer.fit(model, train, val)

# test using the best model!
trainer.test(test_dataloaders=test)
```
## Config
If you would like to learn more about the lightning cli please head over to the [LighningCLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html) docs.

### Citation
```
Expand Down
31 changes: 31 additions & 0 deletions config/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# pytorch_lightning==2.0.1
seed_everything: 1234
trainer:
accelerator: auto
devices: auto
num_nodes: 1
precision: 32-true
logger:
- class_path: pytorch_lightning.loggers.WandbLogger
init_args:
name: lit_mnist_logs
save_dir: logs
project: pl_template
callbacks:
- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
verbose: True
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
fast_dev_run: false
max_epochs: 100
model:
class_path: lit_mnist.LitClassifier
init_args:
hidden_dim: 128
learning_rate: 0.001
data:
class_path: lit_mnist.MNISTDataModule
init_args:
data_dir: ''
batch_size: 16
ckpt_path: null
189 changes: 110 additions & 79 deletions project/lit_mnist.py
Original file line number Diff line number Diff line change
@@ -1,96 +1,127 @@
from argparse import ArgumentParser

import torch
from torch import nn

import pytorch_lightning as pl
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from torchvision.datasets.mnist import MNIST
from torchvision import transforms

from typing import Optional

class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()

self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('test_loss', loss)

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser

class MyModule(nn.Module):
'''
Class_Discription
'''
def __init__(self, hidden_dim) -> None:
super().__init__()
self.l1 = torch.nn.Linear(28 * 28, hidden_dim)
self.l2 = torch.nn.Linear(hidden_dim, 10)

def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x


class LitClassifier(pl.LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3):
super().__init__()
self.save_hyperparameters()

self.model = MyModule(self.hparams.hidden_dim)

def forward(self, x):
return self.model(x)

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('valid_loss', loss)

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log('test_loss', loss)

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128)
parser.add_argument('--learning_rate', type=float, default=0.0001)
return parser

class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "", batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size

def setup(self, stage:Optional[str] = None):
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
self.mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
self.mnist_train, self.mnist_val = random_split(dataset, [55000, 5000])

def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)

def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)

def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)

def cli_main():
pl.seed_everything(1234)

# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
args = parser.parse_args()

# ------------
# data
# ------------
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
mnist_test = MNIST('', train=False, download=True, transform=transforms.ToTensor())
mnist_train, mnist_val = random_split(dataset, [55000, 5000])

train_loader = DataLoader(mnist_train, batch_size=args.batch_size)
val_loader = DataLoader(mnist_val, batch_size=args.batch_size)
test_loader = DataLoader(mnist_test, batch_size=args.batch_size)

# ------------
# model
# ------------
model = LitClassifier(args.hidden_dim, args.learning_rate)

# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, train_loader, val_loader)

# ------------
# testing
# ------------
trainer.test(test_dataloaders=test_loader)
pl.seed_everything(1234)

# ------------
# args
# ------------
parser = ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--num_workers', default=1, type=int)

parser = pl.Trainer.add_argparse_args(parser)
parser = LitClassifier.add_model_specific_args(parser)
args = parser.parse_args()

# ------------
# data
# ------------
mnist = MNISTDataModule('')

# ------------
# model
# ------------
model = LitClassifier(args.hidden_dim, args.learning_rate)

# ------------
# training
# ------------
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, datamodule=mnist)

# ------------
# testing
# ------------
trainer.test(datamodule=mnist)


if __name__ == '__main__':
cli_main()
cli_main()
17 changes: 17 additions & 0 deletions project/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import logging
import torch
import pytorch_lightning as pl
from pytorch_lightning.cli import LightningCLI

pl_logger = logging.getLogger('pytorch_lightning')

if __name__ == '__main__':
import datetime
pl_logger.info(f"Starting at {datetime.datetime.now()}")

torch.set_float32_matmul_precision('medium')

cli = LightningCLI(
trainer_class=pl.Trainer,
save_config_callback=None,
)
7 changes: 4 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pytorch-lightning >= 1.0.0rc2
torch >= 1.3.0
torchvision >= 0.6.0
pytorch-lightning >= 1.7.2
torch
torchvision
torchaudio
25 changes: 25 additions & 0 deletions slurm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash
#SBATCH --partition=gpu
#SBATCH --job-name=multinode_pl_test
#SBATCH --nodes=2
#SBATCH --exclusive
#SBATCH --comment clap
#SBATCH --ntasks-per-node=8
#SBATCH --output=%x_%j.out

export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/nccl/build/lib:/opt/aws-ofi-nccl-install/lib
export NCCL_PROTO=simple
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/aws-ofi-nccl/lib
export PATH=$PATH:/opt/amazon/efa/bin:/opt/amazon/openmpi/bin
export FI_EFA_FORK_SAFE=1
export FI_LOG_LEVEL=1
export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn
export NCCL_DEBUG=info
export OMPI_MCA_mtl_base_verbose=1
export FI_EFA_ENABLE_SHM_TRANSFER=0
export FI_PROVIDER=efa
export FI_EFA_TX_MIN_CREDITS=64
export NCCL_TREE_THRESHOLD=0
export NCCL_SOCKET_IFNAME=^docker0,lo

srun --comment clap /home/knoriy/fsx/miniconda3/envs/clasp/bin/python /home/knoriy/deep-learning-project-template/project/lit_mnist.py --accelerator gpu --strategy ddp --num_nodes 2 --devices 8