Skip to content

Commit 1e777f5

Browse files
iden-kalemajfacebook-github-bot
authored andcommitted
Replace register_backward_hook with register_full_backward_hook (#720)
Summary: register_backward_hook is deprecated. Differential Revision: D68562558
1 parent ea4cb95 commit 1e777f5

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

.github/workflows/ci_gpu.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ jobs:
8383
mkdir -p runs/cifar10/logs
8484
mkdir -p runs/cifar10/test-reports
8585
pip install tensorboard
86-
python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device cuda
86+
python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device cuda --clip_per_layer
8787
python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)"
8888
python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device cuda --grad_sample_mode no_op
8989
python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)"

opacus/grad_sample/grad_sample_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def add_hooks(
207207
)
208208

209209
self.autograd_grad_sample_hooks.append(
210-
module.register_backward_hook(
210+
module.register_full_backward_hook(
211211
partial(
212212
self.capture_backprops_hook,
213213
loss_reduction=loss_reduction,

opacus/tests/multigpu_gradcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def test_gradient_correct(self) -> None:
206206
)
207207
clipping_grad_sample_pairs.append(("ghost", "ghost"))
208208

209+
clipping_grad_sample_pairs = [("per_layer", "hooks")]
209210
for clipping, grad_sample_mode in clipping_grad_sample_pairs:
210211

211212
weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)

0 commit comments

Comments
 (0)