Skip to content
102 changes: 102 additions & 0 deletions benchmarks/benchmark_blockwise_scaled_linear_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pandas as pd
import torch
from tqdm import tqdm
from triton.testing import do_bench

from torchao.float8.float8_utils import compute_error
from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
fp8_blockwise_act_quant,
fp8_blockwise_weight_quant,
)
from torchao.quantization.quant_api import (
int8_dynamic_activation_int4_weight,
quantize_,
)


def benchmark_microseconds(f, *args):
return do_bench(lambda: f(*args), return_mode="median") * 1e3

def get_rowwise_problem(m: int, n: int, k: int):
dev = torch.device("cuda")
A = torch.randint(-128, 127, (m, k), dtype=torch.int8, device=dev)
A_scale = torch.randn((m,), dtype=torch.half, device=dev)
B = torch.randint(
-128, 127, size=(n, 4 * k // 8), dtype=torch.int8, device=dev
)
B_scale = torch.randn((n,), dtype=torch.half, device=dev)
C = None

return A, A_scale, B, B_scale, C

def get_blockwise_problem(m: int, n: int, k: int, block_size: int):
assert n % block_size == 0 and k % block_size == 0, "N and K dims must be divisible by block_size"
dev = torch.device("cuda")
A = (448.0 * (2 * torch.rand(m, k, device=dev) - 1)).to(torch.float8_e4m3fn)
A_scale = torch.randn((m, k // block_size), dtype=torch.half, device=dev)
B = (448.0 * (2 * torch.rand(n, k, device=dev) - 1)).to(torch.float8_e4m3fn)
B_scale = torch.randn((n // block_size, k // block_size), dtype=torch.half, device=dev)

return A, A_scale, B, B_scale

def benchmark(m: int, k: int, n: int, block_size: int):
# Speed benchmark
dev = torch.device("cuda")
A_ref = torch.randn((m, k), dtype=torch.half, device=dev)
B_ref = torch.randn((n, k), dtype=torch.half, device=dev)
fp16_time = benchmark_microseconds(torch.nn.functional.linear, A_ref, B_ref)

A, A_scale, B, B_scale, C = get_rowwise_problem(m, n, k)
rowwise_scaled_linear_cutlass_s8s4_time = benchmark_microseconds(
rowwise_scaled_linear_cutlass_s8s4, A, A_scale, B, B_scale, C
)

A, A_scale, B, B_scale = get_blockwise_problem(m, n, k, block_size)
blockwise_fp8_gemm_time = benchmark_microseconds(
blockwise_fp8_gemm, A, A_scale, B, B_scale
)

# Precision benchmark
lin = torch.nn.Linear(k, n, False, dev, torch.half)
A = torch.randn((m, k), dtype=torch.half, device=dev)
W = lin.weight
output = A @ W.T

A_q, A_s = fp8_blockwise_act_quant(A, block_size)
W_q, W_s = fp8_blockwise_weight_quant(W, block_size)
output_blockwise_quant = blockwise_fp8_gemm(A_q, A_s, W_q, W_s)

quantize_(lin, int8_dynamic_activation_int4_weight())
output_rowwise_quant = lin(A)

error_rowwise_quant = compute_error(output, output_rowwise_quant)
error_blockwise_quant = compute_error(output, output_blockwise_quant)

return {
"m": m,
"k": k,
"n": n,
"fp16_latency (ms)": fp16_time,
"rowwise_scaled_linear_cutlass_s8s4 latency (ms)": rowwise_scaled_linear_cutlass_s8s4_time,
"rowwise s8s4 speedup (d/s)": fp16_time / rowwise_scaled_linear_cutlass_s8s4_time,
"blockwise_fp8_gemm latency (ms)": blockwise_fp8_gemm_time,
"blockwise fp8 speedup (d/s)": fp16_time / blockwise_fp8_gemm_time,
"error_rowwise_quant (dB)": error_rowwise_quant,
"error_blockwise_quant (dB)": error_blockwise_quant
}

if __name__ == "__main__":
k_vals = (8192, 8192, 8192, 28672)
n_vals = (8192, 10240, 57344, 8192)
block_size_vals = (128, 128, 128, 128)

results = []
for m in tqdm([1 << i for i in range(10)]):
for n, k, block_size in zip(n_vals, k_vals, block_size_vals):
results.append(benchmark(m, k, n, block_size))

df = pd.DataFrame(results)
df.to_csv("blockwise_scaled_linear_triton_results.csv", index=False)
print(df.to_markdown(index=False))
7 changes: 5 additions & 2 deletions benchmarks/float8/bench_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,13 @@ def run(
if scaling_granularity == ScalingGranularity.TENSORWISE:
scale_a = torch.tensor([1.0], device=device)
scale_b = torch.tensor([1.0], device=device)
else:
assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported"
elif scaling_granularity == ScalingGranularity.AXISWISE:
scale_a = torch.ones(M, 1, device=device)
scale_b = torch.ones(1, N, device=device)
else:
assert scaling_granularity == ScalingGranularity.BLOCKWISE, "unsupported"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is benchmarking torch._scaled_mm which does not support blockwise scaling, is this change intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unintended, but I will rework this PR. There were some details that I had missed when I initially worked on it.

scale_a = torch.ones(M, N, device=device)
scale_b = torch.ones(M, N, device=device)

def do_matmul(A, B):
nonlocal scale_a
Expand Down
25 changes: 21 additions & 4 deletions benchmarks/float8/float8_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,14 +354,30 @@ def run(
m_fp8_dyn_axs = torch.compile(m_fp8_dyn_axs)
fp8_dyn_axs_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_axs, x)

# get the lw recipe scaling gpu kernel time
# get the float8 dynamic blockwise scaling gpu kernel time
torch._dynamo.reset()
config = recipe_name_to_linear_config(Float8LinearRecipeName.ALL_BLOCKWISE)
m_fp8_dyn_blk = convert_to_float8_training(copy.deepcopy(m_orig), config=config)
m_fp8_dyn_blk = torch.compile(m_fp8_dyn_blk)
fp8_dyn_blk_time_actual_s = get_gpu_kernel_time(m_fp8_dyn_blk, x)

# get the lw_axs recipe scaling gpu kernel time
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP)
# m_fp8_lw = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw = torch.compile(m_fp8_lw)
# fp8_lw_time_actual_s = get_gpu_kernel_time(m_fp8_lw, x)
# m_fp8_lw_axs = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw_axs = torch.compile(m_fp8_lw_axs)
# fp8_lw_axs_time_actual_s = get_gpu_kernel_time(m_fp8_lw_axs, x)

# get the lw_blk recipe scaling gpu kernel time
# TODO(future PR): enable below once basic performance issues
# are fixed
# torch._dynamo.reset()
# config = recipe_name_to_linear_config(Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP)
# m_fp8_lw_blk = convert_to_float8_training(m_orig, config=config)
# m_fp8_lw_blk = torch.compile(m_fp8_lw_blk)
# fp8_lw_blk_time_actual_s = get_gpu_kernel_time(m_fp8_lw_blk, x)

results.append(
[
Expand All @@ -382,6 +398,7 @@ def run(
fp8_dyn_time_actual_s,
fp8_del_time_actual_s,
fp8_dyn_axs_time_actual_s,
fp8_dyn_blk_time_actual_s,
# fp8_lw_time_actual_s,
bf16_time_actual_s / fp8_dyn_time_actual_s,
bf16_time_actual_s / fp8_del_time_actual_s,
Expand Down
61 changes: 61 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from torchao.float8.float8_python_api import addmm_float8_unwrapped
from torchao.float8.float8_scaling_utils import (
get_maybe_axiswise_dim,
get_maybe_blockwise_size,
hp_tensor_to_float8_dynamic,
)
from torchao.float8.float8_tensor import (
Expand Down Expand Up @@ -178,6 +179,22 @@ def test_axiswise_dynamic_cast(self, shape, axiswise_dim):
sqnr = compute_error(a, a_dq)
assert sqnr >= 25.0

@pytest.mark.parametrize("shape", [(8, 16), (4, 8, 16), (2, 4, 8, 16)])
@pytest.mark.parametrize("blockwise_size", [4])
def test_blockwise_dynamic_cast(self, shape, blockwise_size):
a = torch.randn(*shape, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
scaling_granularity=ScalingGranularity.BLOCKWISE,
blockwise_size=blockwise_size,
)
a_dq = a_fp8.to_original_precision()
sqnr = compute_error(a, a_dq)
assert sqnr >= 25.0

def test_axiswise_reshape(self):
a = torch.randn(3, 5, 7, dtype=torch.bfloat16)
linear_mm_config = LinearMMConfig()
Expand Down Expand Up @@ -272,6 +289,48 @@ def test_axiswise_gemm(self, a_shape, a_granularity, b_granularity):
sqnr = compute_error(c_ref, c_fp8_compute)
assert sqnr >= 25.0

@pytest.mark.parametrize("a_shape", [(16, 32), (2, 16, 32), (1, 2, 16, 32)])
@pytest.mark.parametrize(
"a_granularity,b_granularity",
[
(ScalingGranularity.BLOCKWISE, ScalingGranularity.BLOCKWISE),
(ScalingGranularity.BLOCKWISE, ScalingGranularity.TENSORWISE),
(ScalingGranularity.TENSORWISE, ScalingGranularity.BLOCKWISE),
],
)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
@unittest.skipIf(not is_sm_at_least_90(), "Requires CUDA capability >= 9.0")
def test_blockwise_gemm(self, a_shape, a_granularity, b_granularity):
a = torch.randn(*a_shape, dtype=torch.bfloat16, device="cuda")
b = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda")

linear_mm_config = LinearMMConfig()

a_fp8 = hp_tensor_to_float8_dynamic(
a,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.INPUT,
scaling_granularity=a_granularity,
blockwise_size=get_maybe_blockwise_size(8, a_granularity),
)
a_fp8 = a_fp8.reshape(-1, a_shape[-1])

b_fp8 = hp_tensor_to_float8_dynamic(
b,
e4m3_dtype,
linear_mm_config,
gemm_input_role=GemmInputRole.WEIGHT,
scaling_granularity=b_granularity,
blockwise_size=get_maybe_blockwise_size(8, b_granularity),
)

c_fp8_compute = torch.mm(a_fp8, b_fp8.t())
a = a.reshape(-1, a_shape[-1])
c_ref = torch.mm(a, b.t())
sqnr = compute_error(c_ref, c_fp8_compute)
assert sqnr >= 25.0


class TestFloat8Linear:
def _test_linear_impl(
Expand Down Expand Up @@ -417,7 +476,9 @@ def test_linear_from_config_params(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.ALL_BLOCKWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP,
],
)
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
Expand Down
2 changes: 2 additions & 0 deletions test/float8/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def test_inductor_from_config_params(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.ALL_BLOCKWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP,
],
)
@unittest.skipIf(
Expand Down
2 changes: 2 additions & 0 deletions test/float8/test_numerics_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,9 @@ def test_encoder_fw_bw_from_config_params(
"recipe_name",
[
Float8LinearRecipeName.ALL_AXISWISE,
Float8LinearRecipeName.ALL_BLOCKWISE,
Float8LinearRecipeName.LW_AXISWISE_WITH_GW_HP,
Float8LinearRecipeName.LW_BLOCKWISE_WITH_GW_HP,
],
)
@pytest.mark.skipif(
Expand Down
51 changes: 51 additions & 0 deletions test/prototype/test_blockwise_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import pytest
import torch

from torchao.prototype.blockwise_fp8.blockwise_fp8_gemm_triton import blockwise_fp8_gemm
from torchao.prototype.blockwise_fp8.blockwise_quantization import (
fp8_blockwise_act_quant,
fp8_blockwise_weight_dequant,
fp8_blockwise_weight_quant,
)

ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK = [
(2, 512, 128),
(3, 2048, 2048),
(4, 3584, 640),
(13, 8704, 8576),
(26, 18944, 1664),
(67, 6656, 1408),
]

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("_, N, K", ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK)
def test_quant_dequant(_, N, K):
x = torch.randn(N, K).cuda()
qx, s = fp8_blockwise_weight_quant(x, block_size=128)
x_reconstructed = fp8_blockwise_weight_dequant(qx, s, block_size=128)
error = torch.norm(x - x_reconstructed) / torch.norm(x)
print(f"Relative Error: {error.item():.6f}")

assert error < 0.05, "Quant-Dequant error is too high"

@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@pytest.mark.parametrize("M, N, K", ROWWISE_SCALED_LINEAR_CUTLASS_SIZE_MNK)
def test_blockwise_fp8_gemm(M, N, K):
A = torch.randn(M, K).cuda()
B = torch.randn(N, K).cuda()

C = A @ B.T

A_q, A_s = fp8_blockwise_act_quant(A, block_size=128)
B_q, B_s = fp8_blockwise_weight_quant(B, block_size=128)

C_q = blockwise_fp8_gemm(A_q, A_s, B_q, B_s)
print(C_q, C)
error = torch.norm(C - C_q) / torch.norm(C)
print(f"Relative Error: {error.item():.6f}")

assert error < 0.05, "Quantize gemm error is too high"


# test_quant_dequant()
# test_blockwise_fp8_gemm()
Loading