4
4
# This source code is licensed under the BSD 3-Clause license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
import itertools
7
+ from enum import IntEnum
7
8
from typing import Optional
8
9
9
10
import fire
26
27
h100_peak_flops_fp16_tc = 989e12
27
28
h100_peak_tops_float8_tc = 1979e12
28
29
29
- dtype_to_peak_tops = {
30
+ # HGX B20 specs: https://www.nvidia.com/en-us/data-center/hgx/
31
+ # note: divided numbers from ^ by 2 to undo the effects of sparsity
32
+ # TODO(this PR): I'm achieving 5% of peak TFLOPS with bf16 and float8,
33
+ # something seems funky
34
+ b200_peak_flops_float32 = 600e12
35
+ b200_peak_flops_fp16_tc = 18e15
36
+ b200_peak_tops_float8_tc = 36e15
37
+ b200_peak_tops_float4_tc = 72e15
38
+
39
+ dtype_to_peak_tops_h100 = {
30
40
torch .float32 : h100_peak_flops_float32 ,
31
41
torch .float16 : h100_peak_flops_fp16_tc ,
32
42
torch .bfloat16 : h100_peak_flops_fp16_tc ,
33
43
torch .float8_e4m3fn : h100_peak_tops_float8_tc ,
34
44
torch .float8_e5m2 : h100_peak_tops_float8_tc ,
35
45
}
36
46
47
+ dtype_to_peak_tops_b200 = {
48
+ torch .float32 : b200_peak_flops_float32 ,
49
+ torch .float16 : b200_peak_flops_fp16_tc ,
50
+ torch .bfloat16 : b200_peak_flops_fp16_tc ,
51
+ torch .float8_e4m3fn : b200_peak_tops_float8_tc ,
52
+ torch .float8_e5m2 : b200_peak_tops_float8_tc ,
53
+ # TODO float4
54
+ }
55
+
56
+ # TODO(this PR): switch automatically by detected hardware type
57
+ # TODO(this PR): fp4 is currently using fp8's peak tops below, fix it
58
+ dtype_to_peak_tops = dtype_to_peak_tops_b200
59
+
60
+
61
+ # not for land, matching https://www.internalfb.com/phabricator/paste/view/P1717686991
62
+ class DataType (IntEnum ):
63
+ DEFAULT = 0
64
+ E8M0 = 1
65
+ FP4 = 2
66
+ UFP8 = 3
67
+
37
68
38
69
def benchmark_fn_in_sec (f , * args , ** kwargs ):
39
70
# Manual warmup
@@ -75,6 +106,7 @@ def run(
75
106
N : Optional [int ] = None ,
76
107
use_gpu_kernel_time : bool = False ,
77
108
scaling_granularity : str = "tensorwise" ,
109
+ blockwise_dtype : Optional [str ] = None ,
78
110
):
79
111
device = "cuda"
80
112
@@ -85,15 +117,17 @@ def run(
85
117
"K" ,
86
118
"N" ,
87
119
"ref_time_s" ,
88
- "fp8_time_s " ,
89
- "fp8_speedup " ,
120
+ "lowp_time_s " ,
121
+ "lowp_speedup " ,
90
122
)
91
123
results = []
92
124
93
125
dtype = torch .bfloat16
94
126
name_to_shapes = get_name_to_shapes_iter (shape_gen_name , M , K , N )
95
127
fast_accum_vals = [True , False ]
96
- scaling_granularity = ScalingGranularity (scaling_granularity )
128
+ # Note: blockwise not in enum because blockwise is in prototype
129
+ if scaling_granularity != "blockwise" :
130
+ scaling_granularity = ScalingGranularity (scaling_granularity )
97
131
98
132
for idx , (fast_accum , (name , (M , K , N ))) in enumerate (
99
133
itertools .product (fast_accum_vals , name_to_shapes )
@@ -119,28 +153,97 @@ def run(
119
153
# raw float8 matmul (upper bound for what we can achive in eager mode)
120
154
# TODO(future): add e5m2
121
155
d1 , d2 , d3 = torch .float8_e4m3fn , torch .float8_e4m3fn , dtype
122
- A = torch .zeros (M , K , device = device , dtype = d1 )
123
- B = torch .zeros (K , N , device = device , dtype = d2 ).t ().contiguous ().t ()
156
+ A = torch .randn (M , K , device = device ). to ( d1 )
157
+ B = torch .randn (K , N , device = device ). to ( d2 ).t ().contiguous ().t ()
124
158
if scaling_granularity == ScalingGranularity .TENSORWISE :
125
159
scale_a = torch .tensor ([1.0 ], device = device )
126
160
scale_b = torch .tensor ([1.0 ], device = device )
127
- else :
128
- assert scaling_granularity == ScalingGranularity .AXISWISE , "unsupported"
161
+ elif scaling_granularity == ScalingGranularity .AXISWISE :
129
162
scale_a = torch .ones (M , 1 , device = device )
130
163
scale_b = torch .ones (1 , N , device = device )
164
+ elif scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3" :
165
+ # TODO(this PR): also block size 16
166
+ BLOCK_SIZE = 32
167
+ A = torch .randint (128 , (M , K ), device = device , dtype = torch .uint8 ).view (
168
+ torch .float8_e4m3fn
169
+ )
170
+ B = (
171
+ torch .randint (128 , (N , K ), device = device , dtype = torch .uint8 )
172
+ .view (torch .float8_e4m3fn )
173
+ .t ()
174
+ )
175
+ scale_a = torch .randint (
176
+ 128 , (M , K // BLOCK_SIZE ), dtype = torch .uint8 , device = "cuda"
177
+ )
178
+ scale_b = torch .randint (
179
+ 128 , (N , K // BLOCK_SIZE ), dtype = torch .uint8 , device = "cuda"
180
+ ).t ()
181
+ elif scaling_granularity == "blockwise" and blockwise_dtype == "float4" :
182
+ # TODO(this PR): also block size 16
183
+ BLOCK_SIZE = 16
184
+ A = torch .randint (128 , (M , K // 2 ), device = device , dtype = torch .uint8 ).view (
185
+ torch .float8_e4m3fn
186
+ )
187
+ B = (
188
+ torch .randint (128 , (N , K // 2 ), device = device , dtype = torch .uint8 )
189
+ .view (torch .float8_e4m3fn )
190
+ .t ()
191
+ )
192
+ scale_a = torch .randint (
193
+ 128 , (M , K // BLOCK_SIZE ), dtype = torch .uint8 , device = "cuda"
194
+ )
195
+ scale_b = torch .randint (
196
+ 128 , (N , K // BLOCK_SIZE ), dtype = torch .uint8 , device = "cuda"
197
+ ).t ()
198
+ else :
199
+ raise AssertionError (f"unsupported granularity { scaling_granularity } " )
131
200
132
201
def do_matmul (A , B ):
133
202
nonlocal scale_a
134
203
nonlocal scale_b
135
- return torch ._scaled_mm (
136
- A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = fast_accum
137
- )
204
+
205
+ if scaling_granularity == "blockwise" and blockwise_dtype == "float8_e4m3" :
206
+ return torch ._scaled_mm (
207
+ A ,
208
+ B ,
209
+ scale_a ,
210
+ scale_b ,
211
+ bias = None ,
212
+ scale_result = None ,
213
+ out_dtype = d3 ,
214
+ use_fast_accum = fast_accum ,
215
+ a_dtype = None , # inferred from A
216
+ b_dtype = None , # inferred from B
217
+ scale_dtype = DataType .E8M0 ,
218
+ )
219
+ elif scaling_granularity == "blockwise" and blockwise_dtype == "float4" :
220
+ return torch ._scaled_mm (
221
+ A ,
222
+ B ,
223
+ scale_a ,
224
+ scale_b ,
225
+ bias = None ,
226
+ scale_result = None ,
227
+ out_dtype = d3 ,
228
+ use_fast_accum = fast_accum ,
229
+ a_dtype = DataType .FP4 ,
230
+ b_dtype = DataType .FP4 ,
231
+ scale_dtype = DataType .E8M0 ,
232
+ )
233
+
234
+ else :
235
+ return torch ._scaled_mm (
236
+ A , B , scale_a , scale_b , out_dtype = d3 , use_fast_accum = fast_accum
237
+ )
238
+
239
+ # test
240
+ # res = do_matmul(A, B)
138
241
139
242
fp8_time_sec , fp8_tops_sec , fp8_pct_top_peak = do_benchmarks (
140
243
tops , dtype_to_peak_tops [d1 ], use_gpu_kernel_time , do_matmul , A , B
141
244
)
142
245
print (
143
- f"fp8 time_sec { fp8_time_sec :.2E} , tops/sec { fp8_tops_sec :.2E} , pct_peak { fp8_pct_top_peak :.3f} "
246
+ f"lowp time_sec { fp8_time_sec :.2E} , tops/sec { fp8_tops_sec :.2E} , pct_peak { fp8_pct_top_peak :.3f} "
144
247
)
145
248
146
249
del A , B , scale_a , scale_b
0 commit comments