Skip to content

Commit 90df7c0

Browse files
janeyx99pytorchmergebot
authored andcommitted
Migrate state_dict bc test to OptimizerInfo, increase coverage (pytorch#116500)
Pull Request resolved: pytorch#116500 Approved by: https://github.com/albanD
1 parent 19e93b8 commit 90df7c0

File tree

3 files changed

+68
-22
lines changed

3 files changed

+68
-22
lines changed

test/optim/test_optim.py

-22
Original file line numberDiff line numberDiff line change
@@ -247,28 +247,6 @@ def fn_base(optimizer, weight, bias):
247247
for _i in range(20):
248248
optimizer.step(fn)
249249

250-
# Make sure that optimizers that support maximize can load older models
251-
old_state_dict = deepcopy(optimizer.state_dict())
252-
state_dict_no_maximize = deepcopy(optimizer.state_dict())
253-
if "maximize" in state_dict_no_maximize["param_groups"][0]:
254-
for group in state_dict_no_maximize["param_groups"]:
255-
del group["maximize"]
256-
optimizer.load_state_dict(state_dict_no_maximize)
257-
# Make sure we can still step
258-
optimizer.step()
259-
# Undo these changes before proceeding!
260-
optimizer.load_state_dict(old_state_dict)
261-
# Make sure that optimizers that support foreach can load older models
262-
state_dict_no_foreach = deepcopy(optimizer.state_dict())
263-
if "foreach" in state_dict_no_foreach["param_groups"][0]:
264-
for group in state_dict_no_foreach["param_groups"]:
265-
del group["foreach"]
266-
optimizer.load_state_dict(state_dict_no_foreach)
267-
# Make sure we can still step
268-
optimizer.step()
269-
# Undo these changes before proceeding!
270-
optimizer.load_state_dict(old_state_dict)
271-
272250
# Make sure that loading optimizers with step not wrapped in tensor can work
273251
state_dict = optimizer.state_dict()
274252
if "step" in state_dict["state"][0] and torch.is_tensor(

test/test_optim.py

+47
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,53 @@ def fwd_bwd(optim, w, b, i):
448448
optimizer_c.state_dict()["param_groups"][-1]
449449
)
450450

451+
@optims(optim_db, dtypes=[torch.float32])
452+
def test_can_load_older_state_dict(self, device, dtype, optim_info):
453+
new_flags = ["maximize", "foreach", "fused", "differentiable", "capturable"]
454+
optim_cls = optim_info.optim_cls
455+
456+
# Skip differentiable testing for now, see https://github.com/pytorch/pytorch/issues/116490
457+
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(device, dtype, optim_info, skip=("differentiable",))
458+
for optim_input in all_optim_inputs:
459+
torch.manual_seed(1)
460+
model = torch.nn.Sequential(
461+
torch.nn.Conv2d(4, 2, 1, stride=2),
462+
torch.nn.BatchNorm2d(2, eps=1e-05, momentum=0.1),
463+
)
464+
model.to(dtype=dtype, device=device)
465+
input = torch.rand(1, 4, 16, 16, device=device, dtype=dtype)
466+
optimizer = optim_cls(model.parameters(), **optim_input.kwargs)
467+
468+
def fwd_bwd(optim, mod, i):
469+
optim.zero_grad()
470+
loss = mod(i).sum()
471+
loss.backward()
472+
return loss
473+
474+
for _ in range(3):
475+
if optim_cls.__name__ == "LBFGS":
476+
optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
477+
else:
478+
fwd_bwd(optimizer, model, input)
479+
optimizer.step()
480+
481+
# old_state_dict has all new flags del'd
482+
old_state_dict = deepcopy(optimizer.state_dict())
483+
old_state_dict_pg = old_state_dict["param_groups"]
484+
for group in old_state_dict_pg:
485+
for flag in new_flags:
486+
if flag in group:
487+
del group[flag]
488+
489+
optimizer.load_state_dict(old_state_dict)
490+
491+
# Make sure we can still step
492+
if optim_cls.__name__ == "LBFGS":
493+
optimizer.step(functools.partial(fwd_bwd, optimizer, model, input))
494+
else:
495+
fwd_bwd(optimizer, model, input)
496+
optimizer.step()
497+
451498

452499
instantiate_device_type_tests(TestOptimRenewed, globals(), allow_mps=True)
453500

torch/testing/_internal/common_optimizers.py

+21
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,12 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
11101110
step_requires_closure=True,
11111111
supports_param_groups=False,
11121112
supports_multiple_devices=False,
1113+
skips=(
1114+
# Fails on MacOS 13.2.1 in CI https://github.com/pytorch/pytorch/issues/117094
1115+
DecorateInfo(
1116+
skipIfMps, "TestOptimRenewed", "test_can_load_older_state_dict"
1117+
),
1118+
),
11131119
),
11141120
OptimizerInfo(
11151121
NAdam,
@@ -1138,6 +1144,14 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
11381144
"TestOptimRenewed",
11391145
"test_state_dict_deterministic",
11401146
),
1147+
DecorateInfo(
1148+
skipIfTorchDynamo(
1149+
"See https://github.com/pytorch/pytorch/issues/116499"
1150+
),
1151+
"TestOptimRenewed",
1152+
"test_can_load_older_state_dict",
1153+
device_type="cuda",
1154+
),
11411155
),
11421156
),
11431157
OptimizerInfo(
@@ -1310,6 +1324,13 @@ def _get_optim_inputs_including_global_cliquey_kwargs(
13101324
"TestOptimRenewed",
13111325
"test_state_dict_deterministic",
13121326
),
1327+
DecorateInfo(
1328+
unittest.skip(
1329+
"SparseAdam does not support dense gradients, see #116507"
1330+
),
1331+
"TestOptimRenewed",
1332+
"test_can_load_older_state_dict",
1333+
),
13131334
),
13141335
),
13151336
]

0 commit comments

Comments
 (0)