Skip to content

Commit 0bd591a

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Add option to skip certain unit tests (#2878)
Summary: Pull Request resolved: #2878 Add option to skip certain tests - useful when a few tests are broken and we need to skip them to gain visibility in other unit tests. This usually happens on our CPU unit tests, so only modifying this script. To use, simply add the name of the tests to skip in this txt file. such as: ``` test_sharding_fused_ebc_as_top_level ``` You can also use the class name for the test: e.g. ``` ModelParallelSparseOnlyTestGloo ``` NOTE: If you want to exclude some test from running in OSS workflow (usually expecting failure due to environment discrepancy), you can add `_disabled_in_oss_compatibility` at the end of your unit test function name. Reviewed By: TroyGarden Differential Revision: D72815908 fbshipit-source-id: 05267e6ffd1d3e30c80c8f291c57bfa25b6d7223
1 parent 0a78345 commit 0bd591a

File tree

5 files changed

+25
-27
lines changed

5 files changed

+25
-27
lines changed

.github/scripts/tests_to_skip.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
_disabled_in_oss_compatibility

.github/workflows/unittest_ci.yml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ on:
77
push:
88
branches:
99
- nightly
10+
- main
1011
workflow_dispatch:
1112

1213
jobs:
@@ -100,6 +101,14 @@ jobs:
100101
python -c "import numpy"
101102
echo "numpy succeeded"
102103
conda install -n build_binary -y pytest
104+
# Read the list of tests to skip from a file, ignoring empty lines and comments
105+
skip_expression=$(awk '!/^($|#)/ {printf " and not %s", $0}' ./.github/scripts/tests_to_skip.txt)
106+
# Check if skip_expression is effectively empty
107+
if [ -z "$skip_expression" ]; then
108+
skip_expression=""
109+
else
110+
skip_expression=${skip_expression:5} # Remove the leading " and "
111+
fi
103112
conda run -n build_binary \
104113
python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors \
105114
--ignore=torchrec/distributed/tests/test_comm.py --ignore=torchrec/distributed/tests/test_infer_shardings.py \
@@ -110,4 +119,5 @@ jobs:
110119
--ignore-glob='torchrec/inference/inference_legacy/tests*' --ignore-glob='*test_model_parallel_nccl*' \
111120
--ignore=torchrec/distributed/tests/test_cache_prefetch.py --ignore=torchrec/distributed/tests/test_fp_embeddingbag_single_rank.py \
112121
--ignore=torchrec/distributed/tests/test_infer_utils.py --ignore=torchrec/distributed/tests/test_fx_jit.py --ignore-glob=**/test_utils/ \
113-
--ignore-glob='*test_train_pipeline*' --ignore=torchrec/distributed/tests/test_model_parallel_hierarchical.py
122+
--ignore-glob='*test_train_pipeline*' --ignore=torchrec/distributed/tests/test_model_parallel_hierarchical.py \
123+
-k "$skip_expression"

.github/workflows/unittest_ci_cpu.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,14 @@ jobs:
7272
python -c "import numpy"
7373
echo "numpy succeeded"
7474
conda install -n build_binary -y pytest
75+
# Read the list of tests to skip from a file, ignoring empty lines and comments
76+
skip_expression=$(awk '!/^($|#)/ {printf " and not %s", $0}' ./.github/scripts/tests_to_skip.txt)
77+
# Check if skip_expression is effectively empty
78+
if [ -z "$skip_expression" ]; then
79+
skip_expression=""
80+
else
81+
skip_expression=${skip_expression:5} # Remove the leading " and "
82+
fi
7583
conda run -n build_binary \
7684
python -m pytest torchrec -v -s -W ignore::pytest.PytestCollectionWarning --continue-on-collection-errors \
77-
--ignore-glob=**/test_utils/
85+
--ignore-glob=**/test_utils/ -k "$skip_expression"

torchrec/ir/tests/test_serializer.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,7 @@ def test_serialize_deserialize_ebc(self) -> None:
253253
self.assertEqual(deserialized.shape, orginal.shape)
254254
self.assertTrue(torch.allclose(deserialized, orginal))
255255

256-
# pyre-ignore[56]: Pyre was not able to infer the type of argument
257-
@unittest.skipIf(
258-
torch.cuda.device_count() == 0,
259-
"skip this test in OSS (no GPU available) because torch.export uses training ir in OSS",
260-
)
261-
def test_dynamic_shape_ebc(self) -> None:
262-
# TODO: https://fb.workplace.com/groups/1028545332188949/permalink/1138699244506890/
256+
def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
263257
model = self.generate_model()
264258
feature1 = KeyedJaggedTensor.from_offsets_sync(
265259
keys=["f1", "f2", "f3"],

torchrec/models/experimental/test_transformerdlrm.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,7 @@ def test_larger(self) -> None:
6161
concat_dense = inter_arch(dense_features, sparse_features)
6262
self.assertEqual(concat_dense.size(), (B, D * (F + 1)))
6363

64-
# pyre-ignore[56]: Pyre was not able to infer the type of argument
65-
@unittest.skipIf(
66-
torch.cuda.device_count() == 0,
67-
"skip this test in OSS (no GPU available) because seed might be different in OSS",
68-
)
69-
def test_correctness(self) -> None:
64+
def test_correctness_disabled_in_oss_compatibility(self) -> None:
7065
D = 4
7166
B = 3
7267
# multi-head attentions
@@ -170,12 +165,7 @@ def test_correctness(self) -> None:
170165
)
171166
)
172167

173-
# pyre-ignore[56]: Pyre was not able to infer the type of argument
174-
@unittest.skipIf(
175-
torch.cuda.device_count() == 0,
176-
"skip this test in OSS (no GPU available) because seed might be different in OSS",
177-
)
178-
def test_numerical_stability(self) -> None:
168+
def test_numerical_stability_disabled_in_oss_compatibility(self) -> None:
179169
D = 4
180170
B = 3
181171
# multi-head attentions
@@ -204,12 +194,7 @@ def test_numerical_stability(self) -> None:
204194

205195

206196
class DLRMTransformerTest(unittest.TestCase):
207-
# pyre-ignore[56]: Pyre was not able to infer the type of argument
208-
@unittest.skipIf(
209-
torch.cuda.device_count() == 0,
210-
"skip this test in OSS (no GPU available) because seed might be different in OSS",
211-
)
212-
def test_basic(self) -> None:
197+
def test_basic_disabled_in_oss_compatibility(self) -> None:
213198
torch.manual_seed(0)
214199
B = 2
215200
D = 8

0 commit comments

Comments
 (0)