55import triton
66from torch ._inductor .utils import do_bench_using_profiling
77
8+ from torchao .prototype .mx_formats .custom_cast import (
9+ to_mxfp8_dim1 ,
10+ )
811from torchao .prototype .mx_formats .mx_tensor import to_mx
912
1013torch .manual_seed (0 )
@@ -49,6 +52,12 @@ def to_mx_dim0_reference(x_hp, block_size):
4952 return data_d0 , scale_d0
5053
5154
55+ def to_mx_dim1_reference (x_hp , block_size ):
56+ x_hp = x_hp .t ().contiguous ()
57+ scale_d1 , data_d1 = to_mx (x_hp , torch .float8_e4m3fn , block_size )
58+ return data_d1 .t (), scale_d1
59+
60+
5261def benchmark_cuda_function_in_microseconds (func : Callable , * args , ** kwargs ) -> float :
5362 """Thin wrapper around do_bench_using_profiling"""
5463 no_args = lambda : func (* args , ** kwargs )
@@ -67,7 +76,7 @@ def run(
6776 print (f"torch version: { torch .__version__ } " )
6877 print (f"triton version: { triton .__version__ } " )
6978 print (f"mode: { mode } " )
70- assert mode in ("dim0" , "dim1" , "dim0_dim1" , "dim0_mx" )
79+ assert mode in ("dim0" , "dim1" , "dim0_dim1" , "dim0_mx" , "dim1_mx" , "dim1_mx_triton" )
7180
7281 x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" ) * 1000
7382
@@ -144,6 +153,41 @@ def run(
144153 bytes_w = (y_d0 .numel () + s_d0 .numel ()) * bytes_per_el_fp8
145154 bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
146155
156+ elif mode == "dim1_mx" :
157+ to_mx_dim1_reference_c = torch .compile (to_mx_dim1_reference )
158+ y_d1 , s_d1 = to_mx_dim1_reference_c (x , BLOCK_SIZE )
159+
160+ for _ in range (2 ):
161+ __ = to_mx_dim1_reference_c (x , BLOCK_SIZE )
162+ time_us = benchmark_cuda_function_in_microseconds (
163+ lambda x , b : to_mx_dim1_reference_c (x , BLOCK_SIZE ),
164+ x ,
165+ BLOCK_SIZE ,
166+ )
167+
168+ assert y_d1 .dtype == torch .float8_e4m3fn
169+ assert s_d1 .dtype == torch .uint8
170+ bytes_r = x .numel () * bytes_per_el_bf16
171+ bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
172+ bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
173+
174+ elif mode == "dim1_mx_triton" :
175+ y_d1 , s_d1 = to_mxfp8_dim1 (x , inner_block_size = BLOCK_SIZE )
176+
177+ for _ in range (2 ):
178+ __ = to_mxfp8_dim1 (x , inner_block_size = BLOCK_SIZE )
179+ time_us = benchmark_cuda_function_in_microseconds (
180+ lambda x , b : to_mxfp8_dim1 (x , inner_block_size = BLOCK_SIZE ),
181+ x ,
182+ BLOCK_SIZE ,
183+ )
184+
185+ assert y_d1 .dtype == torch .float8_e4m3fn
186+ assert s_d1 .dtype == torch .float8_e8m0fnu
187+ bytes_r = x .numel () * bytes_per_el_bf16
188+ bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
189+ bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
190+
147191 else :
148192 raise AssertionError (f"unknown mode { mode } " )
149193
0 commit comments