Skip to content

Jcaip/llm bsr #1601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 46 additions & 29 deletions benchmarks/benchmark_gpu_sparsity.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from typing import Callable, List, Optional, Tuple

import pandas as pd
import torch
Expand All @@ -11,7 +12,9 @@
create_block_sparse_tensor,
create_semi_structured_tensor,
)
from torchao.utils import benchmark_model
import torch.utils.benchmark as benchmark

from torchao.sparsity.blocksparse import BlockSparseTensor

torch.set_printoptions(
precision=2,
Expand All @@ -27,6 +30,17 @@ def benchmark_model_with_warmup(func, x, N_WARMUP=3):
benchmark_model(func, N_WARMUP, device_type="cuda")
return benchmark_model(func, 10, device_type="cuda")

def benchmark_torch_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
# warmup
for _ in range(1):
func(*args, **kwargs)
# t0 = benchmark.Timer(
# stmt="func(*args, **kwargs)",
# globals={"args": args, "kwargs": kwargs, "func": func},
# )
# return t0.adaptive_autorange(min_run_time=0.1).median * 1e6
return 1


def run_gpu_sparse_benchmark(m, k, n, args):
with torch.no_grad():
Expand All @@ -43,7 +57,8 @@ def run_gpu_sparse_benchmark(m, k, n, args):
A = create_block_sparse_tensor(
m, k, args.block_size, args.sparsity_level, dtype
)
A_sparse = A.to_sparse_bsr(blocksize=args.block_size)
# A_sparse = A.to_sparse_bsr(blocksize=args.block_size)
A_sparse = BlockSparseTensor.from_dense(A, args.block_size).detach()
# BSR kernel tuning
if args.bsr_autotune:
print("Tuning kernel params")
Expand All @@ -61,13 +76,16 @@ def run_gpu_sparse_benchmark(m, k, n, args):
raise ValueError(f"Unknown sparsity: {args.sparsity}")

if args.eval_fn == "linear":
b = torch.randn(m, dtype=dtype).cuda()
# b = torch.randn(m, dtype=dtype).cuda()
b = None

# can't use lambda
def dense_func():
@torch.compile(mode="max-autotune")
def dense_func(x):
return F.linear(x, A, b)

def sparse_func():
@torch.compile(mode="max-autotune")
def sparse_func(x):
return F.linear(x, A_sparse, b)

elif args.eval_fn == "mm":
Expand Down Expand Up @@ -101,20 +119,17 @@ def sparse_func():
else:
raise ValueError(f"Unknown eval_fn: {args.eval_fn}")

dense_time = benchmark_model_with_warmup(dense_func, "dense.json.gz")
sparse_time = benchmark_model_with_warmup(sparse_func, "sparse.json.gz")

dense_func_c = torch.compile(dense_func, mode="max-autotune")
dense_time_c = benchmark_model_with_warmup(
dense_func_c, "dense_compile.json.gz"
)

sparse_func_c = torch.compile(sparse_func, mode="max-autotune")
sparse_time_c = benchmark_model_with_warmup(
sparse_func_c, "sparse_compile.json.gz"
)
# print(x)
# print(A)
# print(A_sparse.crow_indices())
# print(A_sparse.col_indices())
# print(A_sparse.values())
dense_time, sparse_time = 0, 0
dense_time_c, sparse_time_c = 1, 1

torch._dynamo.reset()
#WARMUP
# dense_time_c = benchmark_torch_function_in_microseconds(dense_func, x)
sparse_time_c = benchmark_torch_function_in_microseconds(sparse_func, x)

return {
"test_function": args.eval_fn,
Expand All @@ -126,8 +141,7 @@ def sparse_func():
"dense": dense_time,
"dense_c": dense_time_c,
"sparse_c": sparse_time_c,
"speedup (d/s)": min(dense_time, dense_time_c)
/ min(sparse_time, sparse_time_c),
"speedup (d/s)": dense_time_c / sparse_time_c,
}


Expand Down Expand Up @@ -200,15 +214,18 @@ def sparse_func():
)
elif args.mode == "llama3-8b-w":
mm_shapes = [
(16, 4096, 11008),
(16, 4096, 4096),
(16, 11008, 4096),
(4096, 4096, 11008),
(4096, 4096, 4096),
(4096, 11008, 4096),
(8192, 4096, 11008),
(8192, 4096, 4096),
(8192, 11008, 4096),
# (32, 32, 16),
(4096, 14336, 1),
# (14336, 4096, 1),
# (14336, 4096, 1),
# (11008, 4096, 16),
# (16, 4096, 4096),
# (4096, 4096, 11008),
# (4096, 4096, 4096),
# (4096, 11008, 4096),
# (8192, 4096, 11008),
# (8192, 4096, 4096),
# (8192, 11008, 4096),
]
results = (
run_gpu_sparse_benchmark(m, k, n, args) for (m, k, n) in tqdm(mm_shapes)
Expand Down
41 changes: 41 additions & 0 deletions test/sparsity/test_bsr_sum_prod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch

import triton
import triton.language as tl
import pdb

from torchao.sparsity.utils import create_block_sparse_tensor
from torchao.sparsity.blocksparse import BlockSparseTensor
from torch.library import wrap_triton, triton_op



@torch.compile(dynamic=False, fullgraph=True)
def test(w, x):
b = x.unsqueeze(0)
out= (torch.mul(w, b)).sum(dim=1)
return out

torch.set_printoptions(profile='full', linewidth=100000)
torch.manual_seed(0)
size = 98432

with torch.no_grad():
create_block_sparse_tensor = torch.compiler.disable(create_block_sparse_tensor)
a = create_block_sparse_tensor(32, 32, 16, 0.5, torch.bfloat16).cuda() * torch.randn(32, 32, dtype=torch.bfloat16).cuda()
a[:16, :16] *= 4
a[16:, 16:] *= 4
a[16:, :16] *= 2
a[:16, 16:] *= 1
# print(a)
# print(x)
w = BlockSparseTensor.from_dense(a, 16).detach()
x = torch.arange(32).reshape((32, 1)).to(torch.bfloat16).cuda()
# expected= test(a.unsqueeze(2), x)
# print(expected)
# print("strides", w.unsqueeze(2).stride())
# print("strides", w.stride())
out = test(w.unsqueeze(2), x)
# print(out)

# torch.testing.assert_close(out, expected, rtol=1e-2, atol=1e-2)
Loading
Loading