5
5
import triton
6
6
from torch ._inductor .utils import do_bench_using_profiling
7
7
8
+ from torchao .prototype .mx_formats .custom_cast import (
9
+ triton_to_mxfp8_dim1 ,
10
+ )
8
11
from torchao .prototype .mx_formats .mx_tensor import to_mx
9
12
10
13
torch .manual_seed (0 )
@@ -49,6 +52,12 @@ def to_mx_dim0_reference(x_hp, block_size):
49
52
return data_d0 , scale_d0
50
53
51
54
55
+ def to_mx_dim1_reference (x_hp , block_size ):
56
+ x_hp = x_hp .t ().contiguous ()
57
+ scale_d1 , data_d1 = to_mx (x_hp , torch .float8_e4m3fn , block_size )
58
+ return data_d1 .t (), scale_d1
59
+
60
+
52
61
def benchmark_cuda_function_in_microseconds (func : Callable , * args , ** kwargs ) -> float :
53
62
"""Thin wrapper around do_bench_using_profiling"""
54
63
no_args = lambda : func (* args , ** kwargs )
@@ -67,7 +76,7 @@ def run(
67
76
print (f"torch version: { torch .__version__ } " )
68
77
print (f"triton version: { triton .__version__ } " )
69
78
print (f"mode: { mode } " )
70
- assert mode in ("dim0" , "dim1" , "dim0_dim1" , "dim0_mx" )
79
+ assert mode in ("dim0" , "dim1" , "dim0_dim1" , "dim0_mx" , "dim1_mx" , "dim1_mx_triton" )
71
80
72
81
x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" ) * 1000
73
82
@@ -144,6 +153,41 @@ def run(
144
153
bytes_w = (y_d0 .numel () + s_d0 .numel ()) * bytes_per_el_fp8
145
154
bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
146
155
156
+ elif mode == "dim1_mx" :
157
+ to_mx_dim1_reference_c = torch .compile (to_mx_dim1_reference )
158
+ y_d1 , s_d1 = to_mx_dim1_reference_c (x , BLOCK_SIZE )
159
+
160
+ for _ in range (2 ):
161
+ __ = to_mx_dim1_reference_c (x , BLOCK_SIZE )
162
+ time_us = benchmark_cuda_function_in_microseconds (
163
+ lambda x , b : to_mx_dim1_reference_c (x , BLOCK_SIZE ),
164
+ x ,
165
+ BLOCK_SIZE ,
166
+ )
167
+
168
+ assert y_d1 .dtype == torch .float8_e4m3fn
169
+ assert s_d1 .dtype == torch .uint8
170
+ bytes_r = x .numel () * bytes_per_el_bf16
171
+ bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
172
+ bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
173
+
174
+ elif mode == "dim1_mx_triton" :
175
+ y_d1 , s_d1 = triton_to_mxfp8_dim1 (x , inner_block_size = BLOCK_SIZE )
176
+
177
+ for _ in range (2 ):
178
+ __ = triton_to_mxfp8_dim1 (x , inner_block_size = BLOCK_SIZE )
179
+ time_us = benchmark_cuda_function_in_microseconds (
180
+ lambda x , b : triton_to_mxfp8_dim1 (x , inner_block_size = BLOCK_SIZE ),
181
+ x ,
182
+ BLOCK_SIZE ,
183
+ )
184
+
185
+ assert y_d1 .dtype == torch .float8_e4m3fn
186
+ assert s_d1 .dtype == torch .float8_e8m0fnu
187
+ bytes_r = x .numel () * bytes_per_el_bf16
188
+ bytes_w = (y_d1 .numel () + s_d1 .numel ()) * bytes_per_el_fp8
189
+ bps = (bytes_r + bytes_w ) / (time_us / 1e6 )
190
+
147
191
else :
148
192
raise AssertionError (f"unknown mode { mode } " )
149
193
0 commit comments