From a6c9f1c1de9f790a8f5be147994814197e5fad54 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 10:26:01 -0800 Subject: [PATCH 01/13] Update [ghstack-poisoned] --- benchmarks/float8/bench_matmul.py | 127 ++++++++++++++++++-- test/prototype/mx_formats/test_mx_linear.py | 85 +++++++++++++ 2 files changed, 200 insertions(+), 12 deletions(-) diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index 3d48853754..52cfcfc481 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +from enum import IntEnum import itertools from typing import Optional @@ -26,7 +27,16 @@ h100_peak_flops_fp16_tc = 989e12 h100_peak_tops_float8_tc = 1979e12 -dtype_to_peak_tops = { +# HGX B20 specs: https://www.nvidia.com/en-us/data-center/hgx/ +# note: divided numbers from ^ by 2 to undo the effects of sparsity +# TODO(this PR): I'm achieving 5% of peak TFLOPS with bf16 and float8, +# something seems funky +b200_peak_flops_float32 = 600e12 +b200_peak_flops_fp16_tc = 18e15 +b200_peak_tops_float8_tc = 36e15 +b200_peak_tops_float4_tc = 72e15 + +dtype_to_peak_tops_h100 = { torch.float32: h100_peak_flops_float32, torch.float16: h100_peak_flops_fp16_tc, torch.bfloat16: h100_peak_flops_fp16_tc, @@ -34,6 +44,27 @@ torch.float8_e5m2: h100_peak_tops_float8_tc, } +dtype_to_peak_tops_b200 = { + torch.float32: b200_peak_flops_float32, + torch.float16: b200_peak_flops_fp16_tc, + torch.bfloat16: b200_peak_flops_fp16_tc, + torch.float8_e4m3fn: b200_peak_tops_float8_tc, + torch.float8_e5m2: b200_peak_tops_float8_tc, + # TODO float4 +} + +# TODO(this PR): switch automatically by detected hardware type +# TODO(this PR): fp4 is currently using fp8's peak tops below, fix it +dtype_to_peak_tops = dtype_to_peak_tops_b200 + + +# not for land, matching https://www.internalfb.com/phabricator/paste/view/P1717686991 +class DataType(IntEnum): + DEFAULT = 0 + E8M0 = 1 + FP4 = 2 + UFP8 = 3 + def benchmark_fn_in_sec(f, *args, **kwargs): # Manual warmup @@ -75,6 +106,7 @@ def run( N: Optional[int] = None, use_gpu_kernel_time: bool = False, scaling_granularity: str = "tensorwise", + blockwise_dtype: Optional[str] = None, ): device = "cuda" @@ -85,15 +117,17 @@ def run( "K", "N", "ref_time_s", - "fp8_time_s", - "fp8_speedup", + "lowp_time_s", + "lowp_speedup", ) results = [] dtype = torch.bfloat16 name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N) fast_accum_vals = [True, False] - scaling_granularity = ScalingGranularity(scaling_granularity) + # Note: blockwise not in enum because blockwise is in prototype + if scaling_granularity != "blockwise": + scaling_granularity = ScalingGranularity(scaling_granularity) for idx, (fast_accum, (name, (M, K, N))) in enumerate( itertools.product(fast_accum_vals, name_to_shapes) @@ -119,28 +153,97 @@ def run( # raw float8 matmul (upper bound for what we can achive in eager mode) # TODO(future): add e5m2 d1, d2, d3 = torch.float8_e4m3fn, torch.float8_e4m3fn, dtype - A = torch.zeros(M, K, device=device, dtype=d1) - B = torch.zeros(K, N, device=device, dtype=d2).t().contiguous().t() + A = torch.randn(M, K, device=device).to(d1) + B = torch.randn(K, N, device=device).to(d2).t().contiguous().t() if scaling_granularity == ScalingGranularity.TENSORWISE: scale_a = torch.tensor([1.0], device=device) scale_b = torch.tensor([1.0], device=device) - else: - assert scaling_granularity == ScalingGranularity.AXISWISE, "unsupported" + elif scaling_granularity == ScalingGranularity.AXISWISE: scale_a = torch.ones(M, 1, device=device) scale_b = torch.ones(1, N, device=device) + elif scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3": + # TODO(this PR): also block size 16 + BLOCK_SIZE = 32 + A = torch.randint(128, (M, K), device=device, dtype=torch.uint8).view( + torch.float8_e4m3fn + ) + B = ( + torch.randint(128, (N, K), device=device, dtype=torch.uint8) + .view(torch.float8_e4m3fn) + .t() + ) + scale_a = torch.randint( + 128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ) + scale_b = torch.randint( + 128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ).t() + elif scaling_granularity == "blockwise" and blockwise_dtype == "float4": + # TODO(this PR): also block size 16 + BLOCK_SIZE = 16 + A = torch.randint(128, (M, K // 2), device=device, dtype=torch.uint8).view( + torch.float8_e4m3fn + ) + B = ( + torch.randint(128, (N, K // 2), device=device, dtype=torch.uint8) + .view(torch.float8_e4m3fn) + .t() + ) + scale_a = torch.randint( + 128, (M, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ) + scale_b = torch.randint( + 128, (N, K // BLOCK_SIZE), dtype=torch.uint8, device="cuda" + ).t() + else: + raise AssertionError(f"unsupported granularity {scaling_granularity}") def do_matmul(A, B): nonlocal scale_a nonlocal scale_b - return torch._scaled_mm( - A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum - ) + + if scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3": + return torch._scaled_mm( + A, + B, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=d3, + use_fast_accum=fast_accum, + a_dtype=None, # inferred from A + b_dtype=None, # inferred from B + scale_dtype=DataType.E8M0, + ) + elif scaling_granularity == "blockwise" and blockwise_dtype == "float4": + return torch._scaled_mm( + A, + B, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=d3, + use_fast_accum=fast_accum, + a_dtype=DataType.FP4, + b_dtype=DataType.FP4, + scale_dtype=DataType.E8M0, + ) + + else: + return torch._scaled_mm( + A, B, scale_a, scale_b, out_dtype=d3, use_fast_accum=fast_accum + ) + + # test + # res = do_matmul(A, B) fp8_time_sec, fp8_tops_sec, fp8_pct_top_peak = do_benchmarks( tops, dtype_to_peak_tops[d1], use_gpu_kernel_time, do_matmul, A, B ) print( - f"fp8 time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}" + f"lowp time_sec {fp8_time_sec:.2E}, tops/sec {fp8_tops_sec:.2E}, pct_peak {fp8_pct_top_peak:.3f}" ) del A, B, scale_a, scale_b diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 35afeb7959..25be1d9541 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import copy +from enum import IntEnum import pytest import torch @@ -30,6 +31,14 @@ pytest.skip("Unsupported PyTorch version", allow_module_level=True) +# not for land, https://www.internalfb.com/phabricator/paste/view/P1717686991 +class DataType(IntEnum): + DEFAULT = 0 + E8M0 = 1 + FP4 = 2 + UFP8 = 3 + + # source: https://stackoverflow.com/a/22638709 @pytest.fixture(autouse=True) def run_around_tests(): @@ -234,3 +243,79 @@ def test_filter_fn(): swap_linear_with_mx_inference_linear(m2, torch.float8_e4m3fn, 32, filter_fn) # noqa: E501 assert type(m2[0]) == MXInferenceLinear assert type(m2[1]) == torch.nn.Linear + + +def test_scaled_mm_mxfp8(): + # hello world + # next: basic numerics + + M, K, N = 8192, 4096, 8192 + BLOCK_SIZE = 32 + a = torch.randint(128, (M, K), device="cuda", dtype=torch.uint8).view( + torch.float8_e4m3fn + ) + b = ( + torch.randint(128, (N, K), device="cuda", dtype=torch.uint8) + .view(torch.float8_e4m3fn) + .t() + ) + a_scales = torch.randint( + 128, ((M * K) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8 + ).view(M, K // BLOCK_SIZE) + b_scales = ( + torch.randint(128, ((K * N) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8) + .view(N, K // BLOCK_SIZE) + .t() + ) + out = torch._scaled_mm( + a, + b, + a_scales, + b_scales, + None, + None, + torch.bfloat16, + False, + None, + None, + DataType.E8M0, + ) + print(out) + + +def test_scaled_mm_nvfp4(): + # hello world + # next: basic numerics + + M, K, N = 8192, 4096, 8192 + BLOCK_SIZE = 16 + a = torch.randint(128, ((M * K) // 2,), device="cuda", dtype=torch.uint8).view( + M, K // 2 + ) + b = ( + torch.randint(128, ((K * N) // 2,), device="cuda", dtype=torch.uint8) + .view(N, K // 2) + .t() + ) + a_scales = torch.randint( + 128, ((M * K) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8 + ).view(M, K // BLOCK_SIZE) + b_scales = ( + torch.randint(128, ((K * N) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8) + .view(N, K // BLOCK_SIZE) + .t() + ) + out = torch._scaled_mm( + a, + b, + a_scales, + b_scales, + None, + None, + torch.bfloat16, + False, + DataType.FP4, + DataType.FP4, + DataType.UFP8, + ) + print(out) From 596651654e1e6ddaab10480f543b1493b51ffa6a Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 10:48:57 -0800 Subject: [PATCH 02/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 25be1d9541..68da09ecc6 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -245,6 +245,10 @@ def test_filter_fn(): assert type(m2[1]) == torch.nn.Linear +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher" +) def test_scaled_mm_mxfp8(): # hello world # next: basic numerics @@ -283,6 +287,10 @@ def test_scaled_mm_mxfp8(): print(out) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher" +) def test_scaled_mm_nvfp4(): # hello world # next: basic numerics From 759797b2ab5952f8d0354bb89ccf3c1ef1513616 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 10:50:28 -0800 Subject: [PATCH 03/13] Update [ghstack-poisoned] --- benchmarks/float8/bench_matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/float8/bench_matmul.py b/benchmarks/float8/bench_matmul.py index 52cfcfc481..bcf30ef5fc 100644 --- a/benchmarks/float8/bench_matmul.py +++ b/benchmarks/float8/bench_matmul.py @@ -3,8 +3,8 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -from enum import IntEnum import itertools +from enum import IntEnum from typing import Optional import fire From ee80fa7d5f070d597d0ceec1e36a24bdeb2106b3 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 10:56:49 -0800 Subject: [PATCH 04/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 36 ++++++++------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 68da09ecc6..4f4217115d 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -247,7 +247,8 @@ def test_filter_fn(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not is_sm_at_least_100(), reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher" + not is_sm_at_least_100(), + reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", ) def test_scaled_mm_mxfp8(): # hello world @@ -264,13 +265,11 @@ def test_scaled_mm_mxfp8(): .t() ) a_scales = torch.randint( - 128, ((M * K) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8 - ).view(M, K // BLOCK_SIZE) - b_scales = ( - torch.randint(128, ((K * N) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8) - .view(N, K // BLOCK_SIZE) - .t() + 128, (M, K // BLOCK_SIZE), device="cuda", dtype=torch.uint8 ) + b_scales = torch.randint( + 128, (K // BLOCK_SIZE, N), device="cuda", dtype=torch.uint8 + ).t() out = torch._scaled_mm( a, b, @@ -289,7 +288,8 @@ def test_scaled_mm_mxfp8(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( - not is_sm_at_least_100(), reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher" + not is_sm_at_least_100(), + reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", ) def test_scaled_mm_nvfp4(): # hello world @@ -297,22 +297,14 @@ def test_scaled_mm_nvfp4(): M, K, N = 8192, 4096, 8192 BLOCK_SIZE = 16 - a = torch.randint(128, ((M * K) // 2,), device="cuda", dtype=torch.uint8).view( - M, K // 2 - ) - b = ( - torch.randint(128, ((K * N) // 2,), device="cuda", dtype=torch.uint8) - .view(N, K // 2) - .t() - ) + a = torch.randint(128, (M, K // 2), device="cuda", dtype=torch.uint8) + b = torch.randint(128, (N, K // 2), device="cuda", dtype=torch.uint8).t() a_scales = torch.randint( - 128, ((M * K) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8 - ).view(M, K // BLOCK_SIZE) - b_scales = ( - torch.randint(128, ((K * N) // BLOCK_SIZE,), device="cuda", dtype=torch.uint8) - .view(N, K // BLOCK_SIZE) - .t() + 128, (M, K // BLOCK_SIZE), device="cuda", dtype=torch.uint8 ) + b_scales = torch.randint( + 128, (N, K // BLOCK_SIZE), device="cuda", dtype=torch.uint8 + ).t() out = torch._scaled_mm( a, b, From c04dbef25b7f9b9c2badc2b736befe820746fb78 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 14:28:55 -0800 Subject: [PATCH 05/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 31 +++++++++++---------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 4f4217115d..52fe996f1b 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -251,24 +251,23 @@ def test_filter_fn(): reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", ) def test_scaled_mm_mxfp8(): - # hello world - # next: basic numerics + # basic numerics with all scales 1.0 + # next: other scale values - M, K, N = 8192, 4096, 8192 + # M, K, N = 8192, 4096, 8192 + M, K, N = 128, 128, 128 BLOCK_SIZE = 32 - a = torch.randint(128, (M, K), device="cuda", dtype=torch.uint8).view( - torch.float8_e4m3fn - ) - b = ( - torch.randint(128, (N, K), device="cuda", dtype=torch.uint8) - .view(torch.float8_e4m3fn) - .t() - ) - a_scales = torch.randint( - 128, (M, K // BLOCK_SIZE), device="cuda", dtype=torch.uint8 + a = torch.eye(M, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.eye(M, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn).t() + + # 127 is 1.0 in e8m0 + scale_val = 127 + + a_scales = torch.full( + (M, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 ) - b_scales = torch.randint( - 128, (K // BLOCK_SIZE, N), device="cuda", dtype=torch.uint8 + b_scales = torch.full( + (K // BLOCK_SIZE, N), scale_val, device="cuda", dtype=torch.uint8 ).t() out = torch._scaled_mm( a, @@ -283,6 +282,8 @@ def test_scaled_mm_mxfp8(): None, DataType.E8M0, ) + + # [[1, 0, ...], ..., [0, ..., 1]] - correct print(out) From 90c6e434bde2b390e095efaf276e34df486686e8 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 16:27:25 -0800 Subject: [PATCH 06/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 117 +++++++++++++++++++- 1 file changed, 115 insertions(+), 2 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 52fe996f1b..58d0be53c5 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -257,8 +257,8 @@ def test_scaled_mm_mxfp8(): # M, K, N = 8192, 4096, 8192 M, K, N = 128, 128, 128 BLOCK_SIZE = 32 - a = torch.eye(M, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn) - b = torch.eye(M, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn).t() + a = torch.ones(M, K, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn) + b = torch.ones(N, K, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn).t() # 127 is 1.0 in e8m0 scale_val = 127 @@ -284,7 +284,120 @@ def test_scaled_mm_mxfp8(): ) # [[1, 0, ...], ..., [0, ..., 1]] - correct + torch.set_printoptions(profile="full", linewidth=320) print(out) + print(torch.max(out)) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", +) +def test_scaled_mm_mx_reconstruct_scale_layout(): + # brute force the expected layout format + # basic numerics with all scales 1.0 + # next: other scale values + + # M, K, N = 8192, 4096, 8192 + M, K, N = 128, 128, 128 + BLOCK_SIZE = 32 + a = torch.ones(M, K, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn) + + # 127 is 1.0 in e8m0 + scale_val = 127 + + print() + + # Probe torch._scaled_mm to deduce the actual layout used for the scale + # arguments. Specifically, here is what the code below would do if we had + # A and B as 4x4 matrices with MX block size 2. All matrices are shown in float32 + # format, not their actual storage format, to demonstrate the algorithm. + # + # A matrix - set to all-ones + # + # A = 1111 + # 1111 + # 1111 + # 1111 + # + # B matrix variants - all-zeros, except a single one for each mx block in the first column + # + # B_0 = 1000 B_2 = 0000 + # 0000 0000 + # 0000 1000 + # 0000 0000 + # + # A scale - starts as a matrix of all-ones + # + # A_s = 11 + # 11 + # 11 + # 11 + # + # for each row in rows of A: + # for each ol in cols of A: + # initialize A to all-ones + # set A[row][col] = 2.0 + # for each B in [Bs]: + # C = torch._scaled_mm(A, B, A_s, B_s, ...) + # if max(C) > 1.0: + # the scale incremented in A_s was corresponding to the current block + # TODO finish this, just need to reconstruct from printed data + + for scale_row in range(M): + for scale_col in range(K // BLOCK_SIZE): + + # We test every BLOCK_SIZE to deduce which of the blocks is + # responsible for the scale value. Note that this isn't the most + # efficient way to test, but I'm optimizing for dev time here. + for block_idx in range(K // BLOCK_SIZE): + + b = torch.zeros(N, K, device="cuda", dtype=torch.float32) + # set a single one inside the block + b[0][block_idx * BLOCK_SIZE] = 1 + b = b.to(torch.float8_e4m3fn).t() + + a_scales = torch.full( + (M, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 + ) + b_scales = torch.full( + # (K // BLOCK_SIZE, N), scale_val, device="cuda", dtype=torch.uint8 + (N, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 + ) + + # TODO: it looks like blockwise scales are switched in cuBLAS? + # a_scales[scale_row][scale_col] = scale_val + 1 + # incrementing scale of b looks like it's actually affecting scaling of a + b_scales[scale_row][scale_col] = scale_val + 1 + + out = torch._scaled_mm( + a, + b, + a_scales, + b_scales, + None, + None, + torch.bfloat16, + False, + None, + None, + DataType.E8M0, + ) + + # torch.set_printoptions(profile="full", linewidth=320) + # print(out) + # print(torch.max(out, keepdim=True)) + + max_val = torch.max(out).item() + if max_val > 1: + max_flat_index = torch.argmax(out).item() + max_row = max_flat_index // M + max_col = max_flat_index % M + print('scale_coords', scale_row, scale_col, 'block_idx', block_idx, 'max_coords', max_row, max_col, 'max_val', max_val) + + # break + # break @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") From 0a8ce19c9b2ce909e2dcc89669b6a23c3fb264cf Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 20:24:47 -0800 Subject: [PATCH 07/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 32 ++++++++++++--------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 58d0be53c5..a6c92d16e2 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -269,6 +269,7 @@ def test_scaled_mm_mxfp8(): b_scales = torch.full( (K // BLOCK_SIZE, N), scale_val, device="cuda", dtype=torch.uint8 ).t() + # b_scales[0][0] = 128 out = torch._scaled_mm( a, b, @@ -284,7 +285,7 @@ def test_scaled_mm_mxfp8(): ) # [[1, 0, ...], ..., [0, ..., 1]] - correct - torch.set_printoptions(profile="full", linewidth=320) + torch.set_printoptions(profile="full", linewidth=280) print(out) print(torch.max(out)) @@ -348,6 +349,19 @@ def test_scaled_mm_mx_reconstruct_scale_layout(): for scale_row in range(M): for scale_col in range(K // BLOCK_SIZE): + a_scales = torch.full( + (M, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 + ) + b_scales = torch.full( + # (K // BLOCK_SIZE, N), scale_val, device="cuda", dtype=torch.uint8 + (N, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 + ) + + # TODO: it looks like blockwise scales are switched in cuBLAS? + # a_scales[scale_row][scale_col] = scale_val + 1 + # incrementing scale of b looks like it's actually affecting scaling of a + b_scales[scale_row][scale_col] = scale_val + 1 + # We test every BLOCK_SIZE to deduce which of the blocks is # responsible for the scale value. Note that this isn't the most # efficient way to test, but I'm optimizing for dev time here. @@ -358,19 +372,6 @@ def test_scaled_mm_mx_reconstruct_scale_layout(): b[0][block_idx * BLOCK_SIZE] = 1 b = b.to(torch.float8_e4m3fn).t() - a_scales = torch.full( - (M, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 - ) - b_scales = torch.full( - # (K // BLOCK_SIZE, N), scale_val, device="cuda", dtype=torch.uint8 - (N, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 - ) - - # TODO: it looks like blockwise scales are switched in cuBLAS? - # a_scales[scale_row][scale_col] = scale_val + 1 - # incrementing scale of b looks like it's actually affecting scaling of a - b_scales[scale_row][scale_col] = scale_val + 1 - out = torch._scaled_mm( a, b, @@ -385,6 +386,7 @@ def test_scaled_mm_mx_reconstruct_scale_layout(): DataType.E8M0, ) + # print(scale_row, scale_col, block_idx) # torch.set_printoptions(profile="full", linewidth=320) # print(out) # print(torch.max(out, keepdim=True)) @@ -394,6 +396,8 @@ def test_scaled_mm_mx_reconstruct_scale_layout(): max_flat_index = torch.argmax(out).item() max_row = max_flat_index // M max_col = max_flat_index % M + assert max_col == 0 + assert max_val == 2.0 print('scale_coords', scale_row, scale_col, 'block_idx', block_idx, 'max_coords', max_row, max_col, 'max_val', max_val) # break From 686be3dac1d2ea47baf52d68c5c8e6bb6c8554a4 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 21:08:10 -0800 Subject: [PATCH 08/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 116 +++++++++++++++++++- 1 file changed, 112 insertions(+), 4 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index a6c92d16e2..87601b3a15 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -295,7 +295,7 @@ def test_scaled_mm_mxfp8(): not is_sm_at_least_100(), reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", ) -def test_scaled_mm_mx_reconstruct_scale_layout(): +def test_scaled_mm_mx_reconstruct_scale_a_layout(): # brute force the expected layout format # basic numerics with all scales 1.0 # next: other scale values @@ -324,7 +324,7 @@ def test_scaled_mm_mx_reconstruct_scale_layout(): # # B matrix variants - all-zeros, except a single one for each mx block in the first column # - # B_0 = 1000 B_2 = 0000 + # B_0 = 1000 B_1 = 0000 # 0000 0000 # 0000 1000 # 0000 0000 @@ -344,7 +344,6 @@ def test_scaled_mm_mx_reconstruct_scale_layout(): # C = torch._scaled_mm(A, B, A_s, B_s, ...) # if max(C) > 1.0: # the scale incremented in A_s was corresponding to the current block - # TODO finish this, just need to reconstruct from printed data for scale_row in range(M): for scale_col in range(K // BLOCK_SIZE): @@ -358,7 +357,6 @@ def test_scaled_mm_mx_reconstruct_scale_layout(): ) # TODO: it looks like blockwise scales are switched in cuBLAS? - # a_scales[scale_row][scale_col] = scale_val + 1 # incrementing scale of b looks like it's actually affecting scaling of a b_scales[scale_row][scale_col] = scale_val + 1 @@ -403,6 +401,116 @@ def test_scaled_mm_mx_reconstruct_scale_layout(): # break # break +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", +) +def test_scaled_mm_mx_reconstruct_scale_b_layout(): + # brute force the expected layout format + # basic numerics with all scales 1.0 + # next: other scale values + + # M, K, N = 8192, 4096, 8192 + M, K, N = 128, 128, 128 + BLOCK_SIZE = 32 + b = torch.ones(M, K, device="cuda", dtype=torch.float32).to(torch.float8_e4m3fn).t() + + # 127 is 1.0 in e8m0 + scale_val = 127 + + print() + + # Probe torch._scaled_mm to deduce the actual layout used for the scale + # arguments. Specifically, here is what the code below would do if we had + # A and B as 4x4 matrices with MX block size 2. All matrices are shown in float32 + # format, not their actual storage format, to demonstrate the algorithm. + # + # A matrix variants - all-zeros, except a single one for each mx block in the first row + # + # A_0 = 1000 A_1 = 0010 + # 0000 0000 + # 0000 0000 + # 0000 0000 + # + # B matrix - set to all-ones + # + # B = 1111 + # 1111 + # 1111 + # 1111 + # + # B scale - starts as a matrix of all-ones + # + # B_s = 11 + # 11 + # 11 + # 11 + # + # for each row in rows of B: + # for each col in cols of B: + # initialize B to all-ones + # set B[row][col] = 2.0 + # for each A in [As]: + # C = torch._scaled_mm(A, B, A_s, B_s, ...) + # if max(C) > 1.0: + # the scale incremented in B_s was corresponding to the current block + + for scale_row in range(M): + for scale_col in range(K // BLOCK_SIZE): + + a_scales = torch.full( + (M, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 + ) + b_scales = torch.full( + # (K // BLOCK_SIZE, N), scale_val, device="cuda", dtype=torch.uint8 + (N, K // BLOCK_SIZE), scale_val, device="cuda", dtype=torch.uint8 + ) + + # TODO: it looks like blockwise scales are switched in cuBLAS? + # incrementing scale of a looks like it's actually affecting scaling of b + a_scales[scale_row][scale_col] = scale_val + 1 + + # We test every BLOCK_SIZE to deduce which of the blocks is + # responsible for the scale value. Note that this isn't the most + # efficient way to test, but I'm optimizing for dev time here. + for block_idx in range(K // BLOCK_SIZE): + + a = torch.zeros(M, K, device="cuda", dtype=torch.float32) + # set a single one inside the block + a[0][block_idx * BLOCK_SIZE] = 1 + a = a.to(torch.float8_e4m3fn) + + out = torch._scaled_mm( + a, + b, + a_scales, + b_scales, + None, + None, + torch.bfloat16, + False, + None, + None, + DataType.E8M0, + ) + + # print(scale_row, scale_col, block_idx) + # torch.set_printoptions(profile="full", linewidth=320) + # print(out) + # print(torch.max(out, keepdim=True)) + + max_val = torch.max(out).item() + if max_val > 1: + max_flat_index = torch.argmax(out).item() + max_row = max_flat_index // M + max_col = max_flat_index % M + assert max_row == 0 + assert max_val == 2.0 + print('scale_coords', scale_row, scale_col, 'block_idx', block_idx, 'max_coords', max_row, max_col, 'max_val', max_val) + + # break + # break @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( From 8ecba5f861df9eb7f0d86959f7606f6a3b417392 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 21:51:06 -0800 Subject: [PATCH 09/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 81 ++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 87601b3a15..bee5454311 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -18,6 +18,7 @@ swap_linear_with_mx_inference_linear, swap_linear_with_mx_linear, ) +from torchao.prototype.mx_formats.mx_tensor import MXTensor from torchao.quantization.utils import compute_error from torchao.utils import ( TORCH_VERSION_AT_LEAST_2_4, @@ -244,13 +245,33 @@ def test_filter_fn(): assert type(m2[0]) == MXInferenceLinear assert type(m2[1]) == torch.nn.Linear +# copy-pasted from https://github.com/drisspg/transformer_nuggets/blob/12bf63d334900d57958f839f273f5bca78a8f4a1/transformer_nuggets/mx/to_blocked.py#L54C1-L62C76 +# and modified to return 128x4 instead of 32x16 +def _to_blocked_single(scales: torch.Tensor) -> torch.Tensor: + """Assume that we have a 128x4 block of scales in K Major order + + To see more information on the individual tile layout: + https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + assert scales.shape == (128, 4) + scales_tiled = scales.view(4, 32, 4) # view as 4 - (32, 4) tiles + return scales_tiled.transpose(0, 1).reshape(128, 4).contiguous() # Interleave tiles + +def test_to_blocked(): + scales = torch.arange(128 * 4).reshape(128, 4) / 4 + print('orig') + print(scales) + print('blocked') + print(_to_blocked_single(scales)) + # looks right! + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( not is_sm_at_least_100(), reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", ) -def test_scaled_mm_mxfp8(): +def test_scaled_mm_mxfp8_scales_one(): # basic numerics with all scales 1.0 # next: other scale values @@ -289,6 +310,64 @@ def test_scaled_mm_mxfp8(): print(out) print(torch.max(out)) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif( + not is_sm_at_least_100(), + reason="blockwise torch._scaled_mm requires CUDA 10.0 or higher", +) +def test_scaled_mm_mxfp8_mxtensor(): + # baseline 1: fp32 + # experiment 1: emulated MX from MXTensor + # experiment 2: real MX gemm + + # results so far: + # * experiment 1 is very close to experiment 2 + # * experiments 1 and 2 are far from baseline (lol!) + + # M, K, N = 8192, 4096, 8192 + M, K, N = 128, 128, 128 + BLOCK_SIZE = 32 + a_fp32 = torch.randn(M, K, device="cuda", dtype=torch.float32) + b_fp32 = torch.randn(N, K, device="cuda", dtype=torch.float32).t().contiguous() + + a_mx = MXTensor.to_mx(a_fp32, torch.float8_e4m3fn, BLOCK_SIZE) + b_mx = MXTensor.to_mx(b_fp32, torch.float8_e4m3fn, BLOCK_SIZE).t() + a_s0 = a_mx._scale_e8m0.reshape(M, -1) + a_s1 = _to_blocked_single(a_s0) + b_s0 = b_mx._scale_e8m0.reshape(N, -1) + b_s1 = _to_blocked_single(b_s0) + + # ones_scale = torch.full((M, K // BLOCK_SIZE), 127, dtype=torch.uint8, device="cuda") + + out_ref = a_fp32 @ b_fp32 + print('baseline', out_ref) + + out_mx_emulated = a_mx @ b_mx + print('mx_emulated', out_mx_emulated) + + out_mx_real = torch._scaled_mm( + a_mx._data, + b_mx._data, + # a_scales is really b_scales, and vice versa. Probably switched in cuBLAS kernel? + _to_blocked_single(b_mx._scale_e8m0.reshape(N, -1)), + _to_blocked_single(a_mx._scale_e8m0.reshape(M, -1)), + None, + None, + torch.float32, + False, + None, + None, + DataType.E8M0, + ) + print('mx_real', out_mx_real) + + sqnr_baseline_to_emulated_mx = compute_error(out_ref, out_mx_emulated) + sqnr_baseline_to_real_mx = compute_error(out_ref, out_mx_real) + sqnr_emulated_mx_to_real_mx = compute_error(out_mx_emulated, out_mx_real) + print('sqnr baseline -> emulated_mx', sqnr_baseline_to_emulated_mx) + print('sqnr baseline -> real_mx', sqnr_baseline_to_real_mx) + print('sqnr emulated_mx -> real_mx', sqnr_emulated_mx_to_real_mx) + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( From c5df6d8b211518ed8a591860c2a3e44ac7b1ee89 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 21:55:48 -0800 Subject: [PATCH 10/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index bee5454311..50f4729f67 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -339,7 +339,7 @@ def test_scaled_mm_mxfp8_mxtensor(): # ones_scale = torch.full((M, K // BLOCK_SIZE), 127, dtype=torch.uint8, device="cuda") - out_ref = a_fp32 @ b_fp32 + out_ref = a_fp32 @ b_fp32.t() print('baseline', out_ref) out_mx_emulated = a_mx @ b_mx From 8daa0d08cde19d222e5356c29dc7fe3f6903c087 Mon Sep 17 00:00:00 2001 From: vasiliy Date: Mon, 27 Jan 2025 21:56:42 -0800 Subject: [PATCH 11/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 50f4729f67..5c1ae6bcd7 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -321,8 +321,8 @@ def test_scaled_mm_mxfp8_mxtensor(): # experiment 2: real MX gemm # results so far: - # * experiment 1 is very close to experiment 2 - # * experiments 1 and 2 are far from baseline (lol!) + # * baseline SQNR vs both experiments is ~27 + # * SQNR between experiment 1 and 2 is ~155 (near perfect match) # M, K, N = 8192, 4096, 8192 M, K, N = 128, 128, 128 From be7806c9c2fe04c011baf3c200bf8800553f98bf Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 28 Jan 2025 12:42:02 -0800 Subject: [PATCH 12/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 106 ++++++++++++-------- 1 file changed, 64 insertions(+), 42 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index 5c1ae6bcd7..c07ec42eb2 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -25,6 +25,9 @@ is_sm_at_least_89, is_sm_at_least_100, ) +from transformer_nuggets.mx.to_blocked import ( + to_blocked, +) torch.manual_seed(2) @@ -265,6 +268,16 @@ def test_to_blocked(): print(_to_blocked_single(scales)) # looks right! +def test_to_blocked_manual_v2(): + scales = torch.arange(128 * 4 * 2).reshape(128 * 2, 4) / 4 + torch.set_printoptions(profile="full", linewidth=280) + print('orig') + print(scales) + print('blocked') + print(to_blocked(scales)) + # looks right! + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.skipif( @@ -324,49 +337,58 @@ def test_scaled_mm_mxfp8_mxtensor(): # * baseline SQNR vs both experiments is ~27 # * SQNR between experiment 1 and 2 is ~155 (near perfect match) - # M, K, N = 8192, 4096, 8192 - M, K, N = 128, 128, 128 - BLOCK_SIZE = 32 - a_fp32 = torch.randn(M, K, device="cuda", dtype=torch.float32) - b_fp32 = torch.randn(N, K, device="cuda", dtype=torch.float32).t().contiguous() - - a_mx = MXTensor.to_mx(a_fp32, torch.float8_e4m3fn, BLOCK_SIZE) - b_mx = MXTensor.to_mx(b_fp32, torch.float8_e4m3fn, BLOCK_SIZE).t() - a_s0 = a_mx._scale_e8m0.reshape(M, -1) - a_s1 = _to_blocked_single(a_s0) - b_s0 = b_mx._scale_e8m0.reshape(N, -1) - b_s1 = _to_blocked_single(b_s0) - - # ones_scale = torch.full((M, K // BLOCK_SIZE), 127, dtype=torch.uint8, device="cuda") - - out_ref = a_fp32 @ b_fp32.t() - print('baseline', out_ref) - - out_mx_emulated = a_mx @ b_mx - print('mx_emulated', out_mx_emulated) - - out_mx_real = torch._scaled_mm( - a_mx._data, - b_mx._data, - # a_scales is really b_scales, and vice versa. Probably switched in cuBLAS kernel? - _to_blocked_single(b_mx._scale_e8m0.reshape(N, -1)), - _to_blocked_single(a_mx._scale_e8m0.reshape(M, -1)), - None, - None, - torch.float32, - False, - None, - None, - DataType.E8M0, + print() + shapes_to_try = ( + (128, 128, 128), + (128, 256, 512), + (256, 512, 128), + (512, 128, 256), + (4096, 4096, 4096), + (4096, 8192, 16384), + (8192, 16384, 4096), + (16384, 4096, 8192), ) - print('mx_real', out_mx_real) - - sqnr_baseline_to_emulated_mx = compute_error(out_ref, out_mx_emulated) - sqnr_baseline_to_real_mx = compute_error(out_ref, out_mx_real) - sqnr_emulated_mx_to_real_mx = compute_error(out_mx_emulated, out_mx_real) - print('sqnr baseline -> emulated_mx', sqnr_baseline_to_emulated_mx) - print('sqnr baseline -> real_mx', sqnr_baseline_to_real_mx) - print('sqnr emulated_mx -> real_mx', sqnr_emulated_mx_to_real_mx) + for M, K, N in shapes_to_try: + print('MKN', M, K, N) + BLOCK_SIZE = 32 + a_fp32 = torch.randn(M, K, device="cuda", dtype=torch.float32) + b_fp32 = torch.randn(N, K, device="cuda", dtype=torch.float32) + + a_mx = MXTensor.to_mx(a_fp32, torch.float8_e4m3fn, BLOCK_SIZE) + b_mx = MXTensor.to_mx(b_fp32, torch.float8_e4m3fn, BLOCK_SIZE).t() + a_s0 = a_mx._scale_e8m0.reshape(M, -1) + a_s1 = to_blocked(a_s0) + b_s0 = b_mx._scale_e8m0.reshape(N, -1) + b_s1 = to_blocked(b_s0) + + out_ref = a_fp32 @ b_fp32.t() + # print('baseline', out_ref) + + out_mx_emulated = a_mx @ b_mx + # print('mx_emulated', out_mx_emulated) + + out_mx_real = torch._scaled_mm( + a_mx._data, + b_mx._data, + # a_scales is really b_scales, and vice versa. Probably switched in cuBLAS kernel? + b_s1, + a_s1, + None, + None, + torch.float32, + False, + None, + None, + DataType.E8M0, + ) + # print('mx_real', out_mx_real) + + sqnr_baseline_to_emulated_mx = compute_error(out_ref, out_mx_emulated) + sqnr_baseline_to_real_mx = compute_error(out_ref, out_mx_real) + sqnr_emulated_mx_to_real_mx = compute_error(out_mx_emulated, out_mx_real) + print('sqnr baseline -> emulated_mx', sqnr_baseline_to_emulated_mx) + print('sqnr baseline -> real_mx', sqnr_baseline_to_real_mx) + print('sqnr emulated_mx -> real_mx', sqnr_emulated_mx_to_real_mx) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") From c8d768ea4899b9f241e68de5963cb03ac1b4013c Mon Sep 17 00:00:00 2001 From: vasiliy Date: Tue, 28 Jan 2025 15:18:47 -0800 Subject: [PATCH 13/13] Update [ghstack-poisoned] --- test/prototype/mx_formats/test_mx_linear.py | 8 ++-- torchao/prototype/mx_formats/mx_ops.py | 43 +++++++++++++++++---- 2 files changed, 40 insertions(+), 11 deletions(-) diff --git a/test/prototype/mx_formats/test_mx_linear.py b/test/prototype/mx_formats/test_mx_linear.py index c07ec42eb2..a03e4c1651 100644 --- a/test/prototype/mx_formats/test_mx_linear.py +++ b/test/prototype/mx_formats/test_mx_linear.py @@ -56,19 +56,19 @@ def run_around_tests(): @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") @pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES) @pytest.mark.parametrize("bias", [True, False]) -@pytest.mark.parametrize("input_shape", [(4, 8), (1, 4, 8), (1, 1, 4, 8)]) +@pytest.mark.parametrize("input_shape", [(128, 256), (1, 128, 256), (1, 1, 128, 256)]) def test_linear_eager(elem_dtype, bias, input_shape): """ Smoke test for training linear module with mx weight """ grad_shape = list(input_shape) - grad_shape[-1] = 6 + grad_shape[-1] = 128 m = nn.Sequential( - nn.Linear(8, 6, bias=bias, device="cuda"), + nn.Linear(256, 128, bias=bias, device="cuda"), ) m_mx = copy.deepcopy(m) - block_size = 2 + block_size = 32 swap_linear_with_mx_linear(m_mx, elem_dtype, block_size) x_ref = torch.randn(*input_shape, device="cuda").requires_grad_() diff --git a/torchao/prototype/mx_formats/mx_ops.py b/torchao/prototype/mx_formats/mx_ops.py index 57fb0d54b4..647a43ae58 100644 --- a/torchao/prototype/mx_formats/mx_ops.py +++ b/torchao/prototype/mx_formats/mx_ops.py @@ -27,6 +27,9 @@ MXTensor, tensor_size_hp_to_fp4x2, ) +from transformer_nuggets.mx.to_blocked import ( + to_blocked, +) aten = torch.ops.aten @@ -63,13 +66,39 @@ def mx_mm(aten_op, args, kwargs=None): a = args[0] b = args[1] assert isinstance(a, MXTensor) and isinstance(b, MXTensor) - a_hp = a.to_dtype(a._orig_dtype) - b_hp = b.to_dtype(b._orig_dtype) - # assert memory layout we expect to be required in hardware - assert a_hp.is_contiguous() - assert b_hp.t().is_contiguous() - res = aten_op(a_hp, b_hp) - return res + + if a._data.dtype is torch.float8_e4m3fn and b._data.dtype is torch.float8_e4m3fn: + + assert a._block_size == 32 and b._block_size == 32 + + a_s0 = a._scale_e8m0.reshape(a._data.shape[0], -1) + a_s1 = to_blocked(a_s0) + b_s0 = b._scale_e8m0.reshape(b._data.shape[1], -1) + b_s1 = to_blocked(b_s0) + out_mx_real = torch._scaled_mm( + a._data, + b._data, + # a_scales is really b_scales, and vice versa. Probably switched in cuBLAS kernel? + b_s1, + a_s1, + None, + None, + a._orig_dtype, + False, + None, + None, + 1, # DataType.E8M0 + ) + return out_mx_real + + else: + a_hp = a.to_dtype(a._orig_dtype) + b_hp = b.to_dtype(b._orig_dtype) + # assert memory layout we expect to be required in hardware + assert a_hp.is_contiguous() + assert b_hp.t().is_contiguous() + res = aten_op(a_hp, b_hp) + return res @implements([aten.t.default])