diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index eb26580cc3..a9324fe393 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -5,6 +5,9 @@ import triton from torch._inductor.utils import do_bench_using_profiling +from torchao.prototype.mx_formats.custom_cast import ( + triton_to_mxfp8_dim1, +) from torchao.prototype.mx_formats.mx_tensor import to_mx torch.manual_seed(0) @@ -49,6 +52,12 @@ def to_mx_dim0_reference(x_hp, block_size): return data_d0, scale_d0 +def to_mx_dim1_reference(x_hp, block_size): + x_hp = x_hp.t().contiguous() + scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size) + return data_d1.t(), scale_d1 + + def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float: """Thin wrapper around do_bench_using_profiling""" no_args = lambda: func(*args, **kwargs) @@ -67,7 +76,7 @@ def run( print(f"torch version: {torch.__version__}") print(f"triton version: {triton.__version__}") print(f"mode: {mode}") - assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx") + assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton") x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000 @@ -144,6 +153,41 @@ def run( bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) + elif mode == "dim1_mx": + to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference) + y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE) + + for _ in range(2): + __ = to_mx_dim1_reference_c(x, BLOCK_SIZE) + time_us = benchmark_cuda_function_in_microseconds( + lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE), + x, + BLOCK_SIZE, + ) + + assert y_d1.dtype == torch.float8_e4m3fn + assert s_d1.dtype == torch.uint8 + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + + elif mode == "dim1_mx_triton": + y_d1, s_d1 = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE) + + for _ in range(2): + __ = triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE) + time_us = benchmark_cuda_function_in_microseconds( + lambda x, b: triton_to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE), + x, + BLOCK_SIZE, + ) + + assert y_d1.dtype == torch.float8_e4m3fn + assert s_d1.dtype == torch.float8_e8m0fnu + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + else: raise AssertionError(f"unknown mode {mode}") diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index f330829565..580bff2172 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -29,6 +29,8 @@ triton_f4_to_bf16, triton_f6_e2m3_to_bf16, triton_f6_e3m2_to_bf16, + triton_to_mxfp8_dim1, + triton_to_mxfp8_dim1_reference, unpack_uint4, ) from torchao.prototype.mx_formats.fp_format_spec import ( @@ -42,7 +44,11 @@ sem_vals_to_f32, ) from torchao.prototype.mx_formats.mx_tensor import MXTensor -from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_8, + is_sm_at_least_89, + is_sm_at_least_100, +) torch.manual_seed(0) @@ -444,3 +450,20 @@ def test_fp6_e3m2_pack_unpack(): torch.float32 ) assert torch.all(orig_vals_f6_packed_unpacked == orig_vals) + + +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +@pytest.mark.skipif( + not is_sm_at_least_89(), + reason="float8 in triton requires CUDA capability 8.9 or greater", +) +@pytest.mark.parametrize("M", (256, 2048)) +@pytest.mark.parametrize("K", (256, 2048)) +# @pytest.mark.parametrize("M", (256,)) +# @pytest.mark.parametrize("K", (256,)) +def test_triton_mxfp8_dim1(M, K): + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + x_mx_ref, x_s_ref = triton_to_mxfp8_dim1_reference(x, block_size=32) + x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32) + torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) + torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 87f7531637..5690b0f5f0 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -4,6 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from typing import Tuple + import numpy as np import torch from torch.utils._triton import has_triton @@ -12,7 +14,7 @@ _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_8 # TODO(future): if needed, make the below work on previous PyTorch versions, # just need to hunt down the previous location of `libdevice`. An assert @@ -1080,3 +1082,308 @@ def _(uint8_data): def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: # Dummy placeholder op for torch < 2.4 raise AssertionError("fp6 packing unsupported without torch >= 2.4") + + +if TORCH_VERSION_AT_LEAST_2_8 and has_triton(): + import triton + import triton.language as tl + + @triton.jit + def _triton_calculate_scale(x, axis): + # There is no good support for accessing globals from a jit'ed triton + # function, so we redefine them here. Since this is prototype code which + # we plan to remove after torch.compile catches up, this is fine. + target_max_pow2 = 8 + e8m0_exponent_bias = 127 + bf16_mbits = 7 + bf16_exp_bias = 127 + fp32_mbits = 23 + # We use a small epsilon to avoid division by zero + epsilon = 1e-10 + + # Find the maximum absolute value for each row + max_abs = tl.max(x, axis=axis) + + # Calculate the e8m0 scale by extracting the exponent (floor) + # TODO(future PR): support other exponent extraction types (ceil, RNE) + max_abs = max_abs + epsilon + max_abs = max_abs.to(tl.bfloat16) + max_abs_int16 = max_abs.to(tl.int16, bitcast=True) + extracted_pow2 = ((max_abs_int16 >> bf16_mbits) & 0b11111111) - bf16_exp_bias + extracted_pow2 = extracted_pow2 - target_max_pow2 + scale_e8m0_unbiased = extracted_pow2.to(tl.bfloat16) + + # Clamp to exponents that can be represented in e8m0 + scale_e8m0_unbiased = tl.clamp( + scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias + ) + + # Create the biased e8m0 representation and cast it to 8 bits + scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias + scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8) + + # TODO(future PR): add NaN handling here + + # Calculate the scale in floating point. + scale_fp = (scale_e8m0_biased.to(tl.int32) << fp32_mbits).to( + tl.float32, bitcast=True + ) + + return scale_fp, scale_e8m0_biased + + def _get_mxfp8_dim1_kernel_autotune_configs(): + # Values to sweep over here were determined by a manual + # sweep over a small set of shapes, it's likely that this + # can be improved in the future. + results = [] + for ROW_TILE_SIZE in (64, 128): + for COL_TILE_SIZE in (64, 128): + for num_warps in (1, 2, 4): + config = triton.Config( + { + "ROW_TILE_SIZE": ROW_TILE_SIZE, + "COL_TILE_SIZE": COL_TILE_SIZE, + }, + num_warps=num_warps, + ) + results.append(config) + return results + + @triton.autotune( + configs=_get_mxfp8_dim1_kernel_autotune_configs(), + key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"], + ) + @triton.jit + def to_mxfp8_dim1_kernel( + x_ptr, # pointer to input tensor + output_col_major_ptr, # pointer to column-major output tensor (column-normalized) + col_scale_ptr, # pointer to store column-wise maximum absolute values + n_rows, # number of rows in the tensor + n_cols, # number of columns in the tensor + ROW_TILE_SIZE: tl.constexpr, + COL_TILE_SIZE: tl.constexpr, + INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX + ): + """ + Example tiling for n_rows==8, n_cols=8, ROW_TILE_SIZE=4, COL_TILE_SIZE=4, INNER_BLOCK_SIZE=2, + pid_row=0, pid_col=0: + + Input (row-major) + + cols 0 1 2 3 4 5 6 7 + -------------------------------- + rows 0 | 0 1 2 3 + 1 | 8 9 10 11 + 2 | 16 17 18 19 + 3 | 24 25 26 27 + 4 | + 5 | + 6 | + 7 | + + Output (row-major of transpose), ids are from input + + cols 0 1 2 3 4 5 6 7 + -------------------------------- + rows 0 | 0 8 16 24 + 1 | 1 9 17 25 + 2 | 2 10 18 26 + 3 | 3 11 19 27 + 4 | + 5 | + 6 | + 7 | + + Output (scales), s(0, 8) means the scale used to cast elements 0 and 8 + + rows 0 1 ... 4 ... 31 + ------------------------------------------------------ + s(0, 8) s(16, 24) ... s(1, 9) ... s(19, 27) + """ + + BLOCKS_PER_ROW_TILE: tl.constexpr = ROW_TILE_SIZE // INNER_BLOCK_SIZE + + # Get program ID + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + # Calculate starting row and column for this tile + start_row = pid_row * ROW_TILE_SIZE + start_col = pid_col * COL_TILE_SIZE + + # Create offsets for the block + row_offsets = tl.arange(0, ROW_TILE_SIZE) + col_offsets = tl.arange(0, COL_TILE_SIZE) + + # Compute global row/col positions + rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting + cols = start_col + col_offsets[None, :] + + # Create masks for out-of-bounds accesses + row_mask = rows < n_rows + col_mask = cols < n_cols + mask = row_mask & col_mask + + # Compute memory offsets for row-major layout (rows, cols) + row_major_offsets = (rows * n_cols + cols).to(tl.int32) + + # Compute memory offsets for column-major layout (cols, rows) + col_major_offsets = (cols * n_rows + rows).to(tl.int32) + + # Load the entire block in a single operation + # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) + x_block = tl.load(x_ptr + row_major_offsets, mask=mask) + + # Transpose dim0 and dim1 + # shape: (COL_TILE_SIZE, ROW_TILE_SIZE) + x_block_t = tl.trans(x_block) + + # Reshape to inner tile size + # shape: (COL_TILE_SIZE, ROW_TILE_SIZE) -> (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE, INNER_BLOCK_SIZE) + x_block_t_r = x_block_t.reshape( + COL_TILE_SIZE * BLOCKS_PER_ROW_TILE, INNER_BLOCK_SIZE + ) + + # Calculate the absolute values of elements in the block + x_block_abs_t_r = tl.abs(x_block_t_r) + + # Find the maximum absolute value for each column + # shape: (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE,) + col_scale_r, col_scale_e8m0_r = _triton_calculate_scale(x_block_abs_t_r, axis=1) + + # Divide each column by scale + # Broadcasting col_scale to match x_block's shape + # x_block_t_r shape (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE, INNER_BLOCK_SIZE) + # col_scale shape (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE,) -> (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE, 1) + col_normalized_t_r = x_block_t_r / col_scale_r[:, None] + + # Reshape back to original tile size + col_normalized_t = tl.reshape(col_normalized_t_r, COL_TILE_SIZE, ROW_TILE_SIZE) + + # Undo the transpose + col_normalized = tl.trans(col_normalized_t) + + # Quantize to float8 + col_normalized = col_normalized.to(tl.float8e4nv) + + # Store the column-normalized result in column-major format + # TODO(future): this mask is for row-major likely need to transpose it for col-major + tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask) + + # reshape col_scale_e8m0_r to col_scale_e8m0 + # shape: (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE,) -> (COL_TILE_SIZE, BLOCKS_PER_ROW_TILE,) + col_scale_e8m0 = col_scale_e8m0_r.reshape(COL_TILE_SIZE * BLOCKS_PER_ROW_TILE) + + col_scale_start_offsets = ( + (pid_col * COL_TILE_SIZE * (n_rows // ROW_TILE_SIZE)) + * BLOCKS_PER_ROW_TILE # number of blocks seen so far + + pid_row * BLOCKS_PER_ROW_TILE # increment BLOCKS_PER_ROW_TILE + ) + + col_scale_start_ptr = col_scale_ptr + col_scale_start_offsets + + # calculate col_scale_indices + col_scale_indices = tl.arange(0, COL_TILE_SIZE * BLOCKS_PER_ROW_TILE) + + # How many values are in all the other columns for this row_pid, need to jump + # over them for every BLOCKS_PER_ROW_TILE values + jump_vals_per_col = (n_rows - ROW_TILE_SIZE) // INNER_BLOCK_SIZE + + # example transformation (specifics depend on tile sizes): + # [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13] + col_scale_indices = col_scale_indices + ( + (col_scale_indices // BLOCKS_PER_ROW_TILE) * jump_vals_per_col + ) + + # TODO(future): mask this store + tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) + + def triton_to_mxfp8_dim1( + x, inner_block_size=32 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Input: + * `x` - input tensor, in row major memory layout + * `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes + + Output: + * `output_col_major`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim1 + * `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1 + """ + assert x.is_contiguous(), "`x` must be contiguous" + assert x.dtype == torch.bfloat16 + assert inner_block_size <= 32 + + # Get tensor shape + n_rows, n_cols = x.shape + + # Masking of loads and stores is not well tested yet, so for now enforce + # shapes which do not need masking. Note that this condition depends on max values of + # ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above. + # TODO(future): implement and test masking and remove this restriction + max_row_tile_size = 128 + max_col_tile_size = 128 + assert n_rows % max_row_tile_size == 0, "unsupported" + assert n_cols % max_col_tile_size == 0, "unsupported" + + # Create output tensors + output_col_major = torch.empty( + (n_cols, n_rows), dtype=torch.float8_e4m3fn, device=x.device + ) + + # Create scale tensors + col_scale = torch.empty( + (n_cols * n_rows // inner_block_size, 1), dtype=torch.uint8, device=x.device + ) + + # Calculate grid dimensions based on tile size + grid = lambda META: ( + triton.cdiv(n_rows, META["ROW_TILE_SIZE"]), + triton.cdiv(n_cols, META["COL_TILE_SIZE"]), + ) + + # Launch the kernel + to_mxfp8_dim1_kernel[grid]( + x_ptr=x, + output_col_major_ptr=output_col_major, + col_scale_ptr=col_scale, + n_rows=n_rows, + n_cols=n_cols, + INNER_BLOCK_SIZE=inner_block_size, + ) + + return ( + output_col_major.t(), + col_scale.view(torch.float8_e8m0fnu), + ) + + def triton_to_mxfp8_dim1_reference( + x_hp: torch.Tensor, block_size + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A reference version of `to_mxfp8_dim1`. + """ + from torchao.prototype.mx_formats.mx_tensor import to_mx + + # cast across dim1 + x_hp_d1 = x_hp.t().contiguous() + scale_e8m0_dim1, x_hp_d1_normalized = to_mx( + x_hp_d1, torch.float8_e4m3fn, block_size + ) + scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu) + return ( + x_hp_d1_normalized.t(), + scale_e8m0_dim1, + ) + +else: + + def triton_to_mxfp8_dim1( + x, inner_block_size=32 + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise AssertionError("needs torch version 2.8+ and triton") + + def triton_to_mxfp8_dim1_reference( + x_hp: torch.Tensor, block_size + ) -> Tuple[torch.Tensor, torch.Tensor]: + raise AssertionError("needs torch version 2.8+ and triton")