Skip to content

Commit ceceea5

Browse files
authored
promote blocksparse from prototype, make it faster (#1734)
This PR promotes block sparsity from prototype in torchao. Chiefly, it ports over the triton addmm blocksparse kernels from core, and makes several performance improvements to them. All of the numbers reported below are for an H100, with blocksize=64 and sparsity_level=0.9. The default dense baseline is 134 tok/s 1) Adds padding support to the triton kernel for dense matrices with dimension < 16, like those we run into during decoding. (214 -> 218 tok/s) 2) Changes the default [num_stages](triton-lang/triton#512) parameter from 1 to 4. This has a large effect on performance, and it seemed like the default kernel autotuning either does not modify or deems this parameter to be unimportant for some reason. (218 -> 263 tok/s). 3) Adds an env_var, BSR_AUTOTUNE, that users can use if they want to do kernel autotuning on top of the default parameters. (263 -> 266 tok/s) This seems to matter more for bs=n compute bound workloads, where I see a reduction from 0.3855 to 0.3745s on bs=8192 prefill (roughly 3%) So in total we are seeing a **1.985x** speedup 🚀 I've also updated the documentation to not reference prototype - planning on updating the diagram in a subsequent PR. ### Testing I added a new test case for the padding inputs and moved the test file out of prototype. ``` python test/sparsity/test_sparse_api.py ```
1 parent ed16fe7 commit ceceea5

File tree

9 files changed

+843
-43
lines changed

9 files changed

+843
-43
lines changed

test/prototype/test_sparse_api.py renamed to test/sparsity/test_sparse_api.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,9 @@ class TestBlockSparseWeight(common_utils.TestCase):
132132
)
133133
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
134134
@common_utils.parametrize("compile", [True, False])
135-
def test_sparse(self, compile):
136-
input = torch.rand((1024, 1024)).half().cuda()
135+
@common_utils.parametrize("input_shape", [1, 1024])
136+
def test_sparse(self, compile, input_shape):
137+
input = torch.rand((input_shape, 1024)).half().cuda()
137138
model = (
138139
nn.Sequential(
139140
nn.Linear(1024, 2048),
@@ -152,9 +153,7 @@ def test_sparse(self, compile):
152153
model[1].weight.data = create_block_sparse_tensor(M, N, 64, 0.5, torch.float16)
153154
dense_result = model(input)
154155

155-
from torchao.prototype.sparsity.superblock.blocksparse import (
156-
block_sparse_weight,
157-
)
156+
from torchao.sparsity import block_sparse_weight
158157

159158
sparsify_(model, block_sparse_weight(blocksize=64))
160159
# if compile:

torchao/_models/llama/generate.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -793,9 +793,37 @@ def ffn_or_attn_only(mod, fqn):
793793
from torchao.sparsity import semi_sparse_weight, sparsify_
794794

795795
if "semi" in sparsity:
796-
# TODO there is a bug here, need to fix
796+
# Fixed sparsity level for 2:4
797797
sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only)
798798

799+
if "bsr" in sparsity:
800+
from torchao.sparsity import SupermaskLinear, block_sparse_weight
801+
802+
# parse "bsr-0.9-64"
803+
_, sparsity_level, blocksize = sparsity.split("-")
804+
sparsity_level, blocksize = float(sparsity_level), int(blocksize)
805+
sparsify_(
806+
model,
807+
lambda x: SupermaskLinear.from_linear(
808+
x,
809+
sparsity_level=sparsity_level,
810+
blocksize=blocksize,
811+
),
812+
filter_fn=ffn_only,
813+
)
814+
print(model)
815+
sparsify_(
816+
model,
817+
SupermaskLinear.to_linear,
818+
filter_fn=ffn_only,
819+
)
820+
print(model)
821+
822+
# Accelerate with triton bsr kernels
823+
sparsify_(
824+
model, block_sparse_weight(blocksize=blocksize), filter_fn=ffn_only
825+
)
826+
799827
model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9
800828

801829
if save:
@@ -810,7 +838,10 @@ def ffn_or_attn_only(mod, fqn):
810838
print("Compiling Model")
811839
global decode_one_token, prefill
812840
decode_one_token = torch.compile(
813-
decode_one_token, mode="reduce-overhead", fullgraph=True
841+
decode_one_token,
842+
mode="reduce-overhead",
843+
fullgraph=True,
844+
dynamic=True,
814845
)
815846

816847
if compile_prefill:
@@ -849,7 +880,7 @@ def ffn_or_attn_only(mod, fqn):
849880
prompt = f"{B_INST} {prompt.strip()} {E_INST}"
850881
encoded = encode_tokens(tokenizer, prompt, bos=True, device=device)
851882

852-
if interactive and i >= 0:
883+
if interactive and i >= 0 and prefill_size is None:
853884
buffer = []
854885
period_id = tokenizer.encode(".")[0]
855886
done_generating = False
@@ -919,7 +950,7 @@ def callback(x):
919950
device_sync(device=device) # MKG
920951
t = time.perf_counter() - t0
921952

922-
if not interactive and demo_summarize_prompt is None:
953+
if not interactive and demo_summarize_prompt is None and prefill_size is None:
923954
tok_list = y[0].tolist()
924955
# truncate text after end of string token
925956
tokens = (

torchao/kernel/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
from torchao.kernel.bsr_triton_ops import bsr_dense_addmm
12
from torchao.kernel.intmm import int_scaled_matmul, safe_int_mm
23

34
__all__ = [
5+
"bsr_dense_addmm",
46
"safe_int_mm",
57
"int_scaled_matmul",
68
]

0 commit comments

Comments
 (0)