16
16
get_name_to_shapes_iter ,
17
17
)
18
18
19
+ from torchao .ops import mx_fp4_bf16
20
+ from torchao .prototype .mx_formats .mx_tensor import to_mx
19
21
from torchao .testing .training .roofline_utils import get_specs
20
22
21
23
@@ -62,29 +64,38 @@ def run(
62
64
):
63
65
device = "cuda"
64
66
# TODO(future PR): this is ugly
65
- assert recipe in ("tensorwise" , "rowwise" , "mxfp8_cublas" ), "unsupported"
67
+ assert recipe in (
68
+ "tensorwise" ,
69
+ "rowwise" ,
70
+ "mxfp8_cublas" ,
71
+ "mxfp4_cutlass" ,
72
+ "nvfp4" ,
73
+ ), "unsupported"
74
+ use_fp4 = recipe in ("mxfp4_cutlass" , "nvfp4" )
66
75
67
76
specs = get_specs ()
68
77
bf16_peak_tops = specs ["bf16_peak_tops" ]
69
78
fp8_peak_tops = specs ["fp8_peak_tops" ]
79
+ fp4_peak_tops = specs ["fp4_peak_tops" ]
70
80
print (f"gpu_name: { torch .cuda .get_device_name (0 )} " )
71
- print (f"peak tops: bf16 { bf16_peak_tops :.2e} , fp8 { fp8_peak_tops :.2e} " )
72
-
81
+ print (
82
+ f"peak tops: bf16 { bf16_peak_tops :.2e} , fp8 { fp8_peak_tops :.2e} , fp4 { fp4_peak_tops :.2e} "
83
+ )
73
84
headers = (
74
85
"fast_accum" ,
75
86
"name" ,
76
87
"M" ,
77
88
"K" ,
78
89
"N" ,
79
- "ref_time_s " ,
80
- "fp8_time_s " ,
90
+ "time_s " ,
91
+ "speedup " ,
81
92
"fp8_speedup" ,
82
93
)
83
94
results = []
84
95
85
96
dtype = torch .bfloat16
86
97
name_to_shapes = get_name_to_shapes_iter (shape_gen_name , M , K , N )
87
- fast_accum_vals = [True , False ]
98
+ fast_accum_vals = [False ] if use_fp4 else [ True , False ]
88
99
89
100
for idx , (fast_accum , (name , (M , K , N ))) in enumerate (
90
101
itertools .product (fast_accum_vals , name_to_shapes )
@@ -107,38 +118,82 @@ def run(
107
118
108
119
del A
109
120
110
- # raw float8 matmul (upper bound for what we can achive in eager mode)
111
- # TODO(future): add e5m2
112
- d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , dtype
113
- A = torch .zeros (M , K , device = device , dtype = d1 )
114
- B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
121
+ A_hp = torch .randn (M , K , device = device )
122
+ B_hp_t = torch .randn (N , K , device = device )
123
+
124
+ if recipe == "mxfp4_cutlass" :
125
+ _ , A = to_mx (A_hp , torch .float4_e2m1fn_x2 , 32 )
126
+ _ , Bt = to_mx (B_hp_t , torch .float4_e2m1fn_x2 , 32 )
127
+ B = Bt .contiguous ().T
128
+ peak_tops = fp4_peak_tops
129
+ elif recipe == "nvfp4" :
130
+ from torchao .prototype .mx_formats .nvfp4_tensor import nvfp4_quantize
131
+
132
+ A_scales , A_data = nvfp4_quantize (A_hp , block_size = 16 )
133
+ B_scales , B_data = nvfp4_quantize (B_hp_t , block_size = 16 )
134
+ A = A_data .view (torch .float4_e2m1fn_x2 )
135
+ B = B_data .view (torch .float4_e2m1fn_x2 ).T
136
+ peak_tops = fp4_peak_tops
137
+ else :
138
+ # raw float8 matmul (upper bound for what we can achive in eager mode)
139
+ # TODO(future): add e5m2
140
+ d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , dtype
141
+ A = A_hp .to (d1 )
142
+ B = B_hp_t .to (d2 ).contiguous ().T
143
+ peak_tops = fp8_peak_tops
144
+
115
145
if recipe == "tensorwise" :
116
146
scale_a = torch .tensor ([1.0 ], device = device )
117
147
scale_b = torch .tensor ([1.0 ], device = device )
118
148
elif recipe == "rowwise" :
119
149
scale_a = torch .ones (M , 1 , device = device )
120
150
scale_b = torch .ones (1 , N , device = device )
121
- elif recipe == "mxfp8_cublas" :
151
+ elif recipe in ( "mxfp8_cublas" , "mxfp4_cutlass" ) :
122
152
scale_a = torch .ones (M , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
123
153
scale_b = torch .ones (N , K // 32 , device = device , dtype = torch .float8_e8m0fnu )
154
+ elif recipe == "nvfp4" :
155
+ # Use the blockwise scales from nvfp4_quantize
156
+ scale_a = A_scales .view (torch .float8_e4m3fn )
157
+ scale_b = B_scales .view (torch .float8_e4m3fn )
124
158
else :
125
159
assert False , f"unknown recipe { recipe } "
126
160
127
- def do_matmul (A , B ):
161
+ def do_matmul_fp8 (A , B ):
128
162
nonlocal scale_a
129
163
nonlocal scale_b
130
164
return torch ._scaled_mm (
131
165
A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = fast_accum
132
166
)
133
167
134
- fp8_time_sec , fp8_tops_sec , fp8_pct_top_peak = do_benchmarks (
135
- tops , fp8_peak_tops , use_gpu_kernel_time , do_matmul , A , B
168
+ def do_matmul_mxfp4 (A , B ):
169
+ nonlocal scale_a
170
+ nonlocal scale_b
171
+ return mx_fp4_bf16 (A , B , scale_a , scale_b )
172
+
173
+ def do_matmul_nvfp4 (A , B ):
174
+ nonlocal scale_a
175
+ nonlocal scale_b
176
+ return torch ._scaled_mm (A , B , scale_a , scale_b , out_dtype = dtype )
177
+
178
+ if recipe == "mxfp4_cutlass" :
179
+ do_matmul = do_matmul_mxfp4
180
+ elif recipe == "nvfp4" :
181
+ do_matmul = do_matmul_nvfp4
182
+ else :
183
+ do_matmul = do_matmul_fp8
184
+
185
+ time_sec , tops_sec , pct_top_peak = do_benchmarks (
186
+ tops , peak_tops , use_gpu_kernel_time , do_matmul , A , B
136
187
)
137
188
print (
138
- f"fp8 time_sec { fp8_time_sec :.2E} , tops/sec { fp8_tops_sec :.2E} , pct_peak { fp8_pct_top_peak :.3f} "
189
+ f"time_sec { time_sec :.2E} , tops/sec { tops_sec :.2E} , pct_peak { pct_top_peak :.3f} "
139
190
)
140
191
141
- del A , B , scale_a , scale_b
192
+ del A , B
193
+ if scale_a is not None :
194
+ del scale_a
195
+ if scale_b is not None :
196
+ del scale_b
142
197
143
198
results .append (
144
199
[
@@ -148,8 +203,8 @@ def do_matmul(A, B):
148
203
K ,
149
204
N ,
150
205
ref_time_sec ,
151
- fp8_time_sec ,
152
- ref_time_sec / fp8_time_sec ,
206
+ time_sec ,
207
+ ref_time_sec / time_sec ,
153
208
]
154
209
)
155
210
0 commit comments