From af6ae2f0ec344dccb974d48917a95038bbb12cc3 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 21 Mar 2025 06:59:05 -0700 Subject: [PATCH 01/20] Update [ghstack-poisoned] --- benchmarks/mx_formats/cast_bench.py | 46 +- test/prototype/mx_formats/test_custom_cast.py | 16 + torchao/prototype/mx_formats/custom_cast.py | 497 ++++++++++++++++++ 3 files changed, 558 insertions(+), 1 deletion(-) diff --git a/benchmarks/mx_formats/cast_bench.py b/benchmarks/mx_formats/cast_bench.py index eb26580cc3..1d7498b970 100644 --- a/benchmarks/mx_formats/cast_bench.py +++ b/benchmarks/mx_formats/cast_bench.py @@ -5,6 +5,9 @@ import triton from torch._inductor.utils import do_bench_using_profiling +from torchao.prototype.mx_formats.custom_cast import ( + to_mxfp8_dim1, +) from torchao.prototype.mx_formats.mx_tensor import to_mx torch.manual_seed(0) @@ -49,6 +52,12 @@ def to_mx_dim0_reference(x_hp, block_size): return data_d0, scale_d0 +def to_mx_dim1_reference(x_hp, block_size): + x_hp = x_hp.t().contiguous() + scale_d1, data_d1 = to_mx(x_hp, torch.float8_e4m3fn, block_size) + return data_d1.t(), scale_d1 + + def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float: """Thin wrapper around do_bench_using_profiling""" no_args = lambda: func(*args, **kwargs) @@ -67,7 +76,7 @@ def run( print(f"torch version: {torch.__version__}") print(f"triton version: {triton.__version__}") print(f"mode: {mode}") - assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx") + assert mode in ("dim0", "dim1", "dim0_dim1", "dim0_mx", "dim1_mx", "dim1_mx_triton") x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") * 1000 @@ -144,6 +153,41 @@ def run( bytes_w = (y_d0.numel() + s_d0.numel()) * bytes_per_el_fp8 bps = (bytes_r + bytes_w) / (time_us / 1e6) + elif mode == "dim1_mx": + to_mx_dim1_reference_c = torch.compile(to_mx_dim1_reference) + y_d1, s_d1 = to_mx_dim1_reference_c(x, BLOCK_SIZE) + + for _ in range(2): + __ = to_mx_dim1_reference_c(x, BLOCK_SIZE) + time_us = benchmark_cuda_function_in_microseconds( + lambda x, b: to_mx_dim1_reference_c(x, BLOCK_SIZE), + x, + BLOCK_SIZE, + ) + + assert y_d1.dtype == torch.float8_e4m3fn + assert s_d1.dtype == torch.uint8 + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + + elif mode == "dim1_mx_triton": + y_d1, s_d1 = to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE) + + for _ in range(2): + __ = to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE) + time_us = benchmark_cuda_function_in_microseconds( + lambda x, b: to_mxfp8_dim1(x, inner_block_size=BLOCK_SIZE), + x, + BLOCK_SIZE, + ) + + assert y_d1.dtype == torch.float8_e4m3fn + assert s_d1.dtype == torch.float8_e8m0fnu + bytes_r = x.numel() * bytes_per_el_bf16 + bytes_w = (y_d1.numel() + s_d1.numel()) * bytes_per_el_fp8 + bps = (bytes_r + bytes_w) / (time_us / 1e6) + else: raise AssertionError(f"unknown mode {mode}") diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index f330829565..58ccb7f076 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -26,6 +26,9 @@ get_bits, pack_uint4, pack_uint6, + # TODO(before land): better name? + to_mxfp8_dim1, + to_mxfp8_dim1_reference, triton_f4_to_bf16, triton_f6_e2m3_to_bf16, triton_f6_e3m2_to_bf16, @@ -444,3 +447,16 @@ def test_fp6_e3m2_pack_unpack(): torch.float32 ) assert torch.all(orig_vals_f6_packed_unpacked == orig_vals) + + +# TODO(before land): skip before sm89 +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +def test_triton_mxfp8_dim1(): + M, K = 1024, 2048 + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + x_mx_ref, x_s_ref = to_mxfp8_dim1_reference(x, block_size=32) + x_mx_t, x_s_t = to_mxfp8_dim1(x, inner_block_size=32) + torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) + torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) + print("done") diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 87f7531637..ff4351833f 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -4,6 +4,8 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. +from typing import Tuple + import numpy as np import torch from torch.utils._triton import has_triton @@ -1080,3 +1082,498 @@ def _(uint8_data): def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: # Dummy placeholder op for torch < 2.4 raise AssertionError("fp6 packing unsupported without torch >= 2.4") + + +if TORCH_VERSION_AT_LEAST_2_4 and has_triton(): + import triton + import triton.language as tl + + @triton.jit + def _triton_calculate_scale(x, axis): + # We use a small epsilon to avoid division by zero + epsilon = 1e-10 + + # TODO(before land): reuse the constants below instead of hardcoding + target_max_pow2 = 8 + e8m0_exponent_bias = 127 + # bf16_mbits = 7 + # bf16_exp_bias = 127 + + # Find the maximum absolute value for each row + max_abs = tl.max(x, axis=axis) + # return max_abs, max_abs.to(tl.uint8) + + # TODO 1: rewrite as bit shifts, see https://github.com/pytorch/ao/pull/1908/files + # before: 1.7 TB/s + # after: ? + scale_e8m0_unbiased = tl.floor(tl.log2(max_abs + epsilon)) - target_max_pow2 + # max_abs = max_abs + epsilon + # max_abs = max_abs.to(tl.bfloat16) + # max_abs_int16 = max_abs.to(tl.int16, bitcast=True) + # extracted_pow2 = ((max_abs_int16 >> bf16_mbits) & 0b11111111) - bf16_exp_bias + # extracted_pow2 = extracted_pow2 - target_max_pow2 + # scale_e8m0_unbiased = extracted_pow2.to(tl.bfloat16) + + # Clamp to exponents that can be represented in e8m0 + scale_e8m0_unbiased = tl.clamp( + scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias + ) + + # Create the biased e8m0 representation and cast it to 8 bits + scale_e8m0_biased = scale_e8m0_unbiased + e8m0_exponent_bias + scale_e8m0_biased = scale_e8m0_biased.to(tl.uint8) + + # TODO(future PR): add NaN handling here + + # For now, calculate the scale in floating point. + # TODO(future) audit if there is a need to bit shift exponents instead. + # TODO 2: rewrite as bit shifts, see https://github.com/pytorch/ao/pull/1910/files + scale_fp = tl.exp2(scale_e8m0_unbiased.to(tl.float32)) + + return scale_fp, scale_e8m0_biased + + @triton.jit + def to_mxfp8_across_dim0_and_dim1_kernel( + x_ptr, # pointer to input tensor + output_row_major_ptr, # pointer to row-major output tensor (row-normalized) + output_col_major_ptr, # pointer to column-major output tensor (column-normalized) + row_scale_ptr, # pointer to store row-wise maximum absolute values + col_scale_ptr, # pointer to store column-wise maximum absolute values + n_rows, # number of rows in the tensor + n_cols, # number of columns in the tensor + TILE_SIZE: tl.constexpr, # tile size as a compile-time constant + ): + """ + credit: mostly Claude, some Vasiliy + """ + + # Get program ID + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + # Calculate starting row and column for this tile + start_row = pid_row * TILE_SIZE + start_col = pid_col * TILE_SIZE + + # Create offsets for the block + row_offsets = tl.arange(0, TILE_SIZE) + col_offsets = tl.arange(0, TILE_SIZE) + + # Compute global row/col positions + rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting + cols = start_col + col_offsets[None, :] + + # Create masks for out-of-bounds accesses + row_mask = rows < n_rows + col_mask = cols < n_cols + mask = row_mask & col_mask + + # Compute memory offsets for row-major layout (rows, cols) + row_major_offsets = (rows * n_cols + cols).to(tl.int32) + + # Compute memory offsets for column-major layout (cols, rows) + col_major_offsets = (cols * n_rows + rows).to(tl.int32) + + # Load the entire block in a single operation + x_block = tl.load(x_ptr + row_major_offsets, mask=mask) + + # ---------------------------------------------------- + # Row-wise normalization + # ---------------------------------------------------- + # Calculate the absolute values of elements in the block + x_block_abs = tl.abs(x_block) + + # Find the maximum absolute value for each row + row_scale, row_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=1) + + # Normalize each row by its maximum absolute value + # Broadcasting row_scale to match x_block's shape + row_normalized = x_block / row_scale[:, None] + + # quant to float8 + row_normalized = row_normalized.to(tl.float8e4nv) + + # ---------------------------------------------------- + # Column-wise normalization + # ---------------------------------------------------- + # Find the maximum absolute value for each column + col_scale, col_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=0) + + # Normalize each column by its maximum absolute value + # Broadcasting col_scale to match x_block's shape + col_normalized = x_block / col_scale[None, :] + + # quant to float8 + col_normalized = col_normalized.to(tl.float8e4nv) + + # Store the row-normalized result in row-major format + tl.store(output_row_major_ptr + row_major_offsets, row_normalized, mask=mask) + + # Store the column-normalized result in column-major format + tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask) + + # Create 1D ranges for storing row and column max values + row_indices = start_row + tl.arange(0, TILE_SIZE) + col_indices = start_col + tl.arange(0, TILE_SIZE) + + # Create masks for valid rows and columns + row_mask = row_indices < n_rows + col_mask = col_indices < n_cols + + # Vasiliy - deviating from Claude here for much simpler code + row_scale_start_ptr = row_scale_ptr + (pid_row * n_cols) + pid_col + row_scale_indices = tl.arange(0, TILE_SIZE) * (n_cols // TILE_SIZE) + # TODO(future): mask + tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0) + + # Vasiliy - deviating from Claude here for much simpler code + col_scale_start_ptr = col_scale_ptr + (pid_col * n_rows) + pid_row + col_scale_indices = tl.arange(0, TILE_SIZE) * (n_rows // TILE_SIZE) + # TODO(future): mask + tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) + + def to_mxfp8_across_dim0_and_dim1(x, tile_size=32): + """ + This is a single fused triton kernel to cast `x` to MX across dim0 and dim1. + This is useful for MX training with the mxfp8 recipe family. + + The kernel loads data in 2d tiles, and performs the necessary casting across both + dim0 and dim1 for each tile. + + Note that for now, there is only one level of tiling (32 for MX). In the future, + we expect that adding an outer tile (of size up to 128 on B200s) can provide a + further speedup. + + Input: + * `x` - input tensor, in row major memory layout + * `tile_size` - size of tiles to normalize across, default is 32 for MX recipes + + Output: + * `output_row_major`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 + * `output_col_major`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim1 + * `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0 + * `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1 + """ + assert x.is_contiguous(), "`x` must be contiguous" + assert x.dtype == torch.bfloat16 + # Get tensor shape + n_rows, n_cols = x.shape + + # Create output tensors (both row-major and column-major) + output_row_major = torch.empty_like(x, dtype=torch.float8_e4m3fn) + output_col_major = torch.empty( + (n_cols, n_rows), dtype=torch.float8_e4m3fn, device=x.device + ) + + # Create tensors for row-wise and column-wise maximum absolute values + row_scale = torch.empty( + n_rows, n_cols // tile_size, dtype=torch.uint8, device=x.device + ) + col_scale = torch.empty( + n_cols, n_rows // tile_size, dtype=torch.uint8, device=x.device + ) + + # Calculate grid dimensions based on tile size + grid_rows = triton.cdiv(n_rows, tile_size) + grid_cols = triton.cdiv(n_cols, tile_size) + + # Launch the kernel + to_mxfp8_across_dim0_and_dim1_kernel[(grid_rows, grid_cols)]( + x_ptr=x, + output_row_major_ptr=output_row_major, + output_col_major_ptr=output_col_major, + row_scale_ptr=row_scale, + col_scale_ptr=col_scale, + n_rows=n_rows, + n_cols=n_cols, + TILE_SIZE=tile_size, + ) + + return ( + output_row_major, + output_col_major.t(), + row_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), + col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), + ) + + def to_mxfp8_across_dim0_and_dim1_reference( + x_hp: torch.Tensor, block_size + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + A reference version of `to_mxfp8_across_dim0_and_dim1`. + """ + from torchao.prototype.mx_formats.mx_tensor import to_mx + + # cast across dim0 + scale_e8m0_dim0, x_hp_d0_normalized = to_mx( + x_hp, torch.float8_e4m3fn, block_size + ) + scale_e8m0_dim0 = scale_e8m0_dim0.unsqueeze(1).view(torch.float8_e8m0fnu) + # cast across dim1 + x_hp_d1 = x_hp.t().contiguous() + scale_e8m0_dim1, x_hp_d1_normalized = to_mx( + x_hp_d1, torch.float8_e4m3fn, block_size + ) + scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu) + return ( + x_hp_d0_normalized, + x_hp_d1_normalized.t(), + scale_e8m0_dim0, + scale_e8m0_dim1, + ) + + @triton.jit + def to_mxfp8_dim1_kernel( + x_ptr, # pointer to input tensor + output_col_major_ptr, # pointer to column-major output tensor (column-normalized) + col_scale_ptr, # pointer to store column-wise maximum absolute values + n_rows, # number of rows in the tensor + n_cols, # number of columns in the tensor + ROW_TILE_SIZE: tl.constexpr, # can be autotuned + COL_TILE_SIZE: tl.constexpr, # can be autotuned + RENAME_ME_TILE_SIZE: tl.constexpr, + INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX + ): + # Get program ID + pid_row = tl.program_id(0) + pid_col = tl.program_id(1) + + # Calculate starting row and column for this tile + start_row = pid_row * ROW_TILE_SIZE + start_col = pid_col * COL_TILE_SIZE + + # Create offsets for the block + row_offsets = tl.arange(0, ROW_TILE_SIZE) + col_offsets = tl.arange(0, COL_TILE_SIZE) + + # Compute global row/col positions + rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting + cols = start_col + col_offsets[None, :] + + # Create masks for out-of-bounds accesses + row_mask = rows < n_rows + col_mask = cols < n_cols + mask = row_mask & col_mask + + # Compute memory offsets for row-major layout (rows, cols) + row_major_offsets = (rows * n_cols + cols).to(tl.int32) + + # Compute memory offsets for column-major layout (cols, rows) + col_major_offsets = (cols * n_rows + rows).to(tl.int32) + + # Load the entire block in a single operation + # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) + x_block = tl.load(x_ptr + row_major_offsets, mask=mask) + + # Transpose dim0 and dim1 + # shape: (COL_TILE_SIZE, ROW_TILE_SIZE) + x_block_t = tl.trans(x_block) + + # Reshape to inner tile size + # TODO: make this generic and nice + # inner_block_size = 32 + # shape: (COL_TILE_SIZE, ROW_TILE_SIZE) -> (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, INNER_BLOCK_SIZE) + x_block_t_r = x_block_t.reshape(RENAME_ME_TILE_SIZE, INNER_BLOCK_SIZE) + + # Calculate the absolute values of elements in the block + x_block_abs_t_r = tl.abs(x_block_t_r) + + # Find the maximum absolute value for each column + # col_scale, col_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=0) + # shape: (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) + col_scale_r, col_scale_e8m0_r = _triton_calculate_scale(x_block_abs_t_r, axis=1) + # tl.device_print("col_scale.shape", col_scale.shape[0]) + + # Divide each column by scale + # Broadcasting col_scale to match x_block's shape + # x_block shape (n_rows, n_cols) + # col_scale shape (n_cols,) -> (1, n_cols) + # col_normalized = x_block / col_scale[None, :] + + # x_block_t shape (COL_TILE_SIZE, ROW_TILE_SIZE) + # x_block_t_r shape (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, INNER_BLOCK_SIZE) + # col_scale shape (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) -> (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, 1) + col_normalized_t_r = x_block_t_r / col_scale_r[:, None] + + # Reshape back to original tile size + col_normalized_t = tl.reshape(col_normalized_t_r, COL_TILE_SIZE, ROW_TILE_SIZE) + + # Undo the transpose + col_normalized = tl.trans(col_normalized_t) + + # Quantize to float8 + col_normalized = col_normalized.to(tl.float8e4nv) + + # Store the column-normalized result in column-major format + tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask) + + # Create 1D ranges for storing row and column max values + # col_indices = start_col + tl.arange(0, COL_TILE_SIZE) + + # Create masks for valid rows and columns + # col_mask = col_indices < n_cols + + # reshape col_scale_e8m0_r to col_scale_e8m0 + # shape: (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) -> (COL_TILE_SIZE, ROW_TILE_SIZE // INNER_BLOCK_SIZE,) + # col_scale_e8m0 = col_scale_e8m0_r.reshape(COL_TILE_SIZE, ROW_TILE_SIZE // INNER_BLOCK_SIZE) + col_scale_e8m0 = col_scale_e8m0_r.reshape( + COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE + ) + # col_scale_e8m0 = col_scale_e8m0_r.ravel() + + # col_scale_start_offsets = ( + # ( + # pid_col * COL_TILE_SIZE * (n_rows // ROW_TILE_SIZE) + # ) # number of blocks seen so far + # + pid_row # increment ROW_TILE_SIZE + # ) + + factor = ROW_TILE_SIZE // INNER_BLOCK_SIZE + + col_scale_start_offsets = ( + (pid_col * COL_TILE_SIZE * (n_rows // ROW_TILE_SIZE)) + * factor # number of blocks seen so far + + pid_row * factor # increment ROW_TILE_SIZE + ) + + # tl.device_print("offset", col_scale_start_offsets) + col_scale_start_ptr = col_scale_ptr + col_scale_start_offsets + # col_scale_start_ptr = col_scale_ptr + + # col_scale_indices = tl.arange(0, COL_TILE_SIZE) * (n_rows // ROW_TILE_SIZE) + + # calculate col_scale_indices, this is a bit convoluted + # start with a sequential index [0, COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE] + # from example: [0, 1, 2, 3, 4, 5, 6, 7] + col_scale_indices = tl.arange( + 0, COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE + ) + # add offset for inner blocks + factor = ROW_TILE_SIZE // INNER_BLOCK_SIZE + + # needs better name + jump_vals_per_col = (n_rows - ROW_TILE_SIZE) // INNER_BLOCK_SIZE + # tl.device_print("jump_vals_per_col", jump_vals_per_col) + + # [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13] + col_scale_indices = col_scale_indices + ( + tl.floor(col_scale_indices / factor) * jump_vals_per_col + ).to(tl.int32) + # tl.static_print(col_scale_indices) + # tl.device_print("indices", col_scale_indices) + + # TODO(future): mask + tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) + + def to_mxfp8_dim1(x, inner_block_size=32): + """ + Input: + * `x` - input tensor, in row major memory layout + * `inner_block_size` - size of tiles to scale across, default is 32 for MX recipes + + Output: + * `output_col_major`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim1 + * `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1 + """ + assert x.is_contiguous(), "`x` must be contiguous" + assert x.dtype == torch.bfloat16 + # Get tensor shape + n_rows, n_cols = x.shape + + # can be autotuned + # row_tile_size = 32 + row_tile_size = 128 + # row_tile_size = 4 + + assert row_tile_size >= inner_block_size + + # TODO autotune col_tile_size + # input 16k by 16k + # triton scaling mostly commented out + # row_tile_size=256, col_tile_size=64: 3.47 TB/s + # TODO(next): make calculations work with ^ + col_tile_size = 128 + # col_tile_size = 4 + + # Create output tensors + output_col_major = torch.empty( + (n_cols, n_rows), dtype=torch.float8_e4m3fn, device=x.device + ) + + # Create scale tensors + # TODO(before land): switch back to empty + col_scale = torch.empty( + n_cols, n_rows // inner_block_size, dtype=torch.uint8, device=x.device + ) + + # Calculate grid dimensions based on tile size + grid_rows = triton.cdiv(n_rows, row_tile_size) + grid_cols = triton.cdiv(n_cols, col_tile_size) + + # inner_block_size = 32 + rename_me_tile_size = row_tile_size * col_tile_size // inner_block_size + + # Launch the kernel + to_mxfp8_dim1_kernel[(grid_rows, grid_cols)]( + x_ptr=x, + output_col_major_ptr=output_col_major, + col_scale_ptr=col_scale, + n_rows=n_rows, + n_cols=n_cols, + COL_TILE_SIZE=col_tile_size, + ROW_TILE_SIZE=row_tile_size, + RENAME_ME_TILE_SIZE=rename_me_tile_size, + INNER_BLOCK_SIZE=inner_block_size, + ) + + return ( + output_col_major.t(), + col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), + ) + + def to_mxfp8_dim0_reference( + x_hp: torch.Tensor, block_size + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A reference version of `to_mxfp8_dim0`. + """ + from torchao.prototype.mx_formats.mx_tensor import to_mx + + # cast across dim0 + scale_e8m0_dim0, x_hp_d0_normalized = to_mx( + x_hp, torch.float8_e4m3fn, block_size + ) + scale_e8m0_dim0 = scale_e8m0_dim0.unsqueeze(1).view(torch.float8_e8m0fnu) + return ( + x_hp_d0_normalized, + scale_e8m0_dim0, + ) + + def to_mxfp8_dim1_reference( + x_hp: torch.Tensor, block_size + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + A reference version of `to_mxfp8_dim1`. + """ + from torchao.prototype.mx_formats.mx_tensor import to_mx + + # cast across dim1 + x_hp_d1 = x_hp.t().contiguous() + scale_e8m0_dim1, x_hp_d1_normalized = to_mx( + x_hp_d1, torch.float8_e4m3fn, block_size + ) + scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu) + return ( + x_hp_d1_normalized.t(), + scale_e8m0_dim1, + ) + +else: + + def to_mxfp8_across_dim0_and_dim1(x, tile_size=32): + raise AssertionError("needs torch version 2.4+ and triton") + + def scale_dim0_dim1_reference( + x_hp: torch.Tensor, block_size + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + raise AssertionError("needs torch version 2.4+ and triton") From 45120de489e1cb13325c3c4e9095450d3ec7fa5d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 21 Mar 2025 07:16:21 -0700 Subject: [PATCH 02/20] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/custom_cast.py | 208 -------------------- 1 file changed, 208 deletions(-) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index ff4351833f..6dba900abe 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -1132,196 +1132,6 @@ def _triton_calculate_scale(x, axis): return scale_fp, scale_e8m0_biased - @triton.jit - def to_mxfp8_across_dim0_and_dim1_kernel( - x_ptr, # pointer to input tensor - output_row_major_ptr, # pointer to row-major output tensor (row-normalized) - output_col_major_ptr, # pointer to column-major output tensor (column-normalized) - row_scale_ptr, # pointer to store row-wise maximum absolute values - col_scale_ptr, # pointer to store column-wise maximum absolute values - n_rows, # number of rows in the tensor - n_cols, # number of columns in the tensor - TILE_SIZE: tl.constexpr, # tile size as a compile-time constant - ): - """ - credit: mostly Claude, some Vasiliy - """ - - # Get program ID - pid_row = tl.program_id(0) - pid_col = tl.program_id(1) - - # Calculate starting row and column for this tile - start_row = pid_row * TILE_SIZE - start_col = pid_col * TILE_SIZE - - # Create offsets for the block - row_offsets = tl.arange(0, TILE_SIZE) - col_offsets = tl.arange(0, TILE_SIZE) - - # Compute global row/col positions - rows = start_row + row_offsets[:, None] # Convert to 2D for proper broadcasting - cols = start_col + col_offsets[None, :] - - # Create masks for out-of-bounds accesses - row_mask = rows < n_rows - col_mask = cols < n_cols - mask = row_mask & col_mask - - # Compute memory offsets for row-major layout (rows, cols) - row_major_offsets = (rows * n_cols + cols).to(tl.int32) - - # Compute memory offsets for column-major layout (cols, rows) - col_major_offsets = (cols * n_rows + rows).to(tl.int32) - - # Load the entire block in a single operation - x_block = tl.load(x_ptr + row_major_offsets, mask=mask) - - # ---------------------------------------------------- - # Row-wise normalization - # ---------------------------------------------------- - # Calculate the absolute values of elements in the block - x_block_abs = tl.abs(x_block) - - # Find the maximum absolute value for each row - row_scale, row_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=1) - - # Normalize each row by its maximum absolute value - # Broadcasting row_scale to match x_block's shape - row_normalized = x_block / row_scale[:, None] - - # quant to float8 - row_normalized = row_normalized.to(tl.float8e4nv) - - # ---------------------------------------------------- - # Column-wise normalization - # ---------------------------------------------------- - # Find the maximum absolute value for each column - col_scale, col_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=0) - - # Normalize each column by its maximum absolute value - # Broadcasting col_scale to match x_block's shape - col_normalized = x_block / col_scale[None, :] - - # quant to float8 - col_normalized = col_normalized.to(tl.float8e4nv) - - # Store the row-normalized result in row-major format - tl.store(output_row_major_ptr + row_major_offsets, row_normalized, mask=mask) - - # Store the column-normalized result in column-major format - tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask) - - # Create 1D ranges for storing row and column max values - row_indices = start_row + tl.arange(0, TILE_SIZE) - col_indices = start_col + tl.arange(0, TILE_SIZE) - - # Create masks for valid rows and columns - row_mask = row_indices < n_rows - col_mask = col_indices < n_cols - - # Vasiliy - deviating from Claude here for much simpler code - row_scale_start_ptr = row_scale_ptr + (pid_row * n_cols) + pid_col - row_scale_indices = tl.arange(0, TILE_SIZE) * (n_cols // TILE_SIZE) - # TODO(future): mask - tl.store(row_scale_start_ptr + row_scale_indices, row_scale_e8m0) - - # Vasiliy - deviating from Claude here for much simpler code - col_scale_start_ptr = col_scale_ptr + (pid_col * n_rows) + pid_row - col_scale_indices = tl.arange(0, TILE_SIZE) * (n_rows // TILE_SIZE) - # TODO(future): mask - tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) - - def to_mxfp8_across_dim0_and_dim1(x, tile_size=32): - """ - This is a single fused triton kernel to cast `x` to MX across dim0 and dim1. - This is useful for MX training with the mxfp8 recipe family. - - The kernel loads data in 2d tiles, and performs the necessary casting across both - dim0 and dim1 for each tile. - - Note that for now, there is only one level of tiling (32 for MX). In the future, - we expect that adding an outer tile (of size up to 128 on B200s) can provide a - further speedup. - - Input: - * `x` - input tensor, in row major memory layout - * `tile_size` - size of tiles to normalize across, default is 32 for MX recipes - - Output: - * `output_row_major`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 - * `output_col_major`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim1 - * `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0 - * `col_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim1 - """ - assert x.is_contiguous(), "`x` must be contiguous" - assert x.dtype == torch.bfloat16 - # Get tensor shape - n_rows, n_cols = x.shape - - # Create output tensors (both row-major and column-major) - output_row_major = torch.empty_like(x, dtype=torch.float8_e4m3fn) - output_col_major = torch.empty( - (n_cols, n_rows), dtype=torch.float8_e4m3fn, device=x.device - ) - - # Create tensors for row-wise and column-wise maximum absolute values - row_scale = torch.empty( - n_rows, n_cols // tile_size, dtype=torch.uint8, device=x.device - ) - col_scale = torch.empty( - n_cols, n_rows // tile_size, dtype=torch.uint8, device=x.device - ) - - # Calculate grid dimensions based on tile size - grid_rows = triton.cdiv(n_rows, tile_size) - grid_cols = triton.cdiv(n_cols, tile_size) - - # Launch the kernel - to_mxfp8_across_dim0_and_dim1_kernel[(grid_rows, grid_cols)]( - x_ptr=x, - output_row_major_ptr=output_row_major, - output_col_major_ptr=output_col_major, - row_scale_ptr=row_scale, - col_scale_ptr=col_scale, - n_rows=n_rows, - n_cols=n_cols, - TILE_SIZE=tile_size, - ) - - return ( - output_row_major, - output_col_major.t(), - row_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), - col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), - ) - - def to_mxfp8_across_dim0_and_dim1_reference( - x_hp: torch.Tensor, block_size - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - """ - A reference version of `to_mxfp8_across_dim0_and_dim1`. - """ - from torchao.prototype.mx_formats.mx_tensor import to_mx - - # cast across dim0 - scale_e8m0_dim0, x_hp_d0_normalized = to_mx( - x_hp, torch.float8_e4m3fn, block_size - ) - scale_e8m0_dim0 = scale_e8m0_dim0.unsqueeze(1).view(torch.float8_e8m0fnu) - # cast across dim1 - x_hp_d1 = x_hp.t().contiguous() - scale_e8m0_dim1, x_hp_d1_normalized = to_mx( - x_hp_d1, torch.float8_e4m3fn, block_size - ) - scale_e8m0_dim1 = scale_e8m0_dim1.unsqueeze(1).view(torch.float8_e8m0fnu) - return ( - x_hp_d0_normalized, - x_hp_d1_normalized.t(), - scale_e8m0_dim0, - scale_e8m0_dim1, - ) - @triton.jit def to_mxfp8_dim1_kernel( x_ptr, # pointer to input tensor @@ -1531,24 +1341,6 @@ def to_mxfp8_dim1(x, inner_block_size=32): col_scale.reshape(-1, 1).view(torch.float8_e8m0fnu), ) - def to_mxfp8_dim0_reference( - x_hp: torch.Tensor, block_size - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - A reference version of `to_mxfp8_dim0`. - """ - from torchao.prototype.mx_formats.mx_tensor import to_mx - - # cast across dim0 - scale_e8m0_dim0, x_hp_d0_normalized = to_mx( - x_hp, torch.float8_e4m3fn, block_size - ) - scale_e8m0_dim0 = scale_e8m0_dim0.unsqueeze(1).view(torch.float8_e8m0fnu) - return ( - x_hp_d0_normalized, - scale_e8m0_dim0, - ) - def to_mxfp8_dim1_reference( x_hp: torch.Tensor, block_size ) -> Tuple[torch.Tensor, torch.Tensor]: From 5527e722eb7fa3fc0209780b2197997140b558c0 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 21 Mar 2025 07:23:20 -0700 Subject: [PATCH 03/20] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_custom_cast.py | 13 ++++--- torchao/prototype/mx_formats/custom_cast.py | 35 +------------------ 2 files changed, 10 insertions(+), 38 deletions(-) diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index 58ccb7f076..f31d485b43 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -45,7 +45,11 @@ sem_vals_to_f32, ) from torchao.prototype.mx_formats.mx_tensor import MXTensor -from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100 +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_8, + is_sm_at_least_89, + is_sm_at_least_100, +) torch.manual_seed(0) @@ -449,9 +453,11 @@ def test_fp6_e3m2_pack_unpack(): assert torch.all(orig_vals_f6_packed_unpacked == orig_vals) -# TODO(before land): skip before sm89 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif(not has_triton(), reason="unsupported without triton") +@pytest.mark.skipif( + not is_sm_at_least_89(), + reason="float8 in triton requires CUDA capability 8.9 or greater", +) def test_triton_mxfp8_dim1(): M, K = 1024, 2048 x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") @@ -459,4 +465,3 @@ def test_triton_mxfp8_dim1(): x_mx_t, x_s_t = to_mxfp8_dim1(x, inner_block_size=32) torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0) torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0) - print("done") diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 6dba900abe..4b0e7b9ff4 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -1189,17 +1189,11 @@ def to_mxfp8_dim1_kernel( x_block_abs_t_r = tl.abs(x_block_t_r) # Find the maximum absolute value for each column - # col_scale, col_scale_e8m0 = _triton_calculate_scale(x_block_abs, axis=0) # shape: (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) col_scale_r, col_scale_e8m0_r = _triton_calculate_scale(x_block_abs_t_r, axis=1) - # tl.device_print("col_scale.shape", col_scale.shape[0]) # Divide each column by scale # Broadcasting col_scale to match x_block's shape - # x_block shape (n_rows, n_cols) - # col_scale shape (n_cols,) -> (1, n_cols) - # col_normalized = x_block / col_scale[None, :] - # x_block_t shape (COL_TILE_SIZE, ROW_TILE_SIZE) # x_block_t_r shape (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, INNER_BLOCK_SIZE) # col_scale shape (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) -> (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, 1) @@ -1217,40 +1211,21 @@ def to_mxfp8_dim1_kernel( # Store the column-normalized result in column-major format tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask) - # Create 1D ranges for storing row and column max values - # col_indices = start_col + tl.arange(0, COL_TILE_SIZE) - - # Create masks for valid rows and columns - # col_mask = col_indices < n_cols - # reshape col_scale_e8m0_r to col_scale_e8m0 # shape: (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) -> (COL_TILE_SIZE, ROW_TILE_SIZE // INNER_BLOCK_SIZE,) # col_scale_e8m0 = col_scale_e8m0_r.reshape(COL_TILE_SIZE, ROW_TILE_SIZE // INNER_BLOCK_SIZE) col_scale_e8m0 = col_scale_e8m0_r.reshape( COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE ) - # col_scale_e8m0 = col_scale_e8m0_r.ravel() - - # col_scale_start_offsets = ( - # ( - # pid_col * COL_TILE_SIZE * (n_rows // ROW_TILE_SIZE) - # ) # number of blocks seen so far - # + pid_row # increment ROW_TILE_SIZE - # ) factor = ROW_TILE_SIZE // INNER_BLOCK_SIZE - col_scale_start_offsets = ( (pid_col * COL_TILE_SIZE * (n_rows // ROW_TILE_SIZE)) * factor # number of blocks seen so far + pid_row * factor # increment ROW_TILE_SIZE ) - # tl.device_print("offset", col_scale_start_offsets) col_scale_start_ptr = col_scale_ptr + col_scale_start_offsets - # col_scale_start_ptr = col_scale_ptr - - # col_scale_indices = tl.arange(0, COL_TILE_SIZE) * (n_rows // ROW_TILE_SIZE) # calculate col_scale_indices, this is a bit convoluted # start with a sequential index [0, COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE] @@ -1263,14 +1238,11 @@ def to_mxfp8_dim1_kernel( # needs better name jump_vals_per_col = (n_rows - ROW_TILE_SIZE) // INNER_BLOCK_SIZE - # tl.device_print("jump_vals_per_col", jump_vals_per_col) # [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13] col_scale_indices = col_scale_indices + ( tl.floor(col_scale_indices / factor) * jump_vals_per_col ).to(tl.int32) - # tl.static_print(col_scale_indices) - # tl.device_print("indices", col_scale_indices) # TODO(future): mask tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) @@ -1287,6 +1259,7 @@ def to_mxfp8_dim1(x, inner_block_size=32): """ assert x.is_contiguous(), "`x` must be contiguous" assert x.dtype == torch.bfloat16 + # Get tensor shape n_rows, n_cols = x.shape @@ -1299,11 +1272,7 @@ def to_mxfp8_dim1(x, inner_block_size=32): # TODO autotune col_tile_size # input 16k by 16k - # triton scaling mostly commented out - # row_tile_size=256, col_tile_size=64: 3.47 TB/s - # TODO(next): make calculations work with ^ col_tile_size = 128 - # col_tile_size = 4 # Create output tensors output_col_major = torch.empty( @@ -1311,7 +1280,6 @@ def to_mxfp8_dim1(x, inner_block_size=32): ) # Create scale tensors - # TODO(before land): switch back to empty col_scale = torch.empty( n_cols, n_rows // inner_block_size, dtype=torch.uint8, device=x.device ) @@ -1320,7 +1288,6 @@ def to_mxfp8_dim1(x, inner_block_size=32): grid_rows = triton.cdiv(n_rows, row_tile_size) grid_cols = triton.cdiv(n_cols, col_tile_size) - # inner_block_size = 32 rename_me_tile_size = row_tile_size * col_tile_size // inner_block_size # Launch the kernel From 478b9e1413eb4f4fbcd36dd3bcadc906ca12cbf6 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 21 Mar 2025 08:44:24 -0700 Subject: [PATCH 04/20] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_custom_cast.py | 5 +- torchao/prototype/mx_formats/custom_cast.py | 49 +++++++++++-------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/test/prototype/mx_formats/test_custom_cast.py b/test/prototype/mx_formats/test_custom_cast.py index f31d485b43..1bdfb477c4 100644 --- a/test/prototype/mx_formats/test_custom_cast.py +++ b/test/prototype/mx_formats/test_custom_cast.py @@ -458,8 +458,9 @@ def test_fp6_e3m2_pack_unpack(): not is_sm_at_least_89(), reason="float8 in triton requires CUDA capability 8.9 or greater", ) -def test_triton_mxfp8_dim1(): - M, K = 1024, 2048 +@pytest.mark.parametrize("M", (256, 2048)) +@pytest.mark.parametrize("K", (256, 2048)) +def test_triton_mxfp8_dim1(M, K): x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") x_mx_ref, x_s_ref = to_mxfp8_dim1_reference(x, block_size=32) x_mx_t, x_s_t = to_mxfp8_dim1(x, inner_block_size=32) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 4b0e7b9ff4..7e0b760745 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -1132,6 +1132,25 @@ def _triton_calculate_scale(x, axis): return scale_fp, scale_e8m0_biased + def _get_mxfp8_dim1_kernel_autotune_configs(): + results = [] + for ROW_TILE_SIZE in (64, 128): + for COL_TILE_SIZE in (64, 128): + for num_warps in (1, 2, 4): + config = triton.Config( + { + "ROW_TILE_SIZE": ROW_TILE_SIZE, + "COL_TILE_SIZE": COL_TILE_SIZE, + }, + num_warps=num_warps, + ) + results.append(config) + return results + + @triton.autotune( + configs=_get_mxfp8_dim1_kernel_autotune_configs(), + key=["n_rows", "n_cols", "INNER_BLOCK_SIZE"], + ) @triton.jit def to_mxfp8_dim1_kernel( x_ptr, # pointer to input tensor @@ -1141,9 +1160,12 @@ def to_mxfp8_dim1_kernel( n_cols, # number of columns in the tensor ROW_TILE_SIZE: tl.constexpr, # can be autotuned COL_TILE_SIZE: tl.constexpr, # can be autotuned - RENAME_ME_TILE_SIZE: tl.constexpr, INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX ): + RENAME_ME_TILE_SIZE: tl.constexpr = ( + ROW_TILE_SIZE * COL_TILE_SIZE // INNER_BLOCK_SIZE + ) + # Get program ID pid_row = tl.program_id(0) pid_col = tl.program_id(1) @@ -1259,21 +1281,11 @@ def to_mxfp8_dim1(x, inner_block_size=32): """ assert x.is_contiguous(), "`x` must be contiguous" assert x.dtype == torch.bfloat16 + assert inner_block_size <= 32 # Get tensor shape n_rows, n_cols = x.shape - # can be autotuned - # row_tile_size = 32 - row_tile_size = 128 - # row_tile_size = 4 - - assert row_tile_size >= inner_block_size - - # TODO autotune col_tile_size - # input 16k by 16k - col_tile_size = 128 - # Create output tensors output_col_major = torch.empty( (n_cols, n_rows), dtype=torch.float8_e4m3fn, device=x.device @@ -1285,21 +1297,18 @@ def to_mxfp8_dim1(x, inner_block_size=32): ) # Calculate grid dimensions based on tile size - grid_rows = triton.cdiv(n_rows, row_tile_size) - grid_cols = triton.cdiv(n_cols, col_tile_size) - - rename_me_tile_size = row_tile_size * col_tile_size // inner_block_size + grid = lambda META: ( + triton.cdiv(n_rows, META["ROW_TILE_SIZE"]), + triton.cdiv(n_cols, META["COL_TILE_SIZE"]), + ) # Launch the kernel - to_mxfp8_dim1_kernel[(grid_rows, grid_cols)]( + to_mxfp8_dim1_kernel[grid]( x_ptr=x, output_col_major_ptr=output_col_major, col_scale_ptr=col_scale, n_rows=n_rows, n_cols=n_cols, - COL_TILE_SIZE=col_tile_size, - ROW_TILE_SIZE=row_tile_size, - RENAME_ME_TILE_SIZE=rename_me_tile_size, INNER_BLOCK_SIZE=inner_block_size, ) From 571775d428202e972a5052ca0eeb230cc7c8cafb Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 21 Mar 2025 08:54:10 -0700 Subject: [PATCH 05/20] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/custom_cast.py | 30 ++++++++++----------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 7e0b760745..baa2594ad2 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -14,7 +14,7 @@ _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, ) -from torchao.utils import TORCH_VERSION_AT_LEAST_2_4 +from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, TORCH_VERSION_AT_LEAST_2_8 # TODO(future): if needed, make the below work on previous PyTorch versions, # just need to hunt down the previous location of `libdevice`. An assert @@ -1084,7 +1084,7 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: raise AssertionError("fp6 packing unsupported without torch >= 2.4") -if TORCH_VERSION_AT_LEAST_2_4 and has_triton(): +if TORCH_VERSION_AT_LEAST_2_8 and has_triton(): import triton import triton.language as tl @@ -1101,11 +1101,8 @@ def _triton_calculate_scale(x, axis): # Find the maximum absolute value for each row max_abs = tl.max(x, axis=axis) - # return max_abs, max_abs.to(tl.uint8) - # TODO 1: rewrite as bit shifts, see https://github.com/pytorch/ao/pull/1908/files - # before: 1.7 TB/s - # after: ? + # TODO(future): rewrite as bit shifts, see https://github.com/pytorch/ao/pull/1908/files scale_e8m0_unbiased = tl.floor(tl.log2(max_abs + epsilon)) - target_max_pow2 # max_abs = max_abs + epsilon # max_abs = max_abs.to(tl.bfloat16) @@ -1126,8 +1123,7 @@ def _triton_calculate_scale(x, axis): # TODO(future PR): add NaN handling here # For now, calculate the scale in floating point. - # TODO(future) audit if there is a need to bit shift exponents instead. - # TODO 2: rewrite as bit shifts, see https://github.com/pytorch/ao/pull/1910/files + # TODO(future): rewrite as bit shifts, see https://github.com/pytorch/ao/pull/1910/files scale_fp = tl.exp2(scale_e8m0_unbiased.to(tl.float32)) return scale_fp, scale_e8m0_biased @@ -1162,6 +1158,10 @@ def to_mxfp8_dim1_kernel( COL_TILE_SIZE: tl.constexpr, # can be autotuned INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX ): + # TODO(future): better name + BLOCKS_PER_ROW_TILE: tl.constexpr = ROW_TILE_SIZE // INNER_BLOCK_SIZE + + # TODO(future): better name RENAME_ME_TILE_SIZE: tl.constexpr = ( ROW_TILE_SIZE * COL_TILE_SIZE // INNER_BLOCK_SIZE ) @@ -1240,11 +1240,10 @@ def to_mxfp8_dim1_kernel( COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE ) - factor = ROW_TILE_SIZE // INNER_BLOCK_SIZE col_scale_start_offsets = ( (pid_col * COL_TILE_SIZE * (n_rows // ROW_TILE_SIZE)) - * factor # number of blocks seen so far - + pid_row * factor # increment ROW_TILE_SIZE + * BLOCKS_PER_ROW_TILE # number of blocks seen so far + + pid_row * BLOCKS_PER_ROW_TILE # increment ROW_TILE_SIZE ) col_scale_start_ptr = col_scale_ptr + col_scale_start_offsets @@ -1255,15 +1254,14 @@ def to_mxfp8_dim1_kernel( col_scale_indices = tl.arange( 0, COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE ) - # add offset for inner blocks - factor = ROW_TILE_SIZE // INNER_BLOCK_SIZE # needs better name jump_vals_per_col = (n_rows - ROW_TILE_SIZE) // INNER_BLOCK_SIZE + # example transformation (specifics depend on tile sizes): # [0, 1, 2, 3, 4, 5, 6, 7] -> [0, 1, 4, 5, 8, 9, 12, 13] col_scale_indices = col_scale_indices + ( - tl.floor(col_scale_indices / factor) * jump_vals_per_col + tl.floor(col_scale_indices / BLOCKS_PER_ROW_TILE) * jump_vals_per_col ).to(tl.int32) # TODO(future): mask @@ -1339,9 +1337,9 @@ def to_mxfp8_dim1_reference( else: def to_mxfp8_across_dim0_and_dim1(x, tile_size=32): - raise AssertionError("needs torch version 2.4+ and triton") + raise AssertionError("needs torch version 2.8+ and triton") def scale_dim0_dim1_reference( x_hp: torch.Tensor, block_size ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - raise AssertionError("needs torch version 2.4+ and triton") + raise AssertionError("needs torch version 2.8+ and triton") From fd3055853c988aa413862a4f232742fb8ba36b4e Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 21 Mar 2025 08:56:59 -0700 Subject: [PATCH 06/20] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/custom_cast.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index baa2594ad2..05a4d3a4bd 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -1161,11 +1161,6 @@ def to_mxfp8_dim1_kernel( # TODO(future): better name BLOCKS_PER_ROW_TILE: tl.constexpr = ROW_TILE_SIZE // INNER_BLOCK_SIZE - # TODO(future): better name - RENAME_ME_TILE_SIZE: tl.constexpr = ( - ROW_TILE_SIZE * COL_TILE_SIZE // INNER_BLOCK_SIZE - ) - # Get program ID pid_row = tl.program_id(0) pid_col = tl.program_id(1) @@ -1205,7 +1200,9 @@ def to_mxfp8_dim1_kernel( # TODO: make this generic and nice # inner_block_size = 32 # shape: (COL_TILE_SIZE, ROW_TILE_SIZE) -> (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, INNER_BLOCK_SIZE) - x_block_t_r = x_block_t.reshape(RENAME_ME_TILE_SIZE, INNER_BLOCK_SIZE) + x_block_t_r = x_block_t.reshape( + COL_TILE_SIZE * BLOCKS_PER_ROW_TILE, INNER_BLOCK_SIZE + ) # Calculate the absolute values of elements in the block x_block_abs_t_r = tl.abs(x_block_t_r) From b0cd056d1f77692c1f6d512ab86c543904e31baf Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 21 Mar 2025 09:00:08 -0700 Subject: [PATCH 07/20] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/custom_cast.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 05a4d3a4bd..6e1a27e1c0 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -1233,9 +1233,7 @@ def to_mxfp8_dim1_kernel( # reshape col_scale_e8m0_r to col_scale_e8m0 # shape: (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) -> (COL_TILE_SIZE, ROW_TILE_SIZE // INNER_BLOCK_SIZE,) # col_scale_e8m0 = col_scale_e8m0_r.reshape(COL_TILE_SIZE, ROW_TILE_SIZE // INNER_BLOCK_SIZE) - col_scale_e8m0 = col_scale_e8m0_r.reshape( - COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE - ) + col_scale_e8m0 = col_scale_e8m0_r.reshape(COL_TILE_SIZE * BLOCKS_PER_ROW_TILE) col_scale_start_offsets = ( (pid_col * COL_TILE_SIZE * (n_rows // ROW_TILE_SIZE)) @@ -1248,9 +1246,7 @@ def to_mxfp8_dim1_kernel( # calculate col_scale_indices, this is a bit convoluted # start with a sequential index [0, COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE] # from example: [0, 1, 2, 3, 4, 5, 6, 7] - col_scale_indices = tl.arange( - 0, COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE - ) + col_scale_indices = tl.arange(0, COL_TILE_SIZE * BLOCKS_PER_ROW_TILE) # needs better name jump_vals_per_col = (n_rows - ROW_TILE_SIZE) // INNER_BLOCK_SIZE From 26b49fd1154cb9a6cf5a60b904e87a3d6a7cf5b7 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 21 Mar 2025 09:19:47 -0700 Subject: [PATCH 08/20] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/custom_cast.py | 65 +++++++++++++++------ 1 file changed, 48 insertions(+), 17 deletions(-) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 6e1a27e1c0..aeb5081c8b 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -1154,11 +1154,47 @@ def to_mxfp8_dim1_kernel( col_scale_ptr, # pointer to store column-wise maximum absolute values n_rows, # number of rows in the tensor n_cols, # number of columns in the tensor - ROW_TILE_SIZE: tl.constexpr, # can be autotuned - COL_TILE_SIZE: tl.constexpr, # can be autotuned + ROW_TILE_SIZE: tl.constexpr, + COL_TILE_SIZE: tl.constexpr, INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX ): - # TODO(future): better name + """ + Example tiling for n_rows==8, n_cols=8, ROW_TILE_SIZE=4, COL_TILE_SIZE=4, INNER_BLOCK_SIZE=2, + pid_row=0, pid_col=0: + + Input (row-major) + + cols 0 1 2 3 4 5 6 7 + -------------------------------- + rows 0 | 0 1 2 3 + 1 | 8 9 10 11 + 2 | 16 17 18 19 + 3 | 24 25 26 27 + 4 | + 5 | + 6 | + 7 | + + Output (row-major of transpose), ids are from input + + cols 0 1 2 3 4 5 6 7 + -------------------------------- + rows 0 | 0 8 16 24 + 1 | 1 9 17 25 + 2 | 2 10 18 26 + 3 | 3 11 19 27 + 4 | + 5 | + 6 | + 7 | + + Output (scales), s(0, 8) means the scale used to cast elements 0 and 8 + + rows 0 1 2 ... 31 + ---------------------------------------------------- + s(0, 8) s(16, 24) s(1, 9) ... s(19, 27) + """ + BLOCKS_PER_ROW_TILE: tl.constexpr = ROW_TILE_SIZE // INNER_BLOCK_SIZE # Get program ID @@ -1197,9 +1233,7 @@ def to_mxfp8_dim1_kernel( x_block_t = tl.trans(x_block) # Reshape to inner tile size - # TODO: make this generic and nice - # inner_block_size = 32 - # shape: (COL_TILE_SIZE, ROW_TILE_SIZE) -> (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, INNER_BLOCK_SIZE) + # shape: (COL_TILE_SIZE, ROW_TILE_SIZE) -> (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE, INNER_BLOCK_SIZE) x_block_t_r = x_block_t.reshape( COL_TILE_SIZE * BLOCKS_PER_ROW_TILE, INNER_BLOCK_SIZE ) @@ -1208,14 +1242,13 @@ def to_mxfp8_dim1_kernel( x_block_abs_t_r = tl.abs(x_block_t_r) # Find the maximum absolute value for each column - # shape: (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) + # shape: (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE,) col_scale_r, col_scale_e8m0_r = _triton_calculate_scale(x_block_abs_t_r, axis=1) # Divide each column by scale # Broadcasting col_scale to match x_block's shape - # x_block_t shape (COL_TILE_SIZE, ROW_TILE_SIZE) - # x_block_t_r shape (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, INNER_BLOCK_SIZE) - # col_scale shape (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) -> (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE, 1) + # x_block_t_r shape (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE, INNER_BLOCK_SIZE) + # col_scale shape (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE,) -> (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE, 1) col_normalized_t_r = x_block_t_r / col_scale_r[:, None] # Reshape back to original tile size @@ -1231,24 +1264,22 @@ def to_mxfp8_dim1_kernel( tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask) # reshape col_scale_e8m0_r to col_scale_e8m0 - # shape: (COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE,) -> (COL_TILE_SIZE, ROW_TILE_SIZE // INNER_BLOCK_SIZE,) - # col_scale_e8m0 = col_scale_e8m0_r.reshape(COL_TILE_SIZE, ROW_TILE_SIZE // INNER_BLOCK_SIZE) + # shape: (COL_TILE_SIZE * BLOCKS_PER_ROW_TILE,) -> (COL_TILE_SIZE, BLOCKS_PER_ROW_TILE,) col_scale_e8m0 = col_scale_e8m0_r.reshape(COL_TILE_SIZE * BLOCKS_PER_ROW_TILE) col_scale_start_offsets = ( (pid_col * COL_TILE_SIZE * (n_rows // ROW_TILE_SIZE)) * BLOCKS_PER_ROW_TILE # number of blocks seen so far - + pid_row * BLOCKS_PER_ROW_TILE # increment ROW_TILE_SIZE + + pid_row * BLOCKS_PER_ROW_TILE # increment BLOCKS_PER_ROW_TILE ) col_scale_start_ptr = col_scale_ptr + col_scale_start_offsets - # calculate col_scale_indices, this is a bit convoluted - # start with a sequential index [0, COL_TILE_SIZE * ROW_TILE_SIZE // INNER_BLOCK_SIZE] - # from example: [0, 1, 2, 3, 4, 5, 6, 7] + # calculate col_scale_indices col_scale_indices = tl.arange(0, COL_TILE_SIZE * BLOCKS_PER_ROW_TILE) - # needs better name + # How many values are in all the other columns for this row_pid, need to jump + # over them for every BLOCKS_PER_ROW_TILE values jump_vals_per_col = (n_rows - ROW_TILE_SIZE) // INNER_BLOCK_SIZE # example transformation (specifics depend on tile sizes): From ba10a02cbc386d1e236a03149c87d1931ac64a68 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 21 Mar 2025 09:22:42 -0700 Subject: [PATCH 09/20] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/custom_cast.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index aeb5081c8b..eb1e079fb7 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -1190,9 +1190,9 @@ def to_mxfp8_dim1_kernel( Output (scales), s(0, 8) means the scale used to cast elements 0 and 8 - rows 0 1 2 ... 31 - ---------------------------------------------------- - s(0, 8) s(16, 24) s(1, 9) ... s(19, 27) + rows 0 1 ... 4 ... 31 + ------------------------------------------------------ + s(0, 8) s(16, 24) ... s(1, 9) ... s(19, 27) """ BLOCKS_PER_ROW_TILE: tl.constexpr = ROW_TILE_SIZE // INNER_BLOCK_SIZE From 483cdfd2b6cf02bacf17af20eb9bfd66878b245d Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 21 Mar 2025 09:30:35 -0700 Subject: [PATCH 10/20] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/custom_cast.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index eb1e079fb7..4e97b09711 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -1129,6 +1129,9 @@ def _triton_calculate_scale(x, axis): return scale_fp, scale_e8m0_biased def _get_mxfp8_dim1_kernel_autotune_configs(): + # Values to sweep over here were determined by a manual + # sweep over a small set of shapes, it's likely that this + # can be improved in the future. results = [] for ROW_TILE_SIZE in (64, 128): for COL_TILE_SIZE in (64, 128): @@ -1261,6 +1264,7 @@ def to_mxfp8_dim1_kernel( col_normalized = col_normalized.to(tl.float8e4nv) # Store the column-normalized result in column-major format + # TODO(future): this mask is for row-major likely need to transpose it for col-major tl.store(output_col_major_ptr + col_major_offsets, col_normalized, mask=mask) # reshape col_scale_e8m0_r to col_scale_e8m0 @@ -1288,7 +1292,7 @@ def to_mxfp8_dim1_kernel( tl.floor(col_scale_indices / BLOCKS_PER_ROW_TILE) * jump_vals_per_col ).to(tl.int32) - # TODO(future): mask + # TODO(future): mask this store tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) def to_mxfp8_dim1(x, inner_block_size=32): @@ -1308,6 +1312,15 @@ def to_mxfp8_dim1(x, inner_block_size=32): # Get tensor shape n_rows, n_cols = x.shape + # Masking of loads and stores is not well tested yet, so for now enforce + # shapes which do not need masking. Note that this condition depends on max values of + # ROW_TILE_SIZE and COL_TILE_SIZE, which are autotuned above. + # TODO(future): implement and test masking and remove this restriction + max_row_tile_size = 128 + max_col_tile_size = 128 + assert n_rows % max_row_tile_size == 0, "unsupported" + assert n_cols % max_col_tile_size == 0, "unsupported" + # Create output tensors output_col_major = torch.empty( (n_cols, n_rows), dtype=torch.float8_e4m3fn, device=x.device From 32005c911ea52c68486107cdada01ee003cdc483 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 21 Mar 2025 09:44:57 -0700 Subject: [PATCH 11/20] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/custom_cast.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 4e97b09711..0b16d38008 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -1295,7 +1295,7 @@ def to_mxfp8_dim1_kernel( # TODO(future): mask this store tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) - def to_mxfp8_dim1(x, inner_block_size=32): + def to_mxfp8_dim1(x, inner_block_size=32) -> Tuple[torch.Tensor, torch.Tensor]: """ Input: * `x` - input tensor, in row major memory layout @@ -1373,10 +1373,10 @@ def to_mxfp8_dim1_reference( else: - def to_mxfp8_across_dim0_and_dim1(x, tile_size=32): + def to_mxfp8_dim1(x, inner_block_size=32) -> Tuple[torch.Tensor, torch.Tensor]: raise AssertionError("needs torch version 2.8+ and triton") - def scale_dim0_dim1_reference( + def to_mxfp8_dim1_reference( x_hp: torch.Tensor, block_size - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: raise AssertionError("needs torch version 2.8+ and triton") From 7ecd79f7300a15447b9684f6cb07a80a48498c4a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 24 Mar 2025 07:46:26 -0700 Subject: [PATCH 12/20] Update [ghstack-poisoned] --- torchao/prototype/mx_formats/custom_cast.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index c74dfa630a..5690b0f5f0 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -1090,15 +1090,16 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: @triton.jit def _triton_calculate_scale(x, axis): - # We use a small epsilon to avoid division by zero - epsilon = 1e-10 - - # TODO(before land): reuse the constants below instead of hardcoding + # There is no good support for accessing globals from a jit'ed triton + # function, so we redefine them here. Since this is prototype code which + # we plan to remove after torch.compile catches up, this is fine. target_max_pow2 = 8 e8m0_exponent_bias = 127 bf16_mbits = 7 bf16_exp_bias = 127 fp32_mbits = 23 + # We use a small epsilon to avoid division by zero + epsilon = 1e-10 # Find the maximum absolute value for each row max_abs = tl.max(x, axis=axis) From ca3c4cfa8a3e39f3cac2dddba3e54db75a6a9081 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 24 Mar 2025 11:00:40 -0700 Subject: [PATCH 13/20] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 49 +++++++++-- torchao/prototype/mx_formats/config.py | 4 + torchao/prototype/mx_formats/custom_cast.py | 6 +- torchao/prototype/mx_formats/mx_linear.py | 96 ++++++++++++++++----- torchao/prototype/mx_formats/mx_tensor.py | 4 + 5 files changed, 127 insertions(+), 32 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 326e68beb1..1776dfb52a 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -51,19 +51,31 @@ def run_around_tests(): "elem_dtype", itertools.product(SUPPORTED_ELEM_DTYPES, repeat=3) ) @pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)]) -def test_linear_eager(elem_dtype, bias, input_shape): +@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)]) +@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True]) +def test_linear_eager_vs_hp( + elem_dtype, bias, input_shape, use_fp8_dim1_cast_triton_kernel +): """ Smoke test for training linear module with mx weight, compares the following: * baseline: float32 * experiment: emulated MX """ + if use_fp8_dim1_cast_triton_kernel and elem_dtype != ( + torch.float8_e4m3fn, + torch.float8_e4m3fn, + torch.float8_e4m3fn, + ): + pytest.skip("unsupported configuration") + if use_fp8_dim1_cast_triton_kernel and (not is_sm_at_least_89()): + pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + # elem_dtype is a tuple of (input, weight, gradient) dtypes. grad_shape = list(input_shape) - grad_shape[-1] = 8 + grad_shape[-1] = 256 m = nn.Sequential( - nn.Linear(8, 8, bias=bias, device="cuda"), + nn.Linear(256, 256, bias=bias, device="cuda"), ) m_mx = copy.deepcopy(m) config = MXLinearConfig( @@ -71,6 +83,7 @@ def test_linear_eager(elem_dtype, bias, input_shape): elem_dtype=elem_dtype[0], elem_dtype_weight_override=elem_dtype[1], elem_dtype_grad_output_override=elem_dtype[2], + use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel, ) swap_linear_with_mx_linear(m_mx, config=config) @@ -169,6 +182,7 @@ def test_activation_checkpointing(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize( "recipe_name", [ @@ -182,7 +196,8 @@ def test_activation_checkpointing(): @pytest.mark.parametrize("bias", [False, True]) # TODO(future PR): figure out why torch.compile does not match eager when # autocast is on -def test_linear_compile(recipe_name, bias): +@pytest.mark.parametrize("use_fp8_dim1_cast_triton_kernel", [False, True]) +def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_kernel): """ Verify that compile does not change numerics of MX linear fw + bw """ @@ -198,20 +213,36 @@ def test_linear_compile(recipe_name, bias): # TODO(future PR): fix this, things are clearly broken with bias=True pytest.skip("this test is broken for non-emulated recipes with bias=True") + if use_fp8_dim1_cast_triton_kernel: + if recipe_name not in ("mxfp8_emulated", "mxfp8_cublas", "mxfp8_cutlass"): + pytest.skip("unsupported configuration") + if not is_sm_at_least_89(): + pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + if hp_dtype != torch.bfloat16: + pytest.skip("unsupported configuration") + + if hp_dtype == torch.bfloat16 and recipe_name != "mxfp8_cublas": + # TODO(future PR): properly enable float32 + bfloat16 for every + # recipe, this needs a cleanup of out_dtype (needs to match in-hp-dtype, even + # if the underlying gemm kernel only supports bf16 output) + pytest.skip("unsupported configuration") + M, K, N = 128, 256, 512 input_shape = (M, K) grad_shape = (M, N) m_mx = nn.Sequential( - nn.Linear(K, N, bias=bias, device="cuda"), + nn.Linear(K, N, bias=bias, device="cuda", dtype=hp_dtype), ) config = MXLinearConfig.from_recipe_name(recipe_name) + config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel + swap_linear_with_mx_linear(m_mx, config=config) m_mx_c = copy.deepcopy(m_mx) m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") - x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() + x_ref = torch.randn(*input_shape, device="cuda", dtype=hp_dtype).requires_grad_() x = copy.deepcopy(x_ref) - g = torch.randn(*grad_shape, device="cuda") + g = torch.randn(*grad_shape, device="cuda", dtype=hp_dtype) y_ref = m_mx(x_ref) y = m_mx_c(x) @@ -283,7 +314,7 @@ def test_inference_compile_simple(elem_dtype): if elem_dtype is torch.float8_e4m3fn: assert sqnr >= 20.0 else: - assert sqnr >= 13.5 + assert sqnr >= 11.5 @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 6eeb05889a..1fd1e2470c 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -58,6 +58,10 @@ class MXLinearConfig: # on the given hardware an exception will be thrown gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED + # If True, uses a custom triton kernel for cast to mxfp8 across dim1 + # TODO(before land): link issue number + use_fp8_dim1_cast_triton_kernel: bool = False + # If True, uses a custom triton kernel for fp4 dequantize use_fp4_custom_triton_dequant_kernel: bool = False diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index 5690b0f5f0..a78d980c37 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -1087,6 +1087,7 @@ def pack_uint6(uint8_data: torch.Tensor) -> torch.Tensor: if TORCH_VERSION_AT_LEAST_2_8 and has_triton(): import triton import triton.language as tl + from torch.library import triton_op, wrap_triton @triton.jit def _triton_calculate_scale(x, axis): @@ -1298,8 +1299,9 @@ def to_mxfp8_dim1_kernel( # TODO(future): mask this store tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0) + @triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={}) def triton_to_mxfp8_dim1( - x, inner_block_size=32 + x: torch.Tensor, inner_block_size: int = 32 ) -> Tuple[torch.Tensor, torch.Tensor]: """ Input: @@ -1343,7 +1345,7 @@ def triton_to_mxfp8_dim1( ) # Launch the kernel - to_mxfp8_dim1_kernel[grid]( + wrap_triton(to_mxfp8_dim1_kernel)[grid]( x_ptr=x, output_col_major_ptr=output_col_major, col_scale_ptr=col_scale, diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index a17c171f84..4879f67681 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig +from torchao.prototype.mx_formats.custom_cast import triton_to_mxfp8_dim1 from torchao.prototype.mx_formats.mx_tensor import MXTensor @@ -37,6 +38,7 @@ def forward( grad_elem_dtype: Any, block_size: int, gemm_kernel_choice: MXGemmKernelChoice, + use_fp8_dim1_cast_triton_kernel: bool, ): ctx.save_for_backward(input_hp, weight_hp) ctx.in_elem_dtype = in_elem_dtype @@ -44,6 +46,7 @@ def forward( ctx.grad_elem_dtype = grad_elem_dtype ctx.block_size = block_size ctx.gemm_kernel_choice = gemm_kernel_choice + ctx.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel # input @ weight_t = output input_orig_shape = input_hp.shape @@ -63,12 +66,12 @@ def forward( @staticmethod def backward(ctx, grad_output_hp: torch.Tensor): input_hp, weight_hp = ctx.saved_tensors - weight_hp_t_c = weight_hp.t().contiguous() in_elem_dtype = ctx.in_elem_dtype w_elem_dtype = ctx.w_elem_dtype grad_elem_dtype = ctx.grad_elem_dtype block_size = ctx.block_size gemm_kernel_choice = ctx.gemm_kernel_choice + use_fp8_dim1_cast_triton_kernel = ctx.use_fp8_dim1_cast_triton_kernel grad_output_orig_shape = grad_output_hp.shape grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1]) @@ -83,34 +86,84 @@ def backward(ctx, grad_output_hp: torch.Tensor): block_size, gemm_kernel_choice=gemm_kernel_choice, ) - weight_mx_dim1 = MXTensor.to_mx( - weight_hp_t_c, - w_elem_dtype, - block_size, - gemm_kernel_choice=gemm_kernel_choice, - ) + + if use_fp8_dim1_cast_triton_kernel: + weight_mx_dim1_data, weight_mx_dim1_scale = triton_to_mxfp8_dim1( + weight_hp, block_size + ) + weight_mx_dim1 = MXTensor( + weight_mx_dim1_scale.view(torch.uint8).reshape(-1), + weight_mx_dim1_data.t(), + w_elem_dtype, + block_size, + weight_hp.dtype, + False, + gemm_kernel_choice, + False, + ) + + else: + weight_hp_t_c = weight_hp.t().contiguous() + weight_mx_dim1 = MXTensor.to_mx( + weight_hp_t_c, + w_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, + ) grad_input = torch.mm(grad_output_mx_dim0, weight_mx_dim1.t()) grad_input = grad_input.reshape( *grad_output_orig_shape[:-1], grad_input.shape[-1] ) # input_t @ grad_output = grad_weight - grad_output_mx_dim1 = MXTensor.to_mx( - grad_output_hp_r.t().contiguous(), - grad_elem_dtype, - block_size, - gemm_kernel_choice=gemm_kernel_choice, - ) - input_t_mx_dim0_tmp = MXTensor.to_mx( - input_hp_r.t().contiguous(), - in_elem_dtype, - block_size, - gemm_kernel_choice=gemm_kernel_choice, - ) - input_t_mx_dim0 = input_t_mx_dim0_tmp.t() + if use_fp8_dim1_cast_triton_kernel: + grad_output_mx_dim1_data, grad_output_mx_dim1_scale = triton_to_mxfp8_dim1( + grad_output_hp_r, block_size + ) + grad_output_mx_dim1 = MXTensor( + grad_output_mx_dim1_scale.view(torch.uint8).reshape(-1), + grad_output_mx_dim1_data.t(), + grad_elem_dtype, + block_size, + grad_output_hp_r.dtype, + False, + gemm_kernel_choice, + False, + ) + else: + grad_output_mx_dim1 = MXTensor.to_mx( + grad_output_hp_r.t().contiguous(), + grad_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, + ) + + if use_fp8_dim1_cast_triton_kernel: + input_t_mx_dim0_tmp_data, input_t_mx_dim0_tmp_scale = triton_to_mxfp8_dim1( + input_hp_r, block_size + ) + input_t_mx_dim0_tmp = MXTensor( + input_t_mx_dim0_tmp_scale.view(torch.uint8).reshape(-1), + input_t_mx_dim0_tmp_data.t(), + in_elem_dtype, + block_size, + input_hp_r.dtype, + False, + gemm_kernel_choice, + False, + ) + input_t_mx_dim0 = input_t_mx_dim0_tmp.t() + else: + input_t_mx_dim0_tmp = MXTensor.to_mx( + input_hp_r.t().contiguous(), + in_elem_dtype, + block_size, + gemm_kernel_choice=gemm_kernel_choice, + ) + input_t_mx_dim0 = input_t_mx_dim0_tmp.t() grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0) - return grad_input, grad_weight, None, None, None, None, None + return grad_input, grad_weight, None, None, None, None, None, None class MXLinear(torch.nn.Linear): @@ -154,6 +207,7 @@ def forward(self, x): config.elem_dtype_grad_output_override or config.elem_dtype, config.block_size, config.gemm_kernel_choice, + config.use_fp8_dim1_cast_triton_kernel, ) if self.bias is not None: y = y + self.bias diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 5d908d8d50..bfb0032afa 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -247,6 +247,7 @@ def to_mx( # need to reshape at the end to help inductor fuse things data_lp = data_lp.reshape(orig_shape) elif elem_dtype == DTYPE_FP6_E2M3: + # assert data_hp.dtype == torch.float32, f"dtype {data_hp.dtype} is not supported in this codepath yet" data_lp = f32_to_f6_e2m3_unpacked(data_lp) if pack_fp6: orig_shape = [*orig_shape[:-1], 3 * orig_shape[-1] // 4] @@ -254,6 +255,7 @@ def to_mx( # need to reshape at the end to help inductor fuse things data_lp = data_lp.reshape(orig_shape) elif elem_dtype == DTYPE_FP6_E3M2: + # assert data_hp.dtype == torch.float32, f"dtype {data_hp.dtype} is not supported in this codepath yet" data_lp = f32_to_f6_e3m2_unpacked(data_lp) if pack_fp6: orig_shape = [*orig_shape[:-1], 3 * orig_shape[-1] // 4] @@ -261,6 +263,8 @@ def to_mx( # need to reshape at the end to help inductor fuse things data_lp = data_lp.reshape(orig_shape) elif elem_dtype == DTYPE_FP4: + # TODO(future PR): add bfloat16 support here + # assert data_hp.dtype == torch.float32, f"dtype {data_hp.dtype} is not supported in this codepath yet" # can't reshape at the end without handling it in the packing code, # punt until later since we'll need to rethink the torch.compile # approach for fp4x2 in any case From 0de11cfe10bb10ad5376c39682c1da14cabe59ec Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 24 Mar 2025 13:10:01 -0700 Subject: [PATCH 14/20] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 57 +++++++++++++-------- torchao/prototype/mx_formats/config.py | 3 +- torchao/prototype/mx_formats/mx_tensor.py | 4 -- 3 files changed, 39 insertions(+), 25 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 1776dfb52a..425767a0ee 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import copy -import itertools import pytest import torch @@ -16,7 +15,12 @@ MXLinearConfig, MXLinearRecipeName, ) -from torchao.prototype.mx_formats.constants import DTYPE_FP4, SUPPORTED_ELEM_DTYPES +from torchao.prototype.mx_formats.constants import ( + DTYPE_FP4, + DTYPE_FP6_E2M3, + DTYPE_FP6_E3M2, + SUPPORTED_ELEM_DTYPES, +) from torchao.prototype.mx_formats.mx_linear import ( MXInferenceLinear, MXLinear, @@ -48,7 +52,16 @@ def run_around_tests(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize( - "elem_dtype", itertools.product(SUPPORTED_ELEM_DTYPES, repeat=3) + "elem_dtype", + ( + # test each dtype + (torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float8_e4m3fn), + (DTYPE_FP6_E3M2, DTYPE_FP6_E3M2, DTYPE_FP6_E3M2), + (DTYPE_FP6_E2M3, DTYPE_FP6_E2M3, DTYPE_FP6_E2M3), + (DTYPE_FP4, DTYPE_FP4, DTYPE_FP4), + # only test one type of mixed-dtype overrides, to save testing time + (torch.float8_e4m3fn, DTYPE_FP4, DTYPE_FP4), + ), ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)]) @@ -61,21 +74,22 @@ def test_linear_eager_vs_hp( * baseline: float32 * experiment: emulated MX """ - if use_fp8_dim1_cast_triton_kernel and elem_dtype != ( - torch.float8_e4m3fn, - torch.float8_e4m3fn, - torch.float8_e4m3fn, - ): - pytest.skip("unsupported configuration") - if use_fp8_dim1_cast_triton_kernel and (not is_sm_at_least_89()): - pytest.skip("CUDA capability >= 8.9 required for float8 in triton") + if use_fp8_dim1_cast_triton_kernel: + if elem_dtype != ( + torch.float8_e4m3fn, + torch.float8_e4m3fn, + torch.float8_e4m3fn, + ): + pytest.skip("unsupported configuration") + elif not is_sm_at_least_89(): + pytest.skip("CUDA capability >= 8.9 required for float8 in triton") # elem_dtype is a tuple of (input, weight, gradient) dtypes. grad_shape = list(input_shape) grad_shape[-1] = 256 m = nn.Sequential( - nn.Linear(256, 256, bias=bias, device="cuda"), + nn.Linear(256, 256, bias=bias, device="cuda", dtype=torch.bfloat16), ) m_mx = copy.deepcopy(m) config = MXLinearConfig( @@ -87,12 +101,16 @@ def test_linear_eager_vs_hp( ) swap_linear_with_mx_linear(m_mx, config=config) - x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() + x_ref = torch.randn( + *input_shape, device="cuda", dtype=torch.bfloat16 + ).requires_grad_() x = copy.deepcopy(x_ref) g = torch.randn(*grad_shape, device="cuda") - with torch.autocast("cuda", dtype=torch.bfloat16): - y_ref = m(x_ref) - y_mx = m_mx(x) + + y_ref = m(x_ref) + y_mx = m_mx(x) + + assert y_mx.dtype == x.dtype y_ref.backward(g) y_mx.backward(g) @@ -125,7 +143,6 @@ def test_linear_eager_vs_hp( ) @pytest.mark.parametrize("mkn", [(128, 256, 512), (256, 512, 128), (512, 128, 256)]) def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn): - M, K, N = 128, 128, 128 M, K, N = mkn x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda").requires_grad_() @@ -156,9 +173,9 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn): y_sqnr = compute_error(y_real, y_emulated) w_sqnr = compute_error(m_real[0].weight.grad, m_emulated[0].weight.grad) g_sqnr = compute_error(x_copy.grad, x.grad) - assert y_sqnr > 100.0, f"y_sqnr {y_sqnr} too low!" - assert w_sqnr > 100.0, f"w_sqnr {w_sqnr} too low!" - assert g_sqnr > 100.0, f"g_sqnr {g_sqnr} too low!" + assert y_sqnr > 90.0, f"y_sqnr {y_sqnr} too low!" + assert w_sqnr > 90.0, f"w_sqnr {w_sqnr} too low!" + assert g_sqnr > 90.0, f"g_sqnr {g_sqnr} too low!" # TODO(future): enable compile support diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 1fd1e2470c..5951be12e6 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -59,7 +59,8 @@ class MXLinearConfig: gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.EMULATED # If True, uses a custom triton kernel for cast to mxfp8 across dim1 - # TODO(before land): link issue number + # TODO(1945): remove this config option once torch.compile gives us + # a fast kernel use_fp8_dim1_cast_triton_kernel: bool = False # If True, uses a custom triton kernel for fp4 dequantize diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index bfb0032afa..5d908d8d50 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -247,7 +247,6 @@ def to_mx( # need to reshape at the end to help inductor fuse things data_lp = data_lp.reshape(orig_shape) elif elem_dtype == DTYPE_FP6_E2M3: - # assert data_hp.dtype == torch.float32, f"dtype {data_hp.dtype} is not supported in this codepath yet" data_lp = f32_to_f6_e2m3_unpacked(data_lp) if pack_fp6: orig_shape = [*orig_shape[:-1], 3 * orig_shape[-1] // 4] @@ -255,7 +254,6 @@ def to_mx( # need to reshape at the end to help inductor fuse things data_lp = data_lp.reshape(orig_shape) elif elem_dtype == DTYPE_FP6_E3M2: - # assert data_hp.dtype == torch.float32, f"dtype {data_hp.dtype} is not supported in this codepath yet" data_lp = f32_to_f6_e3m2_unpacked(data_lp) if pack_fp6: orig_shape = [*orig_shape[:-1], 3 * orig_shape[-1] // 4] @@ -263,8 +261,6 @@ def to_mx( # need to reshape at the end to help inductor fuse things data_lp = data_lp.reshape(orig_shape) elif elem_dtype == DTYPE_FP4: - # TODO(future PR): add bfloat16 support here - # assert data_hp.dtype == torch.float32, f"dtype {data_hp.dtype} is not supported in this codepath yet" # can't reshape at the end without handling it in the packing code, # punt until later since we'll need to rethink the torch.compile # approach for fp4x2 in any case From 912e4dc59017df988631595c5e772943195338b7 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 24 Mar 2025 15:16:03 -0700 Subject: [PATCH 15/20] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 18 ++++++++++++++++++ torchao/prototype/mx_formats/config.py | 20 ++++++++++++++++++++ torchao/prototype/mx_formats/constants.py | 8 ++++++++ torchao/prototype/mx_formats/mx_linear.py | 8 ++++++++ 4 files changed, 54 insertions(+) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 425767a0ee..6de1d39389 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -401,3 +401,21 @@ def test_filter_fn(): swap_linear_with_mx_inference_linear(m2, config=config, filter_fn=filter_fn) # noqa: E501 assert type(m2[0]) == MXInferenceLinear assert type(m2[1]) == torch.nn.Linear + + +def test_training_print_str(): + m = nn.Sequential(nn.Linear(32, 32)) + config = MXLinearConfig() + swap_linear_with_mx_linear(m, config=config) + s = str(m) + assert "bl_sz=32" in s + assert "kernel=emulated" in s + + +def test_inference_print_str(): + m = nn.Sequential(nn.Linear(32, 32)) + config = MXLinearConfig() + swap_linear_with_mx_inference_linear(m, config=config) + s = str(m) + assert "bl_sz=32" in s + assert "kernel=emulated" in s diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 5951be12e6..438a6a9293 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -12,6 +12,7 @@ from torchao.prototype.mx_formats.constants import ( DTYPE_FP4, + DTYPE_TO_SHORT_STR, SUPPORTED_ELEM_DTYPES, ) @@ -143,3 +144,22 @@ def from_recipe_name( ) else: raise AssertionError(f"unknown recipe_name {recipe_name}") + + def short_str(self) -> str: + """ + Returns a concise representation of the current config. + """ + s = f"bl_sz={self.block_size}, lp_dtype={DTYPE_TO_SHORT_STR[self.elem_dtype]}" + if self.elem_dtype_weight_override is not None: + s += ( + f", lp_w_override={DTYPE_TO_SHORT_STR[self.elem_dtype_weight_override]}" + ) + if self.elem_dtype_grad_output_override is not None: + s += f", lp_go_override={DTYPE_TO_SHORT_STR[self.elem_dtype_grad_output_override]}" + s += f", kernel={self.gemm_kernel_choice.value}" + if self.use_fp8_dim1_cast_triton_kernel: + s += ", use_fp8_dim1_cast_triton_kernel=True" + if self.use_fp4_custom_triton_dequant_kernel: + s += ", use_fp4_custom_triton_dequant_kernel=True" + # TODO(future PR): split training from inference and add fp6 here + return s diff --git a/torchao/prototype/mx_formats/constants.py b/torchao/prototype/mx_formats/constants.py index d2f536e78a..94c63b11e5 100644 --- a/torchao/prototype/mx_formats/constants.py +++ b/torchao/prototype/mx_formats/constants.py @@ -22,6 +22,14 @@ DTYPE_FP4, ] +DTYPE_TO_SHORT_STR = { + torch.float8_e4m3fn: "f8e4m3", + torch.float8_e5m2: "f8e5m2", + DTYPE_FP6_E2M3: "f6e2m3", + DTYPE_FP6_E3M2: "f6e3m2", + DTYPE_FP4: "f4e2m1", +} + F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 F8E5M2_MAX = torch.finfo(torch.float8_e5m2).max # 57344.0 diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 4879f67681..378cc4909a 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -213,6 +213,10 @@ def forward(self, x): y = y + self.bias return y + def extra_repr(self): + s = f"{super().extra_repr()}, {self.config.short_str()}" + return s + class MXInferenceLinear(torch.nn.Linear): """ @@ -255,6 +259,10 @@ def forward(self, x): y = F.linear(x, w_hp, self.bias) return y + def extra_repr(self): + s = f"{super().extra_repr()}, {self.config.short_str()}" + return s + def replace_with_custom_fn_if_matches_filter( model, replacement_fn, filter_fn, cur_fqn="" From fb5662a34b75504bde3035c5269e8fa19de84019 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 25 Mar 2025 13:29:23 -0700 Subject: [PATCH 16/20] Update [ghstack-poisoned] --- torchao/testing/float8/roofline_utils.py | 29 ++++++++---------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/torchao/testing/float8/roofline_utils.py b/torchao/testing/float8/roofline_utils.py index 6578f1721f..92becb9b94 100644 --- a/torchao/testing/float8/roofline_utils.py +++ b/torchao/testing/float8/roofline_utils.py @@ -183,27 +183,16 @@ def get_tensor_memory_traffic_ovhd_s( "mxfp8_cutlass", "mxfp8_cublas", ), "unsupported" - - if tensor_role == "weight": - # x_bf16 = ... - # kernel 1: x_bf16 -> x_mxfp8_dim0 - # kernel 2: x_bf16 -> x_mxfp8_dim1 - if fuse_with_prev: - kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel - else: - kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel - kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel - res_bytes = [kernel_1_rw, kernel_2_rw] + # For now, assume that we can't profitably fuse kernel 1 and kernel 2 + # x_bf16 = ... + # kernel 1: x_bf16 -> x_mxfp8_dim0 + # kernel 2: x_bf16 -> x_mxfp8_dim1 + if fuse_with_prev: + kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel else: - # x_bf16 = ... - # kernel 1: x_bf16 -> x_mxfp8_dim0, x_mxfp8_dim1 - if fuse_with_prev: - kernel_1_rw = 0 + BYTES_PER_EL_FLOAT8 * numel * 2 - else: - kernel_1_rw = ( - BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel * 2 - ) - res_bytes = [kernel_1_rw] + kernel_1_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + kernel_2_rw = BYTES_PER_EL_BF16 * numel + BYTES_PER_EL_FLOAT8 * numel + res_bytes = [kernel_1_rw, kernel_2_rw] # convert from bytes to seconds res_s = [ From 4c2ad8c515baf8a2c009a4306598a821efbbdbf9 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 26 Mar 2025 13:55:48 -0700 Subject: [PATCH 17/20] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 52 --------------------- torchao/prototype/mx_formats/mx_ops.py | 36 ++------------ 2 files changed, 3 insertions(+), 85 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 6de1d39389..1dce269c57 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -10,7 +10,6 @@ import torch import torch.nn as nn -from torchao.float8.float8_utils import is_row_major from torchao.prototype.mx_formats.config import ( MXLinearConfig, MXLinearRecipeName, @@ -334,57 +333,6 @@ def test_inference_compile_simple(elem_dtype): assert sqnr >= 11.5 -@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") -@pytest.mark.skipif( - not is_sm_at_least_100(), - reason="MX gemms require CUDA capability 10.0", -) -def test_scaled_mm_wrapper(): - # today, e8m0 isn't supported in torchinductor or triton - # for now, work around this by creating a wrapper around torch._scaled_mm - # which takes uint8 scales, and reinterprets them as e8m0 inside the wrapper - from torchao.prototype.mx_formats.mx_ops import _scaled_mm_with_uint8_scales - - M, K, N = 128, 256, 512 - BLOCK_SIZE = 32 - a = torch.randn(M, K, device="cuda").to(torch.float8_e4m3fn) - b = torch.randn(N, K, device="cuda").to(torch.float8_e4m3fn) - - a_scale = torch.ones(M, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu) - b_scale = torch.ones(N, K // BLOCK_SIZE, device="cuda", dtype=torch.float8_e8m0fnu) - - out = torch._scaled_mm(a, b.t(), a_scale, b_scale, out_dtype=torch.bfloat16) - - def wrapped(a, b, a_scale, b_scale, out_dtype): - if is_row_major(b.stride()): - b = b.t().contiguous().t() - res = _scaled_mm_with_uint8_scales(a, b, a_scale, b_scale, out_dtype=out_dtype) - return res - - wrapped = torch.compile(wrapped) - - # correct memory format of `b` - out2 = wrapped( - a, - b.t(), - a_scale.view(torch.uint8), - b_scale.view(torch.uint8), - out_dtype=torch.bfloat16, - ) - torch.testing.assert_close(out, out2, atol=0, rtol=0) - - # incorrect memory format of `b` - b_col_major = b.t().contiguous().t() - out3 = wrapped( - a, - b_col_major.t(), - a_scale.view(torch.uint8), - b_scale.view(torch.uint8), - out_dtype=torch.bfloat16, - ) - torch.testing.assert_close(out, out3, atol=0, rtol=0) - - def test_filter_fn(): m1 = nn.Sequential( nn.Linear(32, 32), diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index c2fb37dacb..c5d60a33de 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -35,41 +35,11 @@ tensor_size_hpx3_to_fp6x4, ) from torchao.prototype.mx_formats.utils import to_blocked -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 aten = torch.ops.aten MX_OPS_TABLE: Dict[Any, Any] = {} -if TORCH_VERSION_AT_LEAST_2_5: - - @torch.library.custom_op("mylib::_scaled_mm_with_uint8_scales", mutates_args=()) - def _scaled_mm_with_uint8_scales( - a: torch.Tensor, - b: torch.Tensor, - a_scale: torch.Tensor, - b_scale: torch.Tensor, - out_dtype: torch.dtype, - ) -> torch.Tensor: - """ - Until https://github.com/pytorch/pytorch/issues/147873 is done, we need to - work around the lack of support for `torch.float8_e8m0fnu` in - torchinductor. We do so by hiding the cast of scales to e8m0 inside a - custom op. - """ - # cast back to e8m0 where torchinductor can't see it - a_scale = a_scale.view(torch.float8_e8m0fnu) - b_scale = b_scale.view(torch.float8_e8m0fnu) - res = torch._scaled_mm(a, b, a_scale, b_scale, out_dtype=out_dtype) - return res - - @_scaled_mm_with_uint8_scales.register_fake - def _(a, b, a_scale, b_scale, out_dtype): - m, k = a.shape - k2, n = b.shape - res = torch.empty(m, n, dtype=out_dtype, device=a.device) - return res - def implements(aten_ops): """Register aten ops to the mx op table""" @@ -119,11 +89,11 @@ def mx_mm(aten_op, args, kwargs=None): if a._elem_dtype == torch.float8_e4m3fn: assert b._elem_dtype == torch.float8_e4m3fn if a._gemm_kernel_choice is MXGemmKernelChoice.CUBLAS: - res = _scaled_mm_with_uint8_scales( + res = torch._scaled_mm( a._data, b._data, - a_scale_block, - b_scale_block, + a_scale_block.view(torch.float8_e8m0fnu), + b_scale_block.view(torch.float8_e8m0fnu), out_dtype=torch.bfloat16, ) else: From c1ceef19fa782c9bd3ed9066eeb2031f381ad874 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 26 Mar 2025 13:55:51 -0700 Subject: [PATCH 18/20] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_tensor.py | 14 ++++++++------ torchao/prototype/mx_formats/custom_cast.py | 3 +++ torchao/prototype/mx_formats/mx_linear.py | 6 +++--- torchao/prototype/mx_formats/mx_tensor.py | 6 +++++- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_tensor.py b/test/prototype/mx_formats/test_mx_tensor.py index 1055272ed2..46b196fb49 100644 --- a/test/prototype/mx_formats/test_mx_tensor.py +++ b/test/prototype/mx_formats/test_mx_tensor.py @@ -18,7 +18,6 @@ ) from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6 from torchao.prototype.mx_formats.mx_tensor import ( - E8M0_EXPONENT_NAN_VAL, MXTensor, ScaleCalculationMode, to_dtype, @@ -117,8 +116,8 @@ def test_exponent_nan_in(elem_dtype): ) block_size = 4 tensor_mx = MXTensor.to_mx(tensor_hp, elem_dtype, block_size) - assert torch.all(tensor_mx._scale_e8m0[0] == E8M0_EXPONENT_NAN_VAL) - assert not torch.any(tensor_mx._scale_e8m0[1:] == E8M0_EXPONENT_NAN_VAL) + assert torch.all(torch.isnan(tensor_mx._scale_e8m0[0])) + assert not torch.any(torch.isnan(tensor_mx._scale_e8m0[1:])) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @@ -128,8 +127,11 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): """ If block exponent value is NaN, the MX tensor block value is NaN """ - scale_e8m0_bits = torch.tensor( - [E8M0_EXPONENT_NAN_VAL, 23], dtype=torch.uint8, device="cuda" + if pack_fp6 and elem_dtype not in (DTYPE_FP6_E2M3, DTYPE_FP6_E3M2): + pytest.skip("invalid configuration") + + scale_e8m0 = torch.tensor( + [float("nan"), 1.0], dtype=torch.float8_e8m0fnu, device="cuda" ) block_size = 4 @@ -155,7 +157,7 @@ def test_exponent_nan_out(elem_dtype, pack_fp6): block_size = 4 use_fp4_custom_triton_dequant_kernel = False tensor_mx = MXTensor( - scale_e8m0_bits, + scale_e8m0, data_bits, elem_dtype, block_size, diff --git a/torchao/prototype/mx_formats/custom_cast.py b/torchao/prototype/mx_formats/custom_cast.py index a78d980c37..c3c987baf9 100644 --- a/torchao/prototype/mx_formats/custom_cast.py +++ b/torchao/prototype/mx_formats/custom_cast.py @@ -745,6 +745,7 @@ def triton_f4_to_scaled_bf16( size is currently assumed to be 32. Output: a tensor of bfloat16 values, multiplied by the encoded scale """ + s_e8m0 = s_e8m0.view(torch.uint8) assert TORCH_VERSION_AT_LEAST_2_4, "unsupported" new_shape = (*x.shape[:-1], x.shape[-1] * 2) output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16) @@ -861,6 +862,7 @@ def triton_f6_e2m3_to_scaled_bf16( size is currently assumed to be 32. Output: a tensor of bfloat16 values, multiplied by the encoded scale """ + s_e8m0 = s_e8m0.view(torch.uint8) packed_mx_block_size = 3 * mx_block_size // 4 @@ -902,6 +904,7 @@ def triton_f6_e3m2_to_scaled_bf16( size is currently assumed to be 32. Output: a tensor of bfloat16 values, multiplied by the encoded scale """ + s_e8m0 = s_e8m0.view(torch.uint8) packed_mx_block_size = 3 * mx_block_size // 4 diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 378cc4909a..888be1e436 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -92,7 +92,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): weight_hp, block_size ) weight_mx_dim1 = MXTensor( - weight_mx_dim1_scale.view(torch.uint8).reshape(-1), + weight_mx_dim1_scale.reshape(-1), weight_mx_dim1_data.t(), w_elem_dtype, block_size, @@ -121,7 +121,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): grad_output_hp_r, block_size ) grad_output_mx_dim1 = MXTensor( - grad_output_mx_dim1_scale.view(torch.uint8).reshape(-1), + grad_output_mx_dim1_scale.reshape(-1), grad_output_mx_dim1_data.t(), grad_elem_dtype, block_size, @@ -143,7 +143,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): input_hp_r, block_size ) input_t_mx_dim0_tmp = MXTensor( - input_t_mx_dim0_tmp_scale.view(torch.uint8).reshape(-1), + input_t_mx_dim0_tmp_scale.reshape(-1), input_t_mx_dim0_tmp_data.t(), in_elem_dtype, block_size, diff --git a/torchao/prototype/mx_formats/mx_tensor.py b/torchao/prototype/mx_formats/mx_tensor.py index 5d908d8d50..9949ee1a21 100644 --- a/torchao/prototype/mx_formats/mx_tensor.py +++ b/torchao/prototype/mx_formats/mx_tensor.py @@ -271,10 +271,12 @@ def to_mx( else: raise AssertionError("unsupported") + scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) return scale_e8m0_biased, data_lp def get_fp_scale(scale_e8m0): + scale_e8m0 = scale_e8m0.view(torch.uint8) s_offset = scale_e8m0.to(torch.int16) - E8M0_EXPONENT_BIAS # TODO(later): it would be nice if there was a way to do the 2^x operation # in PyTorch without creating a tensor of twos @@ -507,7 +509,9 @@ def __new__( dtype=orig_dtype, device=data_bits.device, ) - assert scale_e8m0_bits.dtype == torch.uint8, "unsupported" + assert ( + scale_e8m0_bits.dtype == torch.float8_e8m0fnu + ), f"scale_e8m0_bits.dtype must be `torch.float8_e8m0fnu`, got {scale_e8m0_bits.dtype}" assert len(scale_e8m0_bits.shape) == 1, "unsupported" assert data_bits.dtype in ( torch.float8_e4m3fn, From 65bfff079cfd3ce6b714683125b5b3c79dae537f Mon Sep 17 00:00:00 2001 From: vasiliy Date: Wed, 26 Mar 2025 15:05:44 -0700 Subject: [PATCH 19/20] Update [ghstack-poisoned] --- benchmarks/float8/float8_roofline.py | 6 ++--- benchmarks/float8/profile_lowp_training.py | 4 +-- test/prototype/mx_formats/test_mx_linear.py | 20 ++++++++------- torchao/prototype/mx_formats/README.md | 20 +++++++-------- torchao/prototype/mx_formats/__init__.py | 15 ++++++++++++ torchao/prototype/mx_formats/config.py | 3 ++- torchao/prototype/mx_formats/mx_linear.py | 27 ++++++--------------- 7 files changed, 50 insertions(+), 45 deletions(-) diff --git a/benchmarks/float8/float8_roofline.py b/benchmarks/float8/float8_roofline.py index 16de03f957..137563940c 100644 --- a/benchmarks/float8/float8_roofline.py +++ b/benchmarks/float8/float8_roofline.py @@ -61,8 +61,8 @@ Float8LinearConfig, convert_to_float8_training, ) -from torchao.prototype.mx_formats.config import MXLinearConfig -from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear +from torchao.prototype.mx_formats import MXLinearConfig +from torchao.quantization import quantize_ from torchao.testing.float8.roofline_utils import ( get_float8_mem_sympy, get_gemm_time_sympy, @@ -391,7 +391,7 @@ def run( assert mx_recipe_name is not None config = MXLinearConfig.from_recipe_name(mx_recipe_name) m_fp8_dyn = copy.deepcopy(m_orig) - swap_linear_with_mx_linear(m_fp8_dyn, config=config) + quantize_(m_fp8_dyn, config=config) m_fp8_dyn = torch.compile(m_fp8_dyn) b_fp8_e2e_time_s = get_gpu_kernel_time(m_fp8_dyn, x, grad_output) diff --git a/benchmarks/float8/profile_lowp_training.py b/benchmarks/float8/profile_lowp_training.py index d4a3079360..a1ad38df09 100644 --- a/benchmarks/float8/profile_lowp_training.py +++ b/benchmarks/float8/profile_lowp_training.py @@ -46,9 +46,9 @@ convert_to_float8_training, ) from torchao.prototype.mx_formats.config import MXLinearConfig -from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.prototype.mx_formats.utils import to_blocked +from torchao.quantization import quantize_ # don't truncate long kernel names pd.options.display.max_colwidth = 100 @@ -379,7 +379,7 @@ def main( if mx_recipe_name is None: convert_to_float8_training(m_lowp, config=config) else: - swap_linear_with_mx_linear(m_lowp, config=config) + quantize_(m_lowp, config=config) # this function is only used for cast_only to_mx_func = MXTensor.to_mx diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 1dce269c57..a18adf5d64 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -24,8 +24,8 @@ MXInferenceLinear, MXLinear, swap_linear_with_mx_inference_linear, - swap_linear_with_mx_linear, ) +from torchao.quantization import quantize_ from torchao.quantization.utils import compute_error from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_8, @@ -98,7 +98,7 @@ def test_linear_eager_vs_hp( elem_dtype_grad_output_override=elem_dtype[2], use_fp8_dim1_cast_triton_kernel=use_fp8_dim1_cast_triton_kernel, ) - swap_linear_with_mx_linear(m_mx, config=config) + quantize_(m_mx, config) x_ref = torch.randn( *input_shape, device="cuda", dtype=torch.bfloat16 @@ -159,8 +159,8 @@ def test_linear_eager_emulated_vs_real_gemm(recipe_name, mkn): config_emulated = MXLinearConfig(block_size=32, elem_dtype=elem_dtype) config_real = MXLinearConfig.from_recipe_name(recipe_name) - swap_linear_with_mx_linear(m_emulated, config=config_emulated) - swap_linear_with_mx_linear(m_real, config=config_real) + quantize_(m_emulated, config=config_emulated) + quantize_(m_real, config=config_real) y_emulated = m_emulated(x) y_emulated.backward(g) @@ -189,7 +189,7 @@ def test_activation_checkpointing(): nn.Linear(8, 8, bias=True, device="cuda"), ) config = MXLinearConfig(block_size=4, elem_dtype=elem_dtype) - swap_linear_with_mx_linear(m, config=config) + quantize_(m, config=config) x = torch.randn(*input_shape, device="cuda").requires_grad_() g = torch.randn(*grad_shape, device="cuda") @@ -252,7 +252,7 @@ def test_linear_compile(hp_dtype, recipe_name, bias, use_fp8_dim1_cast_triton_ke config = MXLinearConfig.from_recipe_name(recipe_name) config.use_fp8_dim1_cast_triton_kernel = use_fp8_dim1_cast_triton_kernel - swap_linear_with_mx_linear(m_mx, config=config) + quantize_(m_mx, config=config) m_mx_c = copy.deepcopy(m_mx) m_mx_c = torch.compile(m_mx_c, fullgraph=True, backend="inductor") @@ -339,10 +339,12 @@ def test_filter_fn(): nn.Linear(32, 32), ) m2 = copy.deepcopy(m1) - filter_fn = lambda mod, fqn: fqn != "1" # noqa: E731 + filter_fn = lambda mod, fqn: isinstance(mod, torch.nn.Linear) and fqn != "1" # noqa: E731 config = MXLinearConfig(block_size=32) - swap_linear_with_mx_linear(m1, config=config, filter_fn=filter_fn) + print("before", m1) + quantize_(m1, config=config, filter_fn=filter_fn) + print("after", m1) assert type(m1[0]) == MXLinear assert type(m1[1]) == torch.nn.Linear @@ -354,7 +356,7 @@ def test_filter_fn(): def test_training_print_str(): m = nn.Sequential(nn.Linear(32, 32)) config = MXLinearConfig() - swap_linear_with_mx_linear(m, config=config) + quantize_(m, config=config) s = str(m) assert "bl_sz=32" in s assert "kernel=emulated" in s diff --git a/torchao/prototype/mx_formats/README.md b/torchao/prototype/mx_formats/README.md index 4689bae7ae..4eaa937a78 100644 --- a/torchao/prototype/mx_formats/README.md +++ b/torchao/prototype/mx_formats/README.md @@ -40,17 +40,16 @@ x_hp = x_mx.to_dtype(torch.float) This is a module to do MX training, the MX matmul is currently emulated. ```python -from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear -from torchao.prototype.mx_formats.config import MXLinearConfig, MXGemmKernelChoice +import torch +from torchao.quantization import quantize_ +from torchao.prototype.mx_formats import MXLinearConfig, MXGemmKernelChoice -# early prototype: on MX-enabled hardware, you can use the real MX gemm backed by -# torchao's CUTLASS kernels. In the future, we will also add cuBLAS kernel support. -gemm_kernel_choice = MXGemmKernelChoice.EMULATED - -# on NVIDIA Blackwell GPUs, you can also use cuBLAS or CUTLASS mxfp8 kernels -# note: torch.compile support for both of these is WIP +# on NVIDIA Blackwell GPUs, you can use cuBLAS or CUTLASS mxfp8 kernels +gemm_kernel_choice = MXGemmKernelChoice.CUBLAS # gemm_kernel_choice = MXGemmKernelChoice.CUTLASS -# gemm_kernel_choice = MXGemmKernelChoice.CUBLAS + +# on older NVIDIA gpus, you can run training with emulated MX gemm +# gemm_kernel_choice = MXGemmKernelChoice.EMULATED m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda() config = MXLinearConfig( @@ -58,7 +57,7 @@ config = MXLinearConfig( block_size=32, gemm_kernel_choice=gemm_kernel_choice, ) -swap_linear_with_mx_linear(m, config=config) +quantize_(m, config) # training loop (not shown) ``` @@ -68,6 +67,7 @@ swap_linear_with_mx_linear(m, config=config) This is a module to do MX inference, weights are in MX and matmul is in high precision. ```python +import torch from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear from torchao.prototype.mx_formats.config import MXLinearConfig diff --git a/torchao/prototype/mx_formats/__init__.py b/torchao/prototype/mx_formats/__init__.py index e69de29bb2..e0eff6c85d 100644 --- a/torchao/prototype/mx_formats/__init__.py +++ b/torchao/prototype/mx_formats/__init__.py @@ -0,0 +1,15 @@ +from torchao.prototype.mx_formats.config import ( + MXGemmKernelChoice, + MXLinearConfig, + MXLinearRecipeName, +) + +# import mx_linear here to register the quantize_ transform logic +# ruff: noqa: I001 +import torchao.prototype.mx_formats.mx_linear # noqa: F401 + +__all__ = [ + "MXLinearConfig", + "MXGemmKernelChoice", + "MXLinearRecipeName", +] diff --git a/torchao/prototype/mx_formats/config.py b/torchao/prototype/mx_formats/config.py index 438a6a9293..d7767886a9 100644 --- a/torchao/prototype/mx_formats/config.py +++ b/torchao/prototype/mx_formats/config.py @@ -10,6 +10,7 @@ import torch +from torchao.core.config import AOBaseConfig from torchao.prototype.mx_formats.constants import ( DTYPE_FP4, DTYPE_TO_SHORT_STR, @@ -41,7 +42,7 @@ class MXLinearRecipeName(Enum): @dataclass -class MXLinearConfig: +class MXLinearConfig(AOBaseConfig): # block size for scaling, default is 32 to match # https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, # section 5.2 diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index 888be1e436..af8adbbdec 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -16,6 +16,9 @@ from torchao.prototype.mx_formats.config import MXGemmKernelChoice, MXLinearConfig from torchao.prototype.mx_formats.custom_cast import triton_to_mxfp8_dim1 from torchao.prototype.mx_formats.mx_tensor import MXTensor +from torchao.quantization.transform_module import ( + register_quantize_module_handler, +) @torch._dynamo.allow_in_graph @@ -183,7 +186,7 @@ def from_float( mod, config: Optional[MXLinearConfig] = MXLinearConfig(), ): - # TODO(before land): remove this + assert isinstance(mod, torch.nn.Linear), f"unsupported type(mod) {type(mod)}" assert isinstance(config, MXLinearConfig) mod.__class__ = MXLinear mod.config = config @@ -290,25 +293,9 @@ def _is_linear(mod, fqn): return isinstance(mod, torch.nn.Linear) -def swap_linear_with_mx_linear( - model, - *, - config: Optional[MXLinearConfig] = None, - filter_fn=None, -): - if filter_fn is None: - combined_filter_fn = _is_linear - else: - - def __fn(mod, fqn): - return _is_linear(mod, fqn) and filter_fn(mod, fqn) - - combined_filter_fn = __fn - replace_with_custom_fn_if_matches_filter( - model, - lambda mod: MXLinear.from_float(mod, config=config), - combined_filter_fn, - ) +@register_quantize_module_handler(MXLinearConfig) +def _mx_linear_transform(module: torch.nn.Module, config: MXLinearConfig): + return MXLinear.from_float(module, config=config) def swap_linear_with_mx_inference_linear( From 45abedfde1123e159939d7b8322970569feb9e73 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Fri, 28 Mar 2025 07:14:48 -0700 Subject: [PATCH 20/20] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index a18adf5d64..9854b356f3 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -342,9 +342,7 @@ def test_filter_fn(): filter_fn = lambda mod, fqn: isinstance(mod, torch.nn.Linear) and fqn != "1" # noqa: E731 config = MXLinearConfig(block_size=32) - print("before", m1) quantize_(m1, config=config, filter_fn=filter_fn) - print("after", m1) assert type(m1[0]) == MXLinear assert type(m1[1]) == torch.nn.Linear