Skip to content

Commit 673bc73

Browse files
zhaojuanmaofacebook-github-bot
authored andcommitted
fix attach
Summary: allow "attach" to avoid calling sparse data distribution, as in some cases "attach" is called outside training loop, no sparse data distribution when calling attach outside training loop can avoid interference with sparse data distribution inside training loop. Differential Revision: D68908008
1 parent 9269e73 commit 673bc73

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

Diff for: torchrec/distributed/train_pipeline/train_pipelines.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -418,14 +418,16 @@ def detach(self) -> torch.nn.Module:
418418
self._model_attached = False
419419
return self._model
420420

421-
def attach(self, model: Optional[torch.nn.Module] = None) -> None:
421+
def attach(
422+
self, model: Optional[torch.nn.Module] = None, sparse_dist: bool = True
423+
) -> None:
422424
if model:
423425
self._model = model
424426

425427
self._model_attached = True
426428
if self.contexts:
427429
self._pipeline_model(
428-
batch=self.batches[0],
430+
batch=self.batches[0] if sparse_dist else None,
429431
context=self.contexts[0],
430432
pipelined_forward=PipelinedForward,
431433
)

0 commit comments

Comments
 (0)