Skip to content

Commit ec091f1

Browse files
committed
hot fix ddp tests: decreased lr for stability
1 parent 5d7196c commit ec091f1

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

tests/test_runs/test_ddp_cases/run_retrieval_experiment_ddp.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def check_and_save_ids(self, outputs: List[Any], mode: str) -> None:
153153
torch.save(ids_per_step_synced, pattern.format(experiment=self.exp_num, epoch=self.trainer.current_epoch))
154154

155155
def configure_optimizers(self) -> Any:
156-
return Adam(params=self.parameters(), lr=0.5)
156+
return Adam(params=self.parameters(), lr=1e-3)
157157

158158
def on_train_end(self) -> None:
159159
torch.save(self.model, self.save_path_ckpt_pattern.format(experiment=self.exp_num))

tests/test_runs/test_ddp_cases/test_train_with_metrics.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@
3636

3737
@pytest.mark.long
3838
@pytest.mark.parametrize("batch_size", [12])
39-
@pytest.mark.parametrize("max_epochs", [2])
40-
@pytest.mark.parametrize("num_labels,atol", [(120, 1e-2), (1200, 2e-2)])
39+
@pytest.mark.parametrize("max_epochs", [4])
40+
@pytest.mark.parametrize("num_labels,atol", [(120, 1e-2), (360, 2e-2)])
4141
def test_metrics_is_similar_in_ddp(num_labels: int, atol: float, batch_size: int, max_epochs: int) -> None:
4242
devices = (1, 2, 3)
4343
# We will compare metrics from same experiment but with different amount of devices. For this we aggregate
@@ -48,18 +48,19 @@ def test_metrics_is_similar_in_ddp(num_labels: int, atol: float, batch_size: int
4848
metric_topk2values = defaultdict(list)
4949

5050
for num_devices in devices:
51-
batch_size //= num_devices
51+
batch_size_eff = batch_size // num_devices
52+
5253
params = (
5354
f"--devices {num_devices} "
5455
f"--max_epochs {max_epochs} "
5556
f"--num_labels {num_labels} "
56-
f"--batch_size {batch_size}"
57+
f"--batch_size {batch_size_eff}"
5758
)
5859
cmd = f"python {exp_file} " + params
5960
subprocess.run(cmd, check=True, shell=True)
6061

6162
metrics_path = MetricValCallbackWithSaving.save_path_pattern.format(
62-
devices=num_devices, batch_size=batch_size, num_labels=num_labels
63+
devices=num_devices, batch_size=batch_size_eff, num_labels=num_labels
6364
)
6465
metrics = torch.load(metrics_path)[OVERALL_CATEGORIES_KEY]
6566
Path(metrics_path).unlink(missing_ok=True)

0 commit comments

Comments
 (0)