Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit b1c0258

Browse files
committed
Add utility for filtering out skpped tests in large paremtrization groups
ghstack-source-id: 275f276 Pull Request resolved: #303
1 parent c57aa9e commit b1c0258

File tree

1 file changed

+75
-69
lines changed

1 file changed

+75
-69
lines changed

test/test_base.py

Lines changed: 75 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import re
1010
import unittest
1111
import warnings
12+
from itertools import product
1213

1314
import pytest
1415

@@ -52,6 +53,37 @@
5253
is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0)
5354

5455

56+
def filtered_parametrize(param_list, filter_func=None):
57+
"""
58+
A decorator that works like pytest.mark.parametrize but filters out
59+
unwanted parameter combinations.
60+
61+
:param param_list: A list of tuples, each containing (arg_name, [arg_values])
62+
:param filter_func: A function that takes a dictionary of parameter names and values,
63+
and returns True for valid combinations, False otherwise
64+
"""
65+
66+
def decorator(func):
67+
arg_names = [param[0] for param in param_list]
68+
arg_values = [param[1] for param in param_list]
69+
70+
all_combinations = product(*arg_values)
71+
if filter_func:
72+
valid_combinations = [
73+
combo
74+
for combo in all_combinations
75+
if filter_func(dict(zip(arg_names, combo)))
76+
]
77+
else:
78+
valid_combinations = list(all_combinations)
79+
80+
return pytest.mark.parametrize(
81+
argnames=arg_names, argvalues=valid_combinations
82+
)(func)
83+
84+
return decorator
85+
86+
5587
def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool:
5688
assert torch.all(a._data == b._data).item(), "scales are not identical"
5789
assert torch.all(a._data == b._data).item(), "data is not identical"
@@ -230,17 +262,35 @@ def _test_linear_impl(
230262
# verify initialization flags got updated
231263
assert m_fp8.is_amax_initialized, "Amax was not properly initialized"
232264

233-
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
234-
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
235-
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
236-
@pytest.mark.parametrize(
237-
"scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
238-
)
239-
@pytest.mark.parametrize(
240-
"scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
241-
)
242-
@pytest.mark.parametrize(
243-
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
265+
@staticmethod
266+
def is_valid_combination(params):
267+
if not params["emulate"]:
268+
if not torch.cuda.is_available():
269+
return False
270+
if torch.cuda.get_device_capability() < (9, 0):
271+
return False
272+
273+
if params["linear_type"] == LinearType.DYNAMIC:
274+
return all(
275+
params[key] == TensorScalingType.DYNAMIC
276+
for key in ["scaling_type_x", "scaling_type_w", "scaling_type_dL_dY"]
277+
)
278+
279+
return True
280+
281+
@filtered_parametrize(
282+
[
283+
("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]),
284+
("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]),
285+
("emulate", [True, False] if is_H100 else [True]),
286+
("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
287+
("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
288+
(
289+
"scaling_type_dL_dY",
290+
[TensorScalingType.DELAYED, TensorScalingType.DYNAMIC],
291+
),
292+
],
293+
filter_func=is_valid_combination,
244294
)
245295
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
246296
def test_linear_nobias(
@@ -252,28 +302,6 @@ def test_linear_nobias(
252302
scaling_type_w: TensorScalingType,
253303
scaling_type_dL_dY: TensorScalingType,
254304
):
255-
if not emulate:
256-
if not torch.cuda.is_available():
257-
warnings.warn("CUDA not available")
258-
pytest.skip()
259-
elif torch.cuda.get_device_capability() < (9, 0):
260-
warnings.warn(
261-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
262-
)
263-
pytest.skip()
264-
if linear_type is LinearType.DYNAMIC:
265-
# Only test one combination of scaling types, as they are a no-op
266-
# for Float8DynamicLinear. It would be cleaner to split into two
267-
# tests, but IMO not worth it since Float8DynamicLinear will be
268-
# deleted soon
269-
is_all_dynamic = (
270-
scaling_type_x is TensorScalingType.DYNAMIC
271-
and scaling_type_w is TensorScalingType.DYNAMIC
272-
and scaling_type_dL_dY is TensorScalingType.DYNAMIC
273-
)
274-
if not is_all_dynamic:
275-
pytest.skip()
276-
277305
x = torch.randn(*x_shape, device="cuda")
278306
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
279307
self._test_linear_impl(
@@ -286,20 +314,20 @@ def test_linear_nobias(
286314
scaling_type_dL_dY,
287315
)
288316

289-
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
290-
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
291-
@pytest.mark.parametrize("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC])
292-
@pytest.mark.parametrize(
293-
"scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
294-
)
295-
@pytest.mark.parametrize(
296-
"scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
297-
)
298-
@pytest.mark.parametrize(
299-
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
300-
)
301-
@pytest.mark.parametrize(
302-
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
317+
@filtered_parametrize(
318+
[
319+
("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)]),
320+
("linear_type", [LinearType.DELAYED, LinearType.DYNAMIC]),
321+
("emulate", [True, False] if is_H100 else [True]),
322+
("scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
323+
("scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]),
324+
(
325+
"scaling_type_dL_dY",
326+
[TensorScalingType.DELAYED, TensorScalingType.DYNAMIC],
327+
),
328+
("linear_dtype", [torch.float16, torch.bfloat16, torch.float32]),
329+
],
330+
filter_func=is_valid_combination,
303331
)
304332
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
305333
def test_linear_bias(
@@ -312,28 +340,6 @@ def test_linear_bias(
312340
emulate: bool,
313341
linear_dtype: torch.dtype,
314342
):
315-
if not emulate:
316-
if not torch.cuda.is_available():
317-
warnings.warn("CUDA not available")
318-
pytest.skip()
319-
elif torch.cuda.get_device_capability() < (9, 0):
320-
warnings.warn(
321-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
322-
)
323-
pytest.skip()
324-
if linear_type is LinearType.DYNAMIC:
325-
# Only test one combination of scaling types, as they are a no-op
326-
# for Float8DynamicLinear. It would be cleaner to split into two
327-
# tests, but IMO not worth it since Float8DynamicLinear will be
328-
# deleted soon
329-
is_all_dynamic = (
330-
scaling_type_x is TensorScalingType.DYNAMIC
331-
and scaling_type_w is TensorScalingType.DYNAMIC
332-
and scaling_type_dL_dY is TensorScalingType.DYNAMIC
333-
)
334-
if not is_all_dynamic:
335-
pytest.skip()
336-
337343
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
338344
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
339345
self._test_linear_impl(

0 commit comments

Comments
 (0)