Skip to content

2:4 activation sparsity packing kernels #2012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
f62745f
wip to get sample op working
jcaip Mar 24, 2025
02b65de
test
jcaip Mar 26, 2025
cf503aa
wip
jcaip Mar 31, 2025
5a5ca43
update kernel for metadat
jcaip Mar 31, 2025
9cbfed0
wip
jcaip Apr 3, 2025
268b74d
almost working!
jcaip Apr 14, 2025
1a933f9
working but not on random inputs
jcaip Apr 14, 2025
f3b67b0
cleaned up cuda files
jcaip Apr 14, 2025
9f4346d
packing is working now!
jcaip Apr 15, 2025
0b62fba
update
jcaip Apr 15, 2025
f505851
Merge branch 'main' into jcaip/actiation24
jcaip Apr 15, 2025
e8a5d5b
wip
jcaip Apr 16, 2025
b168814
updated to not sort 1x16 at a time
jcaip Apr 16, 2025
559ca90
checkpoint
jcaip Apr 16, 2025
91f18da
removed a lot of templating to try and merge index creation and packing
jcaip Apr 16, 2025
a5b13a0
srelu speedups
jcaip Apr 18, 2025
6957add
wip integrating xformers kernels
jcaip Apr 18, 2025
0bf05ec
update namespare
jcaip Apr 19, 2025
4e91722
remove extra CUTLASS files
jcaip Apr 19, 2025
65d3c1e
more cleanup
jcaip Apr 19, 2025
88bec35
clean up op registration
jcaip Apr 19, 2025
3d4aa93
added test for srelu linear
jcaip Apr 22, 2025
a5b9cab
cleanup
jcaip Apr 22, 2025
645576b
updated benchmarks + cleaned up prototype folder some more
jcaip Apr 22, 2025
35263ae
added ruff
jcaip Apr 22, 2025
a31dcd5
fixed setup
jcaip Apr 22, 2025
7cdd43a
Merge branch 'main' into jcaip/actiation24
jcaip Apr 29, 2025
9ff58a4
fix ruff for utils
jcaip May 12, 2025
f1e9eb1
ruff fix
jcaip May 12, 2025
3172f09
Merge remote-tracking branch 'refs/remotes/origin/jcaip/actiation24' …
jcaip May 12, 2025
46b19e8
Merge branch 'main' into jcaip/actiation24
jcaip May 12, 2025
5b99cd8
fix ruff
jcaip May 12, 2025
126166f
ruff format
jcaip May 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions benchmarks/benchmark_e2e_fp8_sparse_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import pandas as pd
import torch
from torch import nn
from tqdm import tqdm
from triton.testing import do_bench

from torchao.prototype.sparsity.activation.srelu_linear import (
SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig,
)
from torchao.prototype.sparsity.activation.utils import SquaredReLU
from torchao.quantization import (
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
Float8MMConfig,
PerRow,
quantize_,
)


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def benchmark(num_tokens, hidden_size=8192, intermediate_size=8192):
ffn_ref = (
nn.Sequential(
nn.Linear(hidden_size, intermediate_size, bias=False),
SquaredReLU(),
nn.Linear(intermediate_size, hidden_size, bias=False),
)
.to(torch.bfloat16)
.cuda()
)

input_tensor = torch.randn(num_tokens, hidden_size).to(torch.bfloat16).cuda()
fp16_time = benchmark_microseconds(ffn_ref, input_tensor)

# bf16
ffn_clone = (
nn.Sequential(
nn.Linear(hidden_size, intermediate_size, bias=False),
SquaredReLU(),
nn.Linear(intermediate_size, hidden_size, bias=False),
)
.to(torch.bfloat16)
.cuda()
)
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
fp16_c_time = benchmark_microseconds(ffn_clone, input_tensor)

# fp8
ffn_clone = (
nn.Sequential(
nn.Linear(hidden_size, intermediate_size, bias=False),
SquaredReLU(),
nn.Linear(intermediate_size, hidden_size, bias=False),
)
.to(torch.bfloat16)
.cuda()
)
quantize_(
ffn_clone,
Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True)
),
)
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
fp8_c_time = benchmark_microseconds(ffn_clone, input_tensor)

# fp8 sparse
ffn_clone = (
nn.Sequential(
nn.Linear(hidden_size, intermediate_size, bias=False),
SquaredReLU(),
nn.Linear(intermediate_size, hidden_size, bias=False),
)
.to(torch.bfloat16)
.cuda()
)
quantize_(ffn_clone, Float8DynamicActivationFloat8SemiSparseWeightConfig())
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
fp8_c_sparse_time = benchmark_microseconds(ffn_clone, input_tensor)

# activation fp8 sparse
ffn_clone = (
nn.Sequential(
nn.Linear(hidden_size, intermediate_size, bias=False),
# no Squared RELU since it will be fused into the second linear
nn.Linear(intermediate_size, hidden_size, bias=False),
)
.to(torch.bfloat16)
.cuda()
)
quantize_(
ffn_clone[0],
Float8DynamicActivationFloat8WeightConfig(
granularity=PerRow(), mm_config=Float8MMConfig(use_fast_accum=True)
),
)
quantize_(
ffn_clone,
SRELUFloat8SemiSparseDynamicActivationFloat8WeightConfig(),
filter_fn=lambda mod, fqn: "1" in fqn,
)
ffn_clone.forward = torch.compile(ffn_clone.forward, fullgraph=True)
fp8_c_activation_sparse_time = benchmark_microseconds(ffn_clone, input_tensor)

return {
"num_tokens": num_tokens,
"bf16_latency (us)": fp16_time,
"bf16_c_latency (us)": fp16_c_time,
"fp8_c_time (us)": fp8_c_time,
"fp8_c_sparse_time (us)": fp8_c_sparse_time,
"fp8_c_activation_sparse_time (us)": fp8_c_activation_sparse_time,
"speedup": fp8_c_time / fp8_c_activation_sparse_time,
}


if __name__ == "__main__":
with torch.no_grad():
results = []
for num_tokens in tqdm([64, 128, 256, 512, 1024, 2048, 4096]):
results.append(benchmark(num_tokens))
torch.compiler.reset()

df = pd.DataFrame(results)
df.to_csv("e2e_fp8_sparse.csv", index=False)
print(df.to_markdown(index=False))
54 changes: 46 additions & 8 deletions benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchao.sparsity.utils import create_semi_structured_tensor

dtype = torch.bfloat16
dtypeq_X = torch.float8_e5m2
dtypeq_X = torch.float8_e4m3fn
dtypeq_W = torch.float8_e4m3fn
device = torch.device("cuda")

Expand All @@ -25,7 +25,7 @@ def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def get_problem(m: int, n: int, k: int):
def get_problem_cutlass(m: int, n: int, k: int):
X_ref = torch.randn((m, k), dtype=dtype, device=device)
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)

Expand All @@ -45,30 +45,68 @@ def get_problem(m: int, n: int, k: int):
return (X_ref, W_ref), (Xq, X_scale, Wq_sparse, W_meta, W_scale, bias, out_dtype)


def get_problem_cusparselt(m: int, n: int, k: int):
X_ref = torch.randn((m, k), dtype=dtype, device=device)
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)

Xq = X_ref.to(dtypeq_W)
Wq = W_ref.to(dtypeq_W)

Wqs = torch._cslt_compress(Wq)

alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(
Wqs, Xq.t(), None, None, None, False
)

return (Wqs, Xq.t(), None, None, dtype, False, alg_id, split_k, split_k_one_kernel)


def get_problem_scaled_mm(m: int, n: int, k: int):
X_ref = torch.randn((m, k), dtype=dtype, device=device)
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)

X_aqt = _float8_cutlass_quant(X_ref, dtypeq_W)
W_aqt = _float8_cutlass_quant(W_ref, dtypeq_W)

Xq = X_aqt.tensor_impl.float8_data
Wq = W_aqt.tensor_impl.float8_data
X_scale = X_aqt.tensor_impl.scale.unsqueeze(0)
W_scale = W_aqt.tensor_impl.scale.unsqueeze(-1)

return (Wq, Xq.t(), W_scale, X_scale, None, None, dtype)


def benchmark(m: int, k: int, n: int):
ref_args, args = get_problem(m, n, k)
ref_args, args = get_problem_cutlass(m, n, k)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, *ref_args)
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds(
rowwise_scaled_linear_sparse_cutlass_f8f8, *args
)

cslt_args = get_problem_cusparselt(m, n, k)
cusparselt_time = benchmark_microseconds(torch._cslt_sparse_mm, *cslt_args)

fp8_args = get_problem_scaled_mm(m, n, k)
fp8_time = benchmark_microseconds(torch._scaled_mm, *fp8_args)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"fp8_latency (ms)": fp8_time,
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time,
"f8f8 speedup (d/s)": fp16_time
/ rowwise_scaled_linear_sparse_cutlass_f8f8_time,
"cusparselt latency (ms)": cusparselt_time,
"f8f8 speedup (d/s)": fp8_time / rowwise_scaled_linear_sparse_cutlass_f8f8_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)
k_vals = (8192,)
n_vals = (8192,)

results = []
for m in tqdm([1 << i for i in range(10)]):
for m in tqdm([2048, 4096, 8192]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n))

Expand Down
124 changes: 124 additions & 0 deletions benchmarks/benchmark_sparse_conversion_cutlass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import pandas as pd
import torch
from triton.testing import do_bench

from torchao.ops import (
to_sparse_semi_structured_cutlass_sm9x_f8,
)
from torchao.quantization.quant_api import (
_float8_cutlass_quant,
_float8_cutlass_quant_sparse,
)
from torchao.sparsity.utils import create_semi_structured_tensor

dtype = torch.bfloat16
dtypeq_X = torch.float8_e4m3fn
dtypeq_W = torch.float8_e4m3fn
device = torch.device("cuda")


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def get_problem_cutlass(m: int, n: int, k: int):
X_ref = torch.randn((m, k), dtype=dtype, device=device)
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)

X_quant_func = _float8_cutlass_quant
W_quant_func = _float8_cutlass_quant_sparse
X_aqt = X_quant_func(X_ref, dtypeq_X)
W_aqt = W_quant_func(W_ref, dtypeq_W)

Xq = X_aqt.tensor_impl.float8_data
X_scale = X_aqt.tensor_impl.scale
Wq_sparse = W_aqt.tensor_impl.sparse
W_meta = W_aqt.tensor_impl.meta
W_scale = W_aqt.tensor_impl.scale
bias = None
out_dtype = dtype

return (X_ref, W_ref), (Xq, X_scale, Wq_sparse, W_meta, W_scale, bias, out_dtype)


def get_problem_cusparselt(m: int, n: int, k: int):
X_ref = torch.randn((m, k), dtype=dtype, device=device)
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)

Xq = X_ref.to(dtypeq_W)
Wq = W_ref.to(dtypeq_W)

Wqs = torch._cslt_compress(Wq)

alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(
Wqs, Xq.t(), None, None, None, False
)

return (Wqs, Xq.t(), None, None, dtype, False, alg_id, split_k, split_k_one_kernel)


def get_problem_scaled_mm(m: int, n: int, k: int):
X_ref = torch.randn((m, k), dtype=dtype, device=device)
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)

X_aqt = _float8_cutlass_quant(X_ref, dtypeq_W)
W_aqt = _float8_cutlass_quant(W_ref, dtypeq_W)

Xq = X_aqt.tensor_impl.float8_data
Wq = W_aqt.tensor_impl.float8_data
X_scale = X_aqt.tensor_impl.scale.unsqueeze(0)
W_scale = W_aqt.tensor_impl.scale.unsqueeze(-1)

return (Wq, Xq.t(), W_scale, X_scale, None, None, dtype)


def benchmark(m, k):
torch.manual_seed(123)
W_ref = create_semi_structured_tensor(m, k, dtype=torch.float8_e4m3fn).cuda()

# packed, meta = torch.ops.torchao.sparse_semi_structured_tile.default(W_ref, "", True)
cutlass_reference_args = (W_ref,)
cutlass_custom_args = (W_ref, "", True)

cutlass_reference_compression_time = benchmark_microseconds(
to_sparse_semi_structured_cutlass_sm9x_f8, *cutlass_reference_args
)
cutlass_custom_compression_time = benchmark_microseconds(
torch.ops.torchao.sparse_semi_structured_tile.default, *cutlass_custom_args
)

return {
"cutlass_reference (ms)": cutlass_reference_compression_time,
"cutlass_custom (ms)": cutlass_custom_compression_time,
}


def profile():
torch.manual_seed(123)
W_ref = create_semi_structured_tensor(8192, 8192, dtype=torch.float8_e4m3fn).cuda()

# clear cache
new_val = torch.empty(10000, 10000, device="cuda")
new_val[:, :] = 0

packed, meta = torch.ops.torchao.sparse_semi_structured_tile.default(
W_ref, "", True
)


if __name__ == "__main__":
results = []
for m in (2048, 4096, 8192):
results.append(benchmark(m, 8192))

df = pd.DataFrame(results)
df.to_csv("rowwise_scaled_linear_sparse_cutlass_time_results.csv", index=False)
print(df.to_markdown(index=False))

# print("PROFILING")
# profile()
8 changes: 8 additions & 0 deletions e2e_fp8_sparse.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
num_tokens,bf16_latency (us),bf16_c_latency (us),fp8_c_time (us),fp8_c_sparse_time (us),fp8_c_activation_sparse_time (us),speedup
64,166.81599617004395,163.03999722003937,103.00800204277039,74.30399954319,102.81600058078766,1.0018674278409796
128,156.25600516796112,151.5199989080429,99.93600100278854,75.45600086450577,102.04800218343735,0.9793038458817415
256,172.28800058364868,159.58400070667267,114.07999694347382,82.43200182914734,111.07199639081955,1.0270815385551393
512,218.87999773025513,204.6079933643341,144.0960019826889,114.56000059843063,139.48799669742584,1.0330351384661336
1024,394.4000005722046,392.5440013408661,251.10399723052979,196.4160054922104,227.90400683879852,1.1017972027501084
2048,764.6080255508423,734.8160147666931,480.70400953292847,381.1520040035248,426.68798565864563,1.1265937305239622
4096,1658.8159799575806,1623.5840320587158,901.3440012931824,779.0079712867737,843.392014503479,1.0687129896811043
4 changes: 4 additions & 0 deletions rowwise_scaled_linear_sparse_cutlass_time_results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
m,k,n,fp16_latency (ms),fp8_latency (ms),rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms),cusparselt latency (ms),f8f8 speedup (d/s)
2048,8192,8192,345.7919955253601,243.13600361347198,159.7760021686554,634.2080235481262,1.5217304245528933
4096,8192,8192,756.3199996948242,500.2880096435547,363.647997379303,628.7999749183655,1.3757480124982768
8192,8192,8192,1433.568000793457,982.5279712677002,895.3920006752014,859.935998916626,1.0973160029649482
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def get_extensions():
"to_sparse_semi_structured_cutlass_sm9x",
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
),
os.path.join(extensions_cuda_dir, "activation24", "sparsify24.cu"),
]
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
cutlass_90a_sources.append(
Expand Down
Loading
Loading