Skip to content

Commit 8f55f3a

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Fix Pyre (pytorch#2960)
Summary: Pull Request resolved: pytorch#2960 Reviewed By: aliafzal Differential Revision: D74418088 fbshipit-source-id: 6ce96e7833d162f65598354edcab5c71c1f1edca
1 parent 949278c commit 8f55f3a

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

torchrec/distributed/tests/test_pt2_multiprocess.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,6 @@ def _test_compile_rank_fn(
251251
num_float_features: int = 8
252252
num_weighted_features: int = 1
253253

254-
# pyre-ignore
255254
device: torch.Device = torch.device("cuda")
256255
pg: Optional[dist.ProcessGroup] = ctx.pg
257256
assert pg is not None
@@ -336,7 +335,7 @@ def _dmp(m: torch.nn.Module) -> DistributedModelParallel:
336335
env=ShardingEnv.from_process_group(pg),
337336
plan=plan,
338337
sharders=sharders,
339-
device=device,
338+
device=device, # pyre-ignore
340339
init_data_parallel=False,
341340
)
342341

@@ -601,7 +600,7 @@ def _dmp(m: torch.nn.Module) -> DistributedModelParallel: # pyre-ignore
601600
env=ShardingEnv(world_size, rank, pg),
602601
plan=plan,
603602
sharders=sharders,
604-
device=device,
603+
device=device, # pyre-ignore
605604
init_data_parallel=False,
606605
)
607606

0 commit comments

Comments
 (0)