19
19
from pytorch_lightning .callbacks .progress import TQDMProgressBar
20
20
from torch import nn
21
21
from torch .nn import functional as F
22
- from torch .utils .data import DataLoader , random_split
22
+ from torch .utils .data import DataLoader , random_split , RandomSampler
23
23
from torchmetrics import Accuracy
24
24
from torchvision import transforms
25
25
from torchvision .datasets import MNIST
@@ -127,7 +127,11 @@ def setup(self, stage=None):
127
127
)
128
128
129
129
def train_dataloader (self ):
130
- return DataLoader (self .mnist_train , batch_size = BATCH_SIZE )
130
+ return DataLoader (
131
+ self .mnist_train ,
132
+ batch_size = BATCH_SIZE ,
133
+ sampler = RandomSampler (self .mnist_train , num_samples = 1000 ),
134
+ )
131
135
132
136
def val_dataloader (self ):
133
137
return DataLoader (self .mnist_val , batch_size = BATCH_SIZE )
@@ -147,10 +151,11 @@ def test_dataloader(self):
147
151
trainer = Trainer (
148
152
accelerator = "auto" ,
149
153
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
150
- max_epochs = 5 ,
154
+ max_epochs = 3 ,
151
155
callbacks = [TQDMProgressBar (refresh_rate = 20 )],
152
156
num_nodes = int (os .environ .get ("GROUP_WORLD_SIZE" , 1 )),
153
157
devices = int (os .environ .get ("LOCAL_WORLD_SIZE" , 1 )),
158
+ replace_sampler_ddp = False ,
154
159
strategy = "ddp" ,
155
160
)
156
161
0 commit comments