Skip to content

Commit 79ac44e

Browse files
authored
Promote Supermask out of prototype (#1729)
This PR promotes Supermask and block sparsity from prototype -> `torchao.sparsity`, instead of the `apply_supermask` function which was previously closely coupled with SAM. It adds a new public API for `SupermaskLinear`, which users can use to add Supermask to their models for training with ``` sparsify_(model, lambda x: SupermaskLinear.from_linear(x, block_size=64, sparsity_level=0.9) ``` To accelerate for inference, we convert the `SupermaskLinear` model back into a `nn.Linear`, which simplifies the Supermask logic: ``` sparsify_(model, lambda x: SupermaskLinear.to_linear(x, sparsity_level=0.9) ``` **bc-breaking** The previous prototype APIs, `torchao.sparsity.prototype.superblock.supermask` and `torchao.prototype.sparsity.superblock.supermask` have been deprecated. You can use `torchao.sparsity.supermask` instead.
1 parent 988c5c9 commit 79ac44e

File tree

5 files changed

+212
-370
lines changed

5 files changed

+212
-370
lines changed

test/sparsity/test_supermask.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import logging
2+
import unittest
3+
4+
import pytest
5+
import torch
6+
from torch import nn
7+
from torch.testing._internal import common_utils
8+
9+
from torchao.sparsity import sparsify_
10+
11+
logging.basicConfig(
12+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
13+
)
14+
15+
16+
class TestSupermask(common_utils.TestCase):
17+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
18+
@common_utils.parametrize("sparsity_level", [0.25, 0.5])
19+
@common_utils.parametrize("blocksize", [2, 4, 8])
20+
def test_supermask(self, sparsity_level, blocksize):
21+
model = (
22+
nn.Sequential(
23+
nn.Linear(16, 16, bias=False),
24+
)
25+
.half()
26+
.cuda()
27+
.eval()
28+
)
29+
30+
from torchao.sparsity import SupermaskLinear
31+
32+
M, N = model[0].weight.shape
33+
sparsify_(
34+
model,
35+
lambda x: SupermaskLinear.from_linear(
36+
x, sparsity_level=sparsity_level, blocksize=blocksize
37+
),
38+
)
39+
sparsify_(model, SupermaskLinear.to_linear)
40+
weight_bsr = model[0].weight.to_sparse_bsr(blocksize=blocksize)
41+
42+
# Test correct sparsity level
43+
nnz = weight_bsr._nnz()
44+
expected = round((M // blocksize) * (N // blocksize) * (1 - sparsity_level))
45+
assert nnz == expected, f"Expected {expected} nonzeros, got {nnz}"
46+
47+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Need CUDA available")
48+
def test_from_linear(self):
49+
from torchao.sparsity import SupermaskLinear
50+
51+
linear = nn.Linear(128, 128)
52+
supermask_linear = SupermaskLinear.from_linear(
53+
linear, sparsity_level=0.5, blocksize=4
54+
)
55+
assert supermask_linear.weight.shape == linear.weight.shape
56+
57+
58+
common_utils.instantiate_parametrized_tests(TestSupermask)
59+
60+
if __name__ == "__main__":
61+
unittest.main()

0 commit comments

Comments
 (0)