Skip to content

Commit 47381fb

Browse files
RHOAIENG-3771 - Reduce execution time of E2E tests
1 parent 0beece1 commit 47381fb

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

Diff for: tests/e2e/mnist.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pytorch_lightning.callbacks.progress import TQDMProgressBar
2020
from torch import nn
2121
from torch.nn import functional as F
22-
from torch.utils.data import DataLoader, random_split
22+
from torch.utils.data import DataLoader, random_split, RandomSampler
2323
from torchmetrics import Accuracy
2424
from torchvision import transforms
2525
from torchvision.datasets import MNIST
@@ -127,7 +127,11 @@ def setup(self, stage=None):
127127
)
128128

129129
def train_dataloader(self):
130-
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
130+
return DataLoader(
131+
self.mnist_train,
132+
batch_size=BATCH_SIZE,
133+
sampler=RandomSampler(self.mnist_train, num_samples=1000),
134+
)
131135

132136
def val_dataloader(self):
133137
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
@@ -147,10 +151,11 @@ def test_dataloader(self):
147151
trainer = Trainer(
148152
accelerator="auto",
149153
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
150-
max_epochs=5,
154+
max_epochs=3,
151155
callbacks=[TQDMProgressBar(refresh_rate=20)],
152156
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
153157
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
158+
replace_sampler_ddp=False,
154159
strategy="ddp",
155160
)
156161

0 commit comments

Comments
 (0)