We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7568e05 commit 69acf48Copy full SHA for 69acf48
torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py
@@ -13,8 +13,9 @@
13
from typing import List
14
from unittest.mock import MagicMock
15
16
+import parameterized
17
+
18
import torch
-from parameterized import parameterized
19
20
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
21
from torchrec.distributed.test_utils.test_model import ModelInput, TestNegSamplingModule
@@ -260,7 +261,7 @@ def test_restore_from_snapshot(self) -> None:
260
261
for source_model_type, recipient_model_type in variants:
262
self._test_restore_from_snapshot(source_model_type, recipient_model_type)
263
- @parameterized.expand(
264
+ @parameterized.parameterized.expand(
265
[
266
(
267
CallArgs(
0 commit comments