Skip to content

Commit 0db3f66

Browse files
committed
refactor: add disconnected mnist training script
1 parent c212bd8 commit 0db3f66

File tree

1 file changed

+162
-0
lines changed

1 file changed

+162
-0
lines changed

Diff for: demo-notebooks/guided-demos/mnist_disconnected.py

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Copyright 2022 IBM, Red Hat
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# In[]
16+
import os
17+
18+
import torch
19+
from pytorch_lightning import LightningModule, Trainer
20+
from pytorch_lightning.callbacks.progress import TQDMProgressBar
21+
from pytorch_lightning.loggers import CSVLogger
22+
from torch import nn
23+
from torch.nn import functional as F
24+
from torch.utils.data import DataLoader, random_split
25+
from torchmetrics import Accuracy
26+
from torchvision import transforms
27+
from torchvision.datasets import MNIST
28+
29+
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
30+
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
31+
# %%
32+
33+
local_minst_path = os.path.join(PATH_DATASETS, "mnist")
34+
35+
print("prior to running the trainer")
36+
print("MASTER_ADDR: is ", os.getenv("MASTER_ADDR"))
37+
print("MASTER_PORT: is ", os.getenv("MASTER_PORT"))
38+
39+
40+
class LitMNIST(LightningModule):
41+
def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
42+
super().__init__()
43+
44+
# Set our init args as class attributes
45+
self.data_dir = data_dir
46+
self.hidden_size = hidden_size
47+
self.learning_rate = learning_rate
48+
49+
# Hardcode some dataset specific attributes
50+
self.num_classes = 10
51+
self.dims = (1, 28, 28)
52+
channels, width, height = self.dims
53+
self.transform = transforms.Compose(
54+
[
55+
transforms.ToTensor(),
56+
transforms.Normalize((0.1307,), (0.3081,)),
57+
]
58+
)
59+
60+
# Define PyTorch model
61+
self.model = nn.Sequential(
62+
nn.Flatten(),
63+
nn.Linear(channels * width * height, hidden_size),
64+
nn.ReLU(),
65+
nn.Dropout(0.1),
66+
nn.Linear(hidden_size, hidden_size),
67+
nn.ReLU(),
68+
nn.Dropout(0.1),
69+
nn.Linear(hidden_size, self.num_classes),
70+
)
71+
72+
self.val_accuracy = Accuracy()
73+
self.test_accuracy = Accuracy()
74+
75+
def forward(self, x):
76+
x = self.model(x)
77+
return F.log_softmax(x, dim=1)
78+
79+
def training_step(self, batch, batch_idx):
80+
x, y = batch
81+
logits = self(x)
82+
loss = F.nll_loss(logits, y)
83+
return loss
84+
85+
def validation_step(self, batch, batch_idx):
86+
x, y = batch
87+
logits = self(x)
88+
loss = F.nll_loss(logits, y)
89+
preds = torch.argmax(logits, dim=1)
90+
self.val_accuracy.update(preds, y)
91+
92+
# Calling self.log will surface up scalars for you in TensorBoard
93+
self.log("val_loss", loss, prog_bar=True)
94+
self.log("val_acc", self.val_accuracy, prog_bar=True)
95+
96+
def test_step(self, batch, batch_idx):
97+
x, y = batch
98+
logits = self(x)
99+
loss = F.nll_loss(logits, y)
100+
preds = torch.argmax(logits, dim=1)
101+
self.test_accuracy.update(preds, y)
102+
103+
# Calling self.log will surface up scalars for you in TensorBoard
104+
self.log("test_loss", loss, prog_bar=True)
105+
self.log("test_acc", self.test_accuracy, prog_bar=True)
106+
107+
def configure_optimizers(self):
108+
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
109+
return optimizer
110+
111+
####################
112+
# DATA RELATED HOOKS
113+
####################
114+
115+
def prepare_data(self):
116+
# download
117+
print("Downloading MNIST dataset...")
118+
MNIST(self.data_dir, train=True, download=False)
119+
MNIST(self.data_dir, train=False, download=False)
120+
121+
def setup(self, stage=None):
122+
# Assign train/val datasets for use in dataloaders
123+
if stage == "fit" or stage is None:
124+
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform, download=False)
125+
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
126+
127+
# Assign test dataset for use in dataloader(s)
128+
if stage == "test" or stage is None:
129+
self.mnist_test = MNIST(
130+
self.data_dir, train=False, transform=self.transform, download=False
131+
)
132+
133+
def train_dataloader(self):
134+
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
135+
136+
def val_dataloader(self):
137+
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
138+
139+
def test_dataloader(self):
140+
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
141+
142+
143+
# Init DataLoader from MNIST Dataset
144+
145+
model = LitMNIST(data_dir=local_minst_path)
146+
147+
print("GROUP: ", int(os.environ.get("GROUP_WORLD_SIZE", 1)))
148+
print("LOCAL: ", int(os.environ.get("LOCAL_WORLD_SIZE", 1)))
149+
150+
# Initialize a trainer
151+
trainer = Trainer(
152+
accelerator="auto",
153+
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
154+
max_epochs=5,
155+
callbacks=[TQDMProgressBar(refresh_rate=20)],
156+
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
157+
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
158+
strategy="ddp",
159+
)
160+
161+
# Train the model ⚡
162+
trainer.fit(model)

0 commit comments

Comments
 (0)