Skip to content

Commit c4c9332

Browse files
sarckkfacebook-github-bot
authored andcommitted
Remove embedding streams from semi-sync (#2731)
Summary: Pull Request resolved: #2731 Do embedding lookup on default stream instead of extra stream as using extra streams runs into subtle data races. Reviewed By: dstaay-fb Differential Revision: D69270806 fbshipit-source-id: c26ddb1886f3b7151d5048bcdc47180e5ee9f67b
1 parent 9269e73 commit c4c9332

File tree

1 file changed

+37
-57
lines changed

1 file changed

+37
-57
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

+37-57
Original file line numberDiff line numberDiff line change
@@ -768,8 +768,6 @@ def __init__(
768768
self._start_batch = start_batch
769769
self._stash_gradients = stash_gradients
770770
logger.debug(f"Starting semi-sync run at batch: {self._start_batch}")
771-
772-
self._embedding_streams: List[Optional[torch.Stream]] = []
773771
self._gradients: Dict[str, torch.Tensor] = {}
774772

775773
def _grad_swap(self) -> None:
@@ -779,14 +777,6 @@ def _grad_swap(self) -> None:
779777
self._gradients[name] = param.grad.clone()
780778
param.grad = grad
781779

782-
def _init_embedding_streams(self) -> None:
783-
for _ in self._pipelined_modules:
784-
self._embedding_streams.append(
785-
(torch.get_device_module(self._device).Stream(priority=0))
786-
if self._device.type in ["cuda", "mtia"]
787-
else None
788-
)
789-
790780
def _validate_optimizer(self) -> None:
791781
for pipelined_module in self._pipelined_modules:
792782
pipelined_params = set(pipelined_module.parameters())
@@ -815,7 +805,6 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
815805
# pyre-ignore [6]
816806
EmbeddingPipelinedForward,
817807
)
818-
self._init_embedding_streams()
819808
self.wait_sparse_data_dist(self.contexts[0])
820809
self._validate_optimizer()
821810
# pyre-ignore [6]
@@ -916,43 +905,36 @@ def _mlp_forward(
916905
return self._model_fwd(batch)
917906

918907
def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
919-
default_stream = torch.get_device_module(self._device).current_stream()
920908
assert len(context.embedding_features) == len(context.embedding_tensors)
921-
for stream, emb_tensors, embedding_features, detached_emb_tensors in zip(
922-
self._embedding_streams,
909+
for emb_tensors, embedding_features, detached_emb_tensors in zip(
923910
context.embedding_tensors,
924911
context.embedding_features,
925912
context.detached_embedding_tensors,
926913
):
927-
with self._stream_context(stream):
928-
grads = [tensor.grad for tensor in detached_emb_tensors]
929-
if stream:
930-
stream.wait_stream(default_stream)
931-
# Some embeddings may never get used in the final loss computation,
932-
# so the grads will be `None`. If we don't exclude these, it will fail
933-
# with error: "grad can be implicitly created only for scalar outputs"
934-
# Alternatively, if the tensor has only 1 element, pytorch can still
935-
# figure out how to do autograd
936-
embs_to_backprop, grads_to_use, invalid_features = [], [], []
937-
assert len(embedding_features) == len(emb_tensors)
938-
for features, tensor, grad in zip(
939-
embedding_features, emb_tensors, grads
940-
):
941-
if tensor.numel() == 1 or grad is not None:
942-
embs_to_backprop.append(tensor)
943-
grads_to_use.append(grad)
914+
grads = [tensor.grad for tensor in detached_emb_tensors]
915+
# Some embeddings may never get used in the final loss computation,
916+
# so the grads will be `None`. If we don't exclude these, it will fail
917+
# with error: "grad can be implicitly created only for scalar outputs"
918+
# Alternatively, if the tensor has only 1 element, pytorch can still
919+
# figure out how to do autograd
920+
embs_to_backprop, grads_to_use, invalid_features = [], [], []
921+
assert len(embedding_features) == len(emb_tensors)
922+
for features, tensor, grad in zip(embedding_features, emb_tensors, grads):
923+
if tensor.numel() == 1 or grad is not None:
924+
embs_to_backprop.append(tensor)
925+
grads_to_use.append(grad)
926+
else:
927+
if isinstance(features, str):
928+
invalid_features.append(features)
929+
elif isinstance(features, Iterable):
930+
invalid_features.extend(features)
944931
else:
945-
if isinstance(features, str):
946-
invalid_features.append(features)
947-
elif isinstance(features, Iterable):
948-
invalid_features.extend(features)
949-
else:
950-
invalid_features.append(features)
951-
if invalid_features and context.index == 0:
952-
logger.warning(
953-
f"SemiSync, the following features have no gradients: {invalid_features}"
954-
)
955-
torch.autograd.backward(embs_to_backprop, grads_to_use)
932+
invalid_features.append(features)
933+
if invalid_features and context.index == 0:
934+
logger.warning(
935+
f"SemiSync, the following features have no gradients: {invalid_features}"
936+
)
937+
torch.autograd.backward(embs_to_backprop, grads_to_use)
956938

957939
def copy_batch_to_gpu(
958940
self,
@@ -1012,23 +994,21 @@ def start_embedding_lookup(
1012994
"""
1013995
if batch is None:
1014996
return
997+
1015998
with record_function(f"## start_embedding_lookup {context.index} ##"):
1016-
_wait_for_events(
1017-
batch, context, torch.get_device_module(self._device).current_stream()
1018-
)
999+
current_stream = torch.get_device_module(self._device).current_stream()
1000+
_wait_for_events(batch, context, current_stream)
10191001
for i, module in enumerate(self._pipelined_modules):
1020-
stream = self._embedding_streams[i]
1021-
with self._stream_context(stream):
1022-
_start_embedding_lookup(
1023-
module,
1024-
context,
1025-
source_stream=self._data_dist_stream,
1026-
target_stream=stream,
1027-
stream_context=self._stream_context,
1028-
)
1029-
event = torch.get_device_module(self._device).Event()
1030-
event.record()
1031-
context.events.append(event)
1002+
_start_embedding_lookup(
1003+
module,
1004+
context,
1005+
source_stream=self._data_dist_stream,
1006+
target_stream=current_stream,
1007+
stream_context=self._stream_context,
1008+
)
1009+
event = torch.get_device_module(self._device).Event()
1010+
event.record()
1011+
context.events.append(event)
10321012

10331013

10341014
class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]):

0 commit comments

Comments
 (0)