1919from  pytorch_lightning .callbacks .progress  import  TQDMProgressBar 
2020from  torch  import  nn 
2121from  torch .nn  import  functional  as  F 
22- from  torch .utils .data  import  DataLoader , random_split 
22+ from  torch .utils .data  import  DataLoader , random_split ,  RandomSampler 
2323from  torchmetrics  import  Accuracy 
2424from  torchvision  import  transforms 
2525from  torchvision .datasets  import  MNIST 
@@ -127,7 +127,11 @@ def setup(self, stage=None):
127127            )
128128
129129    def  train_dataloader (self ):
130-         return  DataLoader (self .mnist_train , batch_size = BATCH_SIZE )
130+         return  DataLoader (
131+             self .mnist_train ,
132+             batch_size = BATCH_SIZE ,
133+             sampler = RandomSampler (self .mnist_train , num_samples = 1000 ),
134+         )
131135
132136    def  val_dataloader (self ):
133137        return  DataLoader (self .mnist_val , batch_size = BATCH_SIZE )
@@ -147,10 +151,11 @@ def test_dataloader(self):
147151trainer  =  Trainer (
148152    accelerator = "auto" ,
149153    # devices=1 if torch.cuda.is_available() else None,  # limiting got iPython runs 
150-     max_epochs = 5 ,
154+     max_epochs = 3 ,
151155    callbacks = [TQDMProgressBar (refresh_rate = 20 )],
152156    num_nodes = int (os .environ .get ("GROUP_WORLD_SIZE" , 1 )),
153157    devices = int (os .environ .get ("LOCAL_WORLD_SIZE" , 1 )),
158+     replace_sampler_ddp = False ,
154159    strategy = "ddp" ,
155160)
156161
0 commit comments