Skip to content

[WIP] 2:4 activation sparsity #2012

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 46 additions & 7 deletions benchmarks/benchmark_rowwise_scaled_linear_sparse_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torchao.sparsity.utils import create_semi_structured_tensor

dtype = torch.bfloat16
dtypeq_X = torch.float8_e5m2
dtypeq_X = torch.float8_e4m3fn
dtypeq_W = torch.float8_e4m3fn
device = torch.device("cuda")

Expand All @@ -25,7 +25,7 @@ def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3


def get_problem(m: int, n: int, k: int):
def get_problem_cutlass(m: int, n: int, k: int):
X_ref = torch.randn((m, k), dtype=dtype, device=device)
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)

Expand All @@ -44,31 +44,70 @@ def get_problem(m: int, n: int, k: int):

return (X_ref, W_ref), (Xq, X_scale, Wq_sparse, W_meta, W_scale, bias, out_dtype)

def get_problem_cusparselt(m: int, n: int, k: int):
X_ref = torch.randn((m, k), dtype=dtype, device=device)
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)


Xq = X_ref.to(dtypeq_W)
Wq = W_ref.to(dtypeq_W)

Wqs = torch._cslt_compress(Wq)


alg_id, split_k, split_k_one_kernel, _ = torch._C._cusparselt.mm_search(Wqs, Xq.t(), None, None, None, False)

return (Wqs, Xq.t(), None, None, dtype, False, alg_id, split_k, split_k_one_kernel)

def get_problem_scaled_mm(m: int, n: int, k: int):
X_ref = torch.randn((m, k), dtype=dtype, device=device)
W_ref = create_semi_structured_tensor(n, k, dtype=dtype).to(device)

X_aqt = _float8_cutlass_quant(X_ref, dtypeq_W)
W_aqt = _float8_cutlass_quant(W_ref, dtypeq_W)

Xq = X_aqt.tensor_impl.float8_data
Wq = W_aqt.tensor_impl.float8_data
X_scale = X_aqt.tensor_impl.scale.unsqueeze(0)
W_scale = W_aqt.tensor_impl.scale.unsqueeze(-1)

return (Wq, Xq.t(), W_scale, X_scale, None, None, dtype)


def benchmark(m: int, k: int, n: int):
ref_args, args = get_problem(m, n, k)
ref_args, args = get_problem_cutlass(m, n, k)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, *ref_args)
rowwise_scaled_linear_sparse_cutlass_f8f8_time = benchmark_microseconds(
rowwise_scaled_linear_sparse_cutlass_f8f8, *args
)

cslt_args = get_problem_cusparselt(m, n, k)
cusparselt_time = benchmark_microseconds(
torch._cslt_sparse_mm, *cslt_args
)

fp8_args = get_problem_scaled_mm(m, n, k)
fp8_time = benchmark_microseconds(torch._scaled_mm, *fp8_args)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"fp8_latency (ms)": fp8_time,
"rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms)": rowwise_scaled_linear_sparse_cutlass_f8f8_time,
"f8f8 speedup (d/s)": fp16_time
"cusparselt latency (ms)": cusparselt_time,
"f8f8 speedup (d/s)": fp8_time
/ rowwise_scaled_linear_sparse_cutlass_f8f8_time,
}


if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)
k_vals = (8192,)
n_vals = (8192,)

results = []
for m in tqdm([1 << i for i in range(10)]):
for m in tqdm([2048, 4096, 8192]):
for n, k in zip(n_vals, k_vals):
results.append(benchmark(m, k, n))

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
m,k,n,fp16_latency (ms),fp8_latency (ms),rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms),cusparselt latency (ms),f8f8 speedup (d/s)
2048,8192,8192,358.36800932884216,239.26399648189545,171.1360067129135,643.8400149345398,1.3980926695529887
4096,8192,8192,728.0960083007812,499.29600954055786,400.86400508880615,645.9199786186218,1.2455496208245123
8192,8192,8192,1560.3519678115845,988.9280200004578,912.992000579834,869.1520094871521,1.0831727105740219
4 changes: 4 additions & 0 deletions rowwise_scaled_linear_sparse_cutlass_time_results.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
m,k,n,fp16_latency (ms),fp8_latency (ms),rowwise_scaled_linear_sparse_cutlass_f8f8 latency (ms),cusparselt latency (ms),f8f8 speedup (d/s)
2048,8192,8192,358.94399881362915,242.78399348258972,169.63200271129608,635.2959871292114,1.4312393274976192
4096,8192,8192,760.1919770240784,499.61599707603455,402.8159976005554,629.8239827156067,1.2403082301896782
8192,8192,8192,1576.416015625,994.4959878921509,916.9920086860657,860.5120182037354,1.084519797852043
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ def get_extensions():
"-O3" if not debug_mode else "-O0",
"-t=0",
"-std=c++17",
"-w",
],
}

Expand Down Expand Up @@ -395,6 +396,11 @@ def get_extensions():
"to_sparse_semi_structured_cutlass_sm9x",
"to_sparse_semi_structured_cutlass_sm9x_f8.cu",
),
os.path.join(
extensions_cuda_dir,
"activation24",
"SparseSemiStructuredTile_cutlass.cu"
),
]
for dtypes in ["e4m3e4m3", "e4m3e5m2", "e5m2e4m3", "e5m2e5m2"]:
cutlass_90a_sources.append(
Expand All @@ -404,6 +410,7 @@ def get_extensions():
"rowwise_scaled_linear_sparse_cutlass_" + dtypes + ".cu",
)
)

sources = [s for s in sources if s not in cutlass_90a_sources]
else:
# Remove CUTLASS-based kernels from the sources list. An
Expand All @@ -416,6 +423,7 @@ def get_extensions():
)
sources = [s for s in sources if s not in cutlass_sources]

print("CUDA sources: ", sources)
ext_modules = []
if len(sources) > 0:
ext_modules.append(
Expand Down
137 changes: 137 additions & 0 deletions test/sparsity/test_activation24.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@

import torch
import torchao
import torch.nn.functional as F

from torchao.ops import to_sparse_semi_structured_cutlass_sm9x_f8
from torchao.quantization.quant_api import (
_float8_cutlass_quant,
_float8_cutlass_quant_sparse
)
torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = True

from torchao.sparsity.utils import create_semi_structured_tensor
from torch.sparse import to_sparse_semi_structured

from torch.testing._internal import common_utils

dtype = torch.float16
device = torch.device("cuda")
dtypeq_X = torch.float8_e4m3fn
dtypeq_W = torch.float8_e4m3fn
torch.set_printoptions(profile="full")
torch.set_printoptions(linewidth=10000)


torch.manual_seed(32)

# class TestActivation24(common_utils.TestCase):

# @common_utils.parametrize("pattern", [[1, 1, 0, 0], [1, 0, 1, 0], [1, 0, 0, 1], [0, 1, 1, 0], [0, 1, 0, 1], [0, 0, 1, 1]])
def test_correctness():
"""
Tests to see if the metadata packing format has changed between bf16 -> fp8, it looks like it's the same.
"""
# 238 in binary
W_ref_asdf = torch.Tensor([0, 0, 1, 1]).to(device=device, dtype=dtype).tile((32, 64// 4)).contiguous()
W_subclass_sparse = to_sparse_semi_structured(W_ref_asdf)

garbanzo_beans = W_subclass_sparse.meta.view(torch.uint8).tolist()

pattern = [1, 1, 0, 0] # 68
for i in range(32):
for j in range(8):
W_ref = W_ref_asdf.clone()
num_per_tb = 8
W_ref[i, j*num_per_tb:(j+1)*num_per_tb] = torch.Tensor(pattern).to(device=device, dtype=dtype).tile((1, 2)).contiguous()

# W_meta = to_sparse_semi_structured(W_ref).meta.view(torch.uint8)
W_quant_func = _float8_cutlass_quant_sparse
W_aqt = W_quant_func(W_ref, dtypeq_W)
W_meta = W_aqt.tensor_impl.meta
W_meta = W_meta[:32, :8]

indicies = (W_meta == 68).nonzero()

for (r, c) in indicies:
garbanzo_beans[r][c] = f"a[{i:2d}, {j*num_per_tb:2d}:{(j+1)*num_per_tb:2d}]"

# from pprint import pprint
for line in garbanzo_beans:
print(line[:4])
print(line[4:])

assert False
# torch.testing.assert_close(W_meta, W_subclass_sparse.meta.view(torch.uint8))


def test_fast_rowwise_packing():
# W_ref = create_semi_structured_tensor(128, 128)
W_ref = create_semi_structured_tensor(128, 128, dtype=dtype).to(device)
W_subclass_sparse = to_sparse_semi_structured(W_ref)
# print(W_ref)


# Test packed
vc_mine = torch.unique(packed, return_counts=True)
vc_ref = torch.unique(W_subclass_sparse.packed, return_counts=True)
print(packed[:16, :16])
print(W_subclass_sparse.packed[:16, :16])
torch.testing.assert_close(vc_mine, vc_ref)

# Test meta
# vc_mine = torch.unique(packed_meta, return_counts=True)
# vc_ref = torch.unique(W_subclass_sparse.meta, return_counts=True)
# torch.testing.assert_close(vc_mine, vc_ref)


def test_packed_fp8():
# W_ref = create_semi_structured_tensor(128, 128, dtype=torch.float8_e4m3fn).to(device)
W_ref = torch.Tensor([[2, 3, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 8, 0, 0],
[0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 5, 6, 0, 0, 7, 8],
[1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0],
[0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8]]).to(device=device).tile((128// 4, 128// 16)).contiguous().to(torch.float8_e4m3fn)
packed_reference, meta_reference = to_sparse_semi_structured_cutlass_sm9x_f8(W_ref)
packed, packed_meta = torch.ops.torchao.sparse_semi_structured_tile.default(W_ref, "", True)

torch.testing.assert_close(packed.to(torch.float16), packed_reference.to(torch.float16))


def test_meta_fp8_fixed():
torch.manual_seed(123)
W_ref = create_semi_structured_tensor(128, 128, dtype=torch.float8_e4m3fn).to(device)
# W_ref = torch.Tensor([[2, 3, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 8, 0, 8, 0],
# [0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 5, 6, 0, 0, 7, 8],
# [1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0],
# [0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8]]).to(device=device).tile((128// 4, 128// 16)).contiguous().to(torch.float8_e4m3fn)
packed_reference, meta_reference = to_sparse_semi_structured_cutlass_sm9x_f8(W_ref)
packed, packed_meta = torch.ops.torchao.sparse_semi_structured_tile.default(W_ref, "", True)

vc_mine = torch.unique(packed_meta, return_counts=True)
vc_ref = torch.unique(meta_reference, return_counts=True)
# print(vc_mine)
# print(packed_meta[:16, :16])
# print(meta_reference[:16, :16])

# print(packed_meta - meta_reference)
# torch.testing.assert_close(packed, packed_reference)
torch.testing.assert_close(packed_meta, meta_reference)


# common_utils.instantiate_parametrized_tests(TestActivation24)
#

# pprint(garbanzo_beans)



# print(W_meta)

# breakpoint()
# print(W_subclass_sparse.meta.view(torch.uint8) == W_meta)
# print("CUTLASS REFERENCE")
# print(W_meta)
# print(W_meta.shape)
# print(packed_meta)
# packed, packed_meta, packed_t, packed_t_meta , bitmask = torch.ops.torchao.sparse_semi_structured_tile.default(W_ref, "", True)
# print(W_meta)
111 changes: 111 additions & 0 deletions torchao/csrc/cuda/activation24/ComputeSparseTile.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
#pragma once

#include "SparseSemiStructuredPack.cuh"
#include "StaticSort.h"
#include <cutlass/bfloat16.h>
#include <cutlass/half.h>
#include <cutlass/platform/platform.h>
#include <cutlass/version.h>
// Basic FP8 type definitions
#include <cutlass/numeric_conversion.h>
#include <cutlass/numeric_types.h>

// // For FP8 E4M3 format (4 exponent bits, 3 mantissa bits)
// #include <cutlass/float8_e4m3.h>

// // For FP8 E5M2 format (5 exponent bits, 2 mantissa bits)
// #include <cutlass/float8_e5m2.h>

// Given 4x4 values, computes the selected indices that will remain after 2:4
// sparsification, as a bitmask.
// NOTE: Algorithms might select LESS than 8 values in total in some cases.

namespace torchao {

template <typename Element, typename Pointwise> struct TileValueOrderedT {
union {
struct {
Element value;
uint2b_t inner_index;
uint2b_t outer_index;
} parts;
uint32_t raw;
};
CUTLASS_DEVICE bool
operator<(TileValueOrderedT<Element, Pointwise> const &other) const {
return Pointwise::apply(parts.value) < Pointwise::apply(other.parts.value);
}
CUTLASS_DEVICE TileValueOrderedT() {}
};

// Operations that we can apply to rank the values
struct IdentityOp {
template <typename T> static T CUTLASS_HOST_DEVICE apply(T const &x) {
return x;
}
};

// Given 1x4 values (a row), computes the selected indices that will remain
// after 2:4 sparsification, as a bitmask. We have 1 constraint: (1) Exactly 2
// values per row ALGO: We use a simple algorithm that selects the 2 largest
// values in the row. NOTE: RF are not indexable, so we shouldn't rely on
// indexing
// values at any point, otherwise they will be stored in local memory.
template <typename Op = IdentityOp> struct LargestValuesRowwise {
template <typename T> static CUTLASS_DEVICE T outOfBoundsFillValue() {
return -cutlass::platform::numeric_limits<T>::infinity();
}

template <typename Tile1x16Accessor>
CUTLASS_DEVICE Indices1x16 operator()(Tile1x16Accessor values) {
using TileValueOrdered =
TileValueOrderedT<typename Tile1x16Accessor::Element, Op>;
using TileValuesFragment = cutlass::Array<TileValueOrdered, 4 * 4>;

Indices1x16 indices;
TileValuesFragment values_ordered;
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
CUTLASS_PRAGMA_UNROLL
for (int j = 0; j < 4; ++j) {
TileValueOrdered &v = values_ordered[i * 4 + j];
v.parts.value = values.at(0, i * 4 + j).get();
v.parts.inner_index = uint2b_t(j);
v.parts.outer_index = uint2b_t(i);
}
}
// Use a sorting network (aka without branches) to avoid
// warp divergence
StaticSort<TileValuesFragment::kElements> sorter;
sorter(values_ordered);

// bitmask to store how many we have selected on a given row
// 0 selected: (numPerRow >> 2*row) = 00 (0)
// 1 selected: (numPerRow >> 2*row) = 01 (1)
// 2 selected: (numPerRow >> 2*row) = 11 (3)
uint32_t numPer1x4Strip = 0;
indices = 0;

// Take as many as we can, starting with the largest values
CUTLASS_PRAGMA_UNROLL
for (int i = values_ordered.size() - 1; i >= 0; i--) {
auto &e = values_ordered[i];

uint32_t rcount = uint2b_t(numPer1x4Strip >> 2 * e.parts.outer_index);
// NOTE: This is more efficient (yet equivalent) to:
// `rcount != 3 && ccount != 3`
bool selected = rcount <= 2;
indices |= selected << (e.parts.inner_index + 4 * e.parts.outer_index);

numPer1x4Strip |= (rcount + selected) << 2 * e.parts.outer_index;
}
return indices;
}
};

template <typename T> void named_algorithms(T callback) {
// default one
callback(LargestValuesRowwise<IdentityOp>(), "");
}

} // namespace torchao
Loading
Loading