14
14
15
15
import torch
16
16
import torch .utils .benchmark as benchmark
17
+ from float8_experimental .float8_linear import TensorScalingType
17
18
from float8_experimental .float8_linear_utils import (
18
19
get_float8_linear ,
20
+ linear_requires_sync ,
19
21
LinearType ,
20
22
sync_float8_amax_and_scale_history ,
21
23
)
@@ -68,6 +70,7 @@ class Experiment:
68
70
compiled : bool
69
71
use_fast_accum : bool
70
72
linear_type : str
73
+ scaling_repr : str
71
74
72
75
# 3 Times since we are calculating forward backward
73
76
@property
@@ -96,10 +99,17 @@ def main(
96
99
fast_accum_filter : Optional [bool ] = None ,
97
100
shape_name_filter : Optional [str ] = None ,
98
101
linear_type_filter : Optional [str ] = None ,
102
+ scaling_type_x : str = "delayed" ,
103
+ scaling_type_w : str = "delayed" ,
104
+ scaling_type_dL_dY : str = "delayed" ,
99
105
):
100
106
device = "cuda"
101
107
print (f"Compile is set to | { compile } " )
102
108
109
+ scaling_type_x = TensorScalingType (scaling_type_x )
110
+ scaling_type_w = TensorScalingType (scaling_type_w )
111
+ scaling_type_dL_dY = TensorScalingType (scaling_type_dL_dY )
112
+
103
113
# LLaMa 2 70B single-node weight shapes
104
114
# assumes fused attn.wqkv and ffn.w13
105
115
name_to_shapes_70b = {
@@ -134,9 +144,24 @@ def main(
134
144
LinearType .DELAYED if linear_type == "delayed" else LinearType .DYNAMIC
135
145
)
136
146
137
- linear_float8 = get_float8_linear (
138
- linear_type_enum , copy .deepcopy (linear_ref ), emulate = False
139
- )
147
+ if linear_type == "delayed" :
148
+ linear_float8 = get_float8_linear (
149
+ linear_type_enum ,
150
+ copy .deepcopy (linear_ref ),
151
+ emulate = False ,
152
+ scaling_type_x = scaling_type_x ,
153
+ scaling_type_w = scaling_type_w ,
154
+ scaling_type_dL_dY = scaling_type_dL_dY ,
155
+ )
156
+ scaling_repr = linear_float8 .scaling_repr ()
157
+ else :
158
+ linear_float8 = get_float8_linear (
159
+ linear_type_enum ,
160
+ copy .deepcopy (linear_ref ),
161
+ emulate = False ,
162
+ )
163
+ scaling_repr = None
164
+
140
165
if fast_accum :
141
166
linear_float8 .forward_config = ScaledMMConfig (False , True , False )
142
167
else :
@@ -150,7 +175,10 @@ def main(
150
175
if linear_type_enum == LinearType .DELAYED :
151
176
152
177
def float8_forw_backward ():
153
- sync_float8_amax_and_scale_history (linear_float8 )
178
+ if linear_requires_sync (
179
+ linear_type_enum , scaling_type_x , scaling_type_w , scaling_type_dL_dY
180
+ ):
181
+ sync_float8_amax_and_scale_history (linear_float8 )
154
182
linear_float8 (input_tensor ).sum ().backward ()
155
183
156
184
else :
@@ -197,6 +225,7 @@ def wrapper(*args, **kwargs):
197
225
compile ,
198
226
use_fast_accum = fast_accum ,
199
227
linear_type = linear_type ,
228
+ scaling_repr = scaling_repr ,
200
229
)
201
230
print (experiment )
202
231
print ("float8 speedup" , experiment .ref_time_sec / experiment .float8_time_sec )
@@ -209,6 +238,7 @@ def wrapper(*args, **kwargs):
209
238
"K" ,
210
239
"N" ,
211
240
"linear_type" ,
241
+ "scaling_repr" ,
212
242
"ref_dtype" ,
213
243
"compiled" ,
214
244
"use_fast_accum" ,
@@ -228,6 +258,7 @@ def wrapper(*args, **kwargs):
228
258
experiment .shape [1 ],
229
259
experiment .shape [2 ],
230
260
experiment .linear_type ,
261
+ experiment .scaling_repr ,
231
262
experiment .dtype ,
232
263
experiment .compiled ,
233
264
experiment .use_fast_accum ,
@@ -257,6 +288,7 @@ def wrapper(*args, **kwargs):
257
288
"name" ,
258
289
"shape" ,
259
290
"linear_type" ,
291
+ "scaling_repr" ,
260
292
"compiled" ,
261
293
"use_fast_accum" ,
262
294
"ref_time_sec" ,
@@ -280,15 +312,26 @@ def invoke_main() -> None:
280
312
parser .add_argument ("--fast_accum_filter" , type = bool , required = False )
281
313
parser .add_argument ("--shape_name_filter" , type = str , required = False )
282
314
parser .add_argument ("--linear_type_filter" , type = str , required = False )
315
+ parser .add_argument ("--scaling_type_x" , type = str , required = False )
316
+ parser .add_argument ("--scaling_type_w" , type = str , required = False )
317
+ parser .add_argument ("--scaling_type_dL_dY" , type = str , required = False )
283
318
args = parser .parse_args ()
284
319
output_path = Path (args .output_path ) if args .output_path is not None else None
320
+ kwargs = {}
321
+ if args .scaling_type_x is not None :
322
+ kwargs ["scaling_type_x" ] = args .scaling_type_x
323
+ if args .scaling_type_w is not None :
324
+ kwargs ["scaling_type_w" ] = args .scaling_type_w
325
+ if args .scaling_type_dL_dY is not None :
326
+ kwargs ["scaling_type_dL_dY" ] = args .scaling_type_dL_dY
285
327
main (
286
328
output_path ,
287
329
args .compile ,
288
330
args .n_limit ,
289
331
args .fast_accum_filter ,
290
332
args .shape_name_filter ,
291
333
args .linear_type_filter ,
334
+ ** kwargs ,
292
335
)
293
336
294
337
0 commit comments