Skip to content

Commit 26e0732

Browse files
zhaojuanmaofacebook-github-bot
authored andcommitted
fix attach (#2726)
Summary: Pull Request resolved: #2726 allow "attach" to avoid calling sparse data distribution, as in some cases "attach" is called outside training loop, no sparse data distribution when calling attach outside training loop can avoid interference with sparse data distribution inside training loop. Reviewed By: hlin09, ge0405 Differential Revision: D68908008 fbshipit-source-id: bb1b7c362ef0a6de78e3e0bd771d8d22e3e4e985
1 parent c4c9332 commit 26e0732

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -418,14 +418,16 @@ def detach(self) -> torch.nn.Module:
418418
self._model_attached = False
419419
return self._model
420420

421-
def attach(self, model: Optional[torch.nn.Module] = None) -> None:
421+
def attach(
422+
self, model: Optional[torch.nn.Module] = None, sparse_dist: bool = True
423+
) -> None:
422424
if model:
423425
self._model = model
424426

425427
self._model_attached = True
426428
if self.contexts:
427429
self._pipeline_model(
428-
batch=self.batches[0],
430+
batch=self.batches[0] if sparse_dist else None,
429431
context=self.contexts[0],
430432
pipelined_forward=PipelinedForward,
431433
)

0 commit comments

Comments
 (0)