Skip to content

Commit 92b903f

Browse files
sarckkfacebook-github-bot
authored andcommitted
Make streams device-agnostic (#2644)
Summary: Pull Request resolved: #2644 #2598 (D64220706) causes failures when using other accelerators that do not support CUDA. Making the stream contexts hardware agnostic. Reviewed By: hpnhxxwn, iamzainhuda Differential Revision: D67363141 fbshipit-source-id: fc2c6fec1dcbbe15f0385e299b666207d2d9a8f5
1 parent a43cef8 commit 92b903f

File tree

2 files changed

+33
-6
lines changed

2 files changed

+33
-6
lines changed

torchrec/distributed/train_pipeline/train_pipelines.py

+1
Original file line numberDiff line numberDiff line change
@@ -1002,6 +1002,7 @@ def start_embedding_lookup(
10021002
context,
10031003
source_stream=self._data_dist_stream,
10041004
target_stream=stream,
1005+
stream_context=self._stream_context,
10051006
)
10061007
event = torch.get_device_module(self._device).Event()
10071008
event.record()

torchrec/distributed/train_pipeline/utils.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import itertools
1212
import logging
1313
from collections import defaultdict, OrderedDict
14+
from contextlib import AbstractContextManager
1415
from dataclasses import dataclass, field
1516

1617
from itertools import chain
@@ -248,6 +249,21 @@ def recursive_record_stream(
248249
recursive_record_stream(v, stream)
249250

250251

252+
class NoOpStream:
253+
"""No-Op Context manager that takes in a stream"""
254+
255+
def __init__(self, stream: Optional[torch.Stream]) -> None:
256+
self._stream = stream
257+
258+
def __enter__(self) -> "NoOpStream":
259+
"""Return `self` upon entering the runtime context."""
260+
return self
261+
262+
# pyre-ignore
263+
def __exit__(self, exc_type, exc_value, traceback) -> None:
264+
return None
265+
266+
251267
class PipelinedPreproc(torch.nn.Module):
252268
"""
253269
Wrapper around preproc module found during model graph traversal for sparse data dist
@@ -297,6 +313,17 @@ def __init__(
297313
f"Preproc module {fqn} has no dist stream. This may cause race conditions and NaNs during training!"
298314
)
299315

316+
if self._dist_stream:
317+
device: torch.device = self._dist_stream.device
318+
# pyre-ignore
319+
self._stream_context = (
320+
torch.get_device_module(device).stream
321+
if device.type in ["cuda", "mtia"]
322+
else torch.cuda.stream
323+
)
324+
else:
325+
self._stream_context = NoOpStream
326+
300327
@property
301328
def preproc_module(self) -> torch.nn.Module:
302329
return self._preproc_module
@@ -341,8 +368,7 @@ def forward(self, *input, **kwargs) -> Any:
341368

342369
with record_function(f"## sdd_input_preproc {self._context.index} ##"):
343370
# should be no-op as we call this in dist stream
344-
# pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream
345-
with torch.cuda.stream(self._dist_stream):
371+
with self._stream_context(self._dist_stream):
346372
res = self._preproc_module(*args, **kwargs)
347373

348374
# Ensure preproc modules output is safe to use from default stream later
@@ -364,8 +390,7 @@ def forward(self, *input, **kwargs) -> Any:
364390
f"Result of preproc module {self._fqn} is of type {type(res)}. We currently expect it to be a Tensor, Pipelineable, Iterable, or Dict to handle memory safety. If your output is not of this type, please add support for it above. Otherwise you might run into NaNs or CUDA Illegal Memory issues during training!"
365391
)
366392

367-
# pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream
368-
with torch.cuda.stream(self._default_stream):
393+
with self._stream_context(self._default_stream):
369394
# Cache results, only during _start_data_dist
370395
self._context.preproc_fwd_results[self._fqn] = res
371396

@@ -760,10 +785,11 @@ def _start_embedding_lookup(
760785
context: EmbeddingTrainPipelineContext,
761786
source_stream: Optional[torch.Stream],
762787
target_stream: Optional[torch.Stream],
788+
# pyre-ignore[2]
789+
stream_context: Callable[..., AbstractContextManager[Any, Any]],
763790
) -> None:
764791
module_context = context.module_contexts[module.forward.name]
765-
# pyre-ignore[6]: torch.cuda.Stream is a wrapper around torch.Stream
766-
with torch.cuda.stream(source_stream):
792+
with stream_context(source_stream):
767793
kjt = context.input_dist_tensors_requests[module.forward.name].wait()
768794

769795
if target_stream is not None:

0 commit comments

Comments
 (0)