9
9
import re
10
10
import unittest
11
11
import warnings
12
+ from itertools import product
12
13
13
14
import pytest
14
15
52
53
is_H100 = torch .cuda .is_available () and torch .cuda .get_device_capability () >= (9 , 0 )
53
54
54
55
56
+ def filtered_parametrize (param_list , filter_func = None ):
57
+ """
58
+ A decorator that works like pytest.mark.parametrize but filters out
59
+ unwanted parameter combinations.
60
+
61
+ :param param_list: A list of tuples, each containing (arg_name, [arg_values])
62
+ :param filter_func: A function that takes a dictionary of parameter names and values,
63
+ and returns True for valid combinations, False otherwise
64
+ """
65
+
66
+ def decorator (func ):
67
+ arg_names = [param [0 ] for param in param_list ]
68
+ arg_values = [param [1 ] for param in param_list ]
69
+
70
+ all_combinations = product (* arg_values )
71
+ if filter_func :
72
+ valid_combinations = [
73
+ combo
74
+ for combo in all_combinations
75
+ if filter_func (dict (zip (arg_names , combo )))
76
+ ]
77
+ else :
78
+ valid_combinations = list (all_combinations )
79
+
80
+ return pytest .mark .parametrize (
81
+ argnames = arg_names , argvalues = valid_combinations
82
+ )(func )
83
+
84
+ return decorator
85
+
86
+
55
87
def bitwise_identical (a : Float8Tensor , b : Float8Tensor ) -> bool :
56
88
assert torch .all (a ._data == b ._data ).item (), "scales are not identical"
57
89
assert torch .all (a ._data == b ._data ).item (), "data is not identical"
@@ -230,17 +262,35 @@ def _test_linear_impl(
230
262
# verify initialization flags got updated
231
263
assert m_fp8 .is_amax_initialized , "Amax was not properly initialized"
232
264
233
- @pytest .mark .parametrize ("emulate" , [True , False ] if is_H100 else [True ])
234
- @pytest .mark .parametrize ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )])
235
- @pytest .mark .parametrize ("linear_type" , [LinearType .DELAYED , LinearType .DYNAMIC ])
236
- @pytest .mark .parametrize (
237
- "scaling_type_x" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
238
- )
239
- @pytest .mark .parametrize (
240
- "scaling_type_w" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
241
- )
242
- @pytest .mark .parametrize (
243
- "scaling_type_dL_dY" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
265
+ @staticmethod
266
+ def is_valid_combination (params ):
267
+ if not params ["emulate" ]:
268
+ if not torch .cuda .is_available ():
269
+ return False
270
+ if torch .cuda .get_device_capability () < (9 , 0 ):
271
+ return False
272
+
273
+ if params ["linear_type" ] == LinearType .DYNAMIC :
274
+ return all (
275
+ params [key ] == TensorScalingType .DYNAMIC
276
+ for key in ["scaling_type_x" , "scaling_type_w" , "scaling_type_dL_dY" ]
277
+ )
278
+
279
+ return True
280
+
281
+ @filtered_parametrize (
282
+ [
283
+ ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )]),
284
+ ("linear_type" , [LinearType .DELAYED , LinearType .DYNAMIC ]),
285
+ ("emulate" , [True , False ] if is_H100 else [True ]),
286
+ ("scaling_type_x" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]),
287
+ ("scaling_type_w" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]),
288
+ (
289
+ "scaling_type_dL_dY" ,
290
+ [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ],
291
+ ),
292
+ ],
293
+ filter_func = is_valid_combination ,
244
294
)
245
295
@unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
246
296
def test_linear_nobias (
@@ -252,28 +302,6 @@ def test_linear_nobias(
252
302
scaling_type_w : TensorScalingType ,
253
303
scaling_type_dL_dY : TensorScalingType ,
254
304
):
255
- if not emulate :
256
- if not torch .cuda .is_available ():
257
- warnings .warn ("CUDA not available" )
258
- pytest .skip ()
259
- elif torch .cuda .get_device_capability () < (9 , 0 ):
260
- warnings .warn (
261
- f"CUDA capability { torch .cuda .get_device_capability ()} < (9.0)"
262
- )
263
- pytest .skip ()
264
- if linear_type is LinearType .DYNAMIC :
265
- # Only test one combination of scaling types, as they are a no-op
266
- # for Float8DynamicLinear. It would be cleaner to split into two
267
- # tests, but IMO not worth it since Float8DynamicLinear will be
268
- # deleted soon
269
- is_all_dynamic = (
270
- scaling_type_x is TensorScalingType .DYNAMIC
271
- and scaling_type_w is TensorScalingType .DYNAMIC
272
- and scaling_type_dL_dY is TensorScalingType .DYNAMIC
273
- )
274
- if not is_all_dynamic :
275
- pytest .skip ()
276
-
277
305
x = torch .randn (* x_shape , device = "cuda" )
278
306
m_ref = nn .Linear (16 , 32 , bias = False , device = "cuda" )
279
307
self ._test_linear_impl (
@@ -286,20 +314,20 @@ def test_linear_nobias(
286
314
scaling_type_dL_dY ,
287
315
)
288
316
289
- @pytest . mark . parametrize ( "emulate" , [ True , False ] if is_H100 else [ True ])
290
- @ pytest . mark . parametrize ( "x_shape" , [( 16 , 16 ), ( 2 , 16 , 16 ), ( 3 , 2 , 16 , 16 )])
291
- @ pytest . mark . parametrize ( "linear_type " , [LinearType . DELAYED , LinearType . DYNAMIC ])
292
- @ pytest . mark . parametrize (
293
- "scaling_type_x " , [TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]
294
- )
295
- @ pytest . mark . parametrize (
296
- "scaling_type_w" , [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]
297
- )
298
- @ pytest . mark . parametrize (
299
- "scaling_type_dL_dY" , [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]
300
- )
301
- @ pytest . mark . parametrize (
302
- "linear_dtype" , [ torch . float16 , torch . bfloat16 , torch . float32 ]
317
+ @filtered_parametrize (
318
+ [
319
+ ( "x_shape " , [( 16 , 16 ), ( 2 , 16 , 16 ), ( 3 , 2 , 16 , 16 )]),
320
+ ( "linear_type" , [ LinearType . DELAYED , LinearType . DYNAMIC ]),
321
+ ( "emulate " , [True , False ] if is_H100 else [ True ]),
322
+ ( "scaling_type_x" , [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]),
323
+ ( "scaling_type_w" , [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ]),
324
+ (
325
+ "scaling_type_dL_dY" ,
326
+ [ TensorScalingType . DELAYED , TensorScalingType . DYNAMIC ],
327
+ ),
328
+ ( "linear_dtype" , [ torch . float16 , torch . bfloat16 , torch . float32 ]),
329
+ ],
330
+ filter_func = is_valid_combination ,
303
331
)
304
332
@unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
305
333
def test_linear_bias (
@@ -312,28 +340,6 @@ def test_linear_bias(
312
340
emulate : bool ,
313
341
linear_dtype : torch .dtype ,
314
342
):
315
- if not emulate :
316
- if not torch .cuda .is_available ():
317
- warnings .warn ("CUDA not available" )
318
- pytest .skip ()
319
- elif torch .cuda .get_device_capability () < (9 , 0 ):
320
- warnings .warn (
321
- f"CUDA capability { torch .cuda .get_device_capability ()} < (9.0)"
322
- )
323
- pytest .skip ()
324
- if linear_type is LinearType .DYNAMIC :
325
- # Only test one combination of scaling types, as they are a no-op
326
- # for Float8DynamicLinear. It would be cleaner to split into two
327
- # tests, but IMO not worth it since Float8DynamicLinear will be
328
- # deleted soon
329
- is_all_dynamic = (
330
- scaling_type_x is TensorScalingType .DYNAMIC
331
- and scaling_type_w is TensorScalingType .DYNAMIC
332
- and scaling_type_dL_dY is TensorScalingType .DYNAMIC
333
- )
334
- if not is_all_dynamic :
335
- pytest .skip ()
336
-
337
343
x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
338
344
m_ref = nn .Linear (16 , 32 , bias = True , device = "cuda" , dtype = linear_dtype )
339
345
self ._test_linear_impl (
0 commit comments