Skip to content

Commit 7a9b5bb

Browse files
RHOAIENG-3771 - Reduce execution time of E2E tests
By reducing number of epochs and number of training samples in each epoch it was possible to reduce test execution time from more than 10 minutes to less than 2 minutes.
1 parent 8aeaefc commit 7a9b5bb

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

test/e2e/mnist.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytorch_lightning.callbacks.progress import TQDMProgressBar
2121
from torch import nn
2222
from torch.nn import functional as F
23-
from torch.utils.data import DataLoader, random_split
23+
from torch.utils.data import DataLoader, random_split, RandomSampler
2424
from torchmetrics import Accuracy
2525
from torchvision import transforms
2626
from torchvision.datasets import MNIST
@@ -158,7 +158,7 @@ def setup(self, stage=None):
158158
)
159159

160160
def train_dataloader(self):
161-
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
161+
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, sampler=RandomSampler(self.mnist_train, num_samples=1000))
162162

163163
def val_dataloader(self):
164164
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
@@ -178,10 +178,11 @@ def test_dataloader(self):
178178
trainer = Trainer(
179179
accelerator="auto",
180180
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
181-
max_epochs=5,
181+
max_epochs=3,
182182
callbacks=[TQDMProgressBar(refresh_rate=20)],
183183
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
184184
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
185+
replace_sampler_ddp=False,
185186
strategy="ddp",
186187
)
187188

test/odh/resources/mnist.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from pytorch_lightning.callbacks.progress import TQDMProgressBar
2121
from torch import nn
2222
from torch.nn import functional as F
23-
from torch.utils.data import DataLoader, random_split
23+
from torch.utils.data import DataLoader, random_split, RandomSampler
2424
from torchmetrics import Accuracy
2525
from torchvision import transforms
2626
from torchvision.datasets import MNIST
@@ -158,7 +158,7 @@ def setup(self, stage=None):
158158
)
159159

160160
def train_dataloader(self):
161-
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
161+
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE, sampler=RandomSampler(self.mnist_train, num_samples=1000))
162162

163163
def val_dataloader(self):
164164
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
@@ -178,10 +178,11 @@ def test_dataloader(self):
178178
trainer = Trainer(
179179
accelerator="auto",
180180
# devices=1 if torch.cuda.is_available() else None, # limiting got iPython runs
181-
max_epochs=2,
181+
max_epochs=3,
182182
callbacks=[TQDMProgressBar(refresh_rate=20)],
183183
num_nodes=int(os.environ.get("GROUP_WORLD_SIZE", 1)),
184184
devices=int(os.environ.get("LOCAL_WORLD_SIZE", 1)),
185+
replace_sampler_ddp=False,
185186
strategy="ddp",
186187
)
187188

0 commit comments

Comments
 (0)