@@ -3322,6 +3322,57 @@ def make_post_acc_grad_hook(id):
3322
3322
3323
3323
self .check_output_and_recompiles (fn )
3324
3324
3325
+ def test_sac (self ):
3326
+ # circular import
3327
+ from torch .utils .checkpoint import (
3328
+ checkpoint ,
3329
+ CheckpointPolicy ,
3330
+ create_selective_checkpoint_contexts ,
3331
+ )
3332
+
3333
+ def fn ():
3334
+ class mlp (nn .Module ):
3335
+ def __init__ (self ):
3336
+ super ().__init__ ()
3337
+ self .layer1 = nn .Linear (10 , 10 )
3338
+ self .layer2 = nn .Linear (10 , 10 )
3339
+ self .layer3 = nn .Linear (10 , 10 )
3340
+ self .layer4 = nn .Linear (10 , 10 )
3341
+
3342
+ def forward (self , x ):
3343
+ x = self .layer1 (x )
3344
+ x = self .layer2 (x )
3345
+ x = self .layer3 (x )
3346
+ x = self .layer4 (x )
3347
+ return x
3348
+
3349
+ recompute_list = [torch .ops .aten .addmm .default ]
3350
+
3351
+ def recompute_policy (ctx , op , * args , ** kwargs ):
3352
+ if op in recompute_list :
3353
+ return CheckpointPolicy .MUST_RECOMPUTE
3354
+ else :
3355
+ return CheckpointPolicy .PREFER_SAVE
3356
+
3357
+ def context_fn ():
3358
+ return create_selective_checkpoint_contexts (recompute_policy )
3359
+
3360
+ model = mlp ()
3361
+ input = torch .randn (1 , 10 )
3362
+
3363
+ out = checkpoint (model , input , use_reentrant = False , context_fn = context_fn )
3364
+ out .sum ().backward ()
3365
+ yield model .layer1 .weight .grad
3366
+ yield model .layer1 .bias .grad
3367
+ yield model .layer2 .weight .grad
3368
+ yield model .layer2 .bias .grad
3369
+ yield model .layer3 .weight .grad
3370
+ yield model .layer3 .bias .grad
3371
+ yield model .layer4 .weight .grad
3372
+ yield model .layer4 .bias .grad
3373
+
3374
+ self .check_output_and_recompiles (fn )
3375
+
3325
3376
3326
3377
def load_test_module (name ):
3327
3378
testdir = Path (__file__ ).absolute ().parent .parent
0 commit comments