From 64c4ca7e4af0ad4c8e79e14cad21df16d0d5e635 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Fri, 16 May 2025 09:08:55 -0700 Subject: [PATCH] Allow a context manager to be called around apply_jit (#2927) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2927 When running torch.jit.script on the various forward functions, you can run into issues if there are any other utilites interacting with the function definitions. As an example, if you have another JIT running, you need to disable it throughout this process. This commit adds the ability to additionally pass an apply_jit_context context manager wherever apply_jit is currently passed that will be called around the application of the torch jit. Reviewed By: SonicField Differential Revision: D73781040 --- .../tests/test_train_pipelines_utils.py | 26 +++++++++++++++- .../train_pipeline/train_pipelines.py | 31 ++++++++++++++++++- torchrec/distributed/train_pipeline/utils.py | 16 +++++++--- 3 files changed, 66 insertions(+), 7 deletions(-) diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py index 358da4d33..08b66d3f5 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_utils.py @@ -10,7 +10,8 @@ import copy import enum import unittest -from typing import List +from contextlib import contextmanager +from typing import Generator, List from unittest.mock import MagicMock import torch @@ -43,6 +44,29 @@ class ModelType(enum.Enum): class TrainPipelineUtilsTest(TrainPipelineSparseDistTestBase): + # pyre-fixme[56]: Pyre was not able to infer the type of argument + @unittest.skipIf( + not torch.cuda.is_available(), + "Not enough GPUs, this test requires at least one GPU", + ) + def test_rewrite_model_apply_jit(self) -> None: + @contextmanager + def apply_jit_context(events: list[str]) -> Generator[None, None, None]: + events.append("__enter__") + yield + events.append("__exit__") + + events = [] + _rewrite_model( + model=self._setup_model(), + context=TrainPipelineContext(), + dist_stream=None, + apply_jit=True, + apply_jit_context=apply_jit_context(events), + ) + + self.assertEqual(events, ["__enter__", "__exit__"]) + # pyre-fixme[56]: Pyre was not able to infer the type of argument @unittest.skipIf( not torch.cuda.is_available(), diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index b8be13994..5026f9008 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -393,6 +393,8 @@ class TrainPipelineSparseDist(TrainPipeline[In, Out]): (applicable to 2D sharding only) if set and DMP collection is enabled for 2D sharding, sync DMPs every N batches (default to 1, i.e. every batch, None to disable) + apply_jit_context (Optional[ContextManager]): a context manager that + will surround the application of the JIT """ # The PipelinedForward class that is used in _rewrite_model @@ -413,6 +415,7 @@ def __init__( ] = None, dmp_collection_sync_interval_batches: Optional[int] = 1, enqueue_batch_after_forward: bool = False, + apply_jit_context: Optional[ContextManager[None]] = None, ) -> None: self._model = model self._optimizer = optimizer @@ -420,6 +423,7 @@ def __init__( self._execute_all_batches = execute_all_batches self._apply_jit = apply_jit self._enqueue_batch_after_forward = enqueue_batch_after_forward + self._apply_jit_context = apply_jit_context if device.type == "cuda": # use two data streams to support two concurrent batches @@ -716,6 +720,7 @@ def _pipeline_model( apply_jit=self._apply_jit, pipelined_forward=pipelined_forward, pipeline_postproc=self._pipeline_postproc, + apply_jit_context=self._apply_jit_context, ) # initializes input dist, so we can override input dist forwards self.start_sparse_data_dist(batch, context) @@ -904,6 +909,8 @@ class TrainPipelineFusedSparseDist(TrainPipelineSparseDist[In, Out]): TODO: pipeline_postproc, custom_model_fwd, strict use_emb_lookuo_stream (bool): if true invoke the compute_and_output_dist (for batch i+1) using a new stream, else re-using the data_dist stream + apply_jit_context (ContextManager): a context manager that will surround the + application of the JIT """ # The PipelinedForward class that is used in _rewrite_model @@ -922,6 +929,7 @@ def __init__( ] = None, strict: bool = False, emb_lookup_stream: str = "data_dist", # new, current, data_dist (default) + apply_jit_context: Optional[ContextManager[None]] = None, ) -> None: super().__init__( model=model, @@ -932,6 +940,7 @@ def __init__( context_type=EmbeddingTrainPipelineContext, pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, + apply_jit_context=apply_jit_context, ) if emb_lookup_stream == "new": self._emb_lookup_stream: Optional[torch.Stream] = ( @@ -1066,6 +1075,8 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]): (applicable to 2D sharding only) if set and DMP collection is enabled for 2D sharding, sync DMPs every N batches (default to 1, i.e. every batch, None to disable) + apply_jit_context (ContextManager): a context manager that will surround the + application of the JIT """ # The PipelinedForward class that is used in _rewrite_model @@ -1086,6 +1097,7 @@ def __init__( ] = None, strict: bool = False, dmp_collection_sync_interval_batches: Optional[int] = 1, + apply_jit_context: Optional[ContextManager[None]] = None, ) -> None: super().__init__( model=model, @@ -1097,6 +1109,7 @@ def __init__( pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, dmp_collection_sync_interval_batches=dmp_collection_sync_interval_batches, + apply_jit_context=apply_jit_context, ) self._start_batch = start_batch self._stash_gradients = stash_gradients @@ -1378,6 +1391,8 @@ class PrefetchTrainPipelineSparseDist(TrainPipelineSparseDist[In, Out]): execute_all_batches (bool): executes remaining batches in pipeline after exhausting dataloader iterator. apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + apply_jit_context (ContextManager): a context manager that will surround the + application of the JIT """ # The PipelinedForward class that is used in _rewrite_model @@ -1394,6 +1409,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + apply_jit_context: Optional[ContextManager[None]] = None, ) -> None: super().__init__( model=model, @@ -1404,6 +1420,7 @@ def __init__( context_type=PrefetchTrainPipelineContext, pipeline_postproc=pipeline_postproc, custom_model_fwd=custom_model_fwd, + apply_jit_context=apply_jit_context, ) self._context = PrefetchTrainPipelineContext(version=0) self._prefetch_stream: Optional[torch.Stream] = ( @@ -1535,6 +1552,8 @@ class EvalPipelineSparseDist(TrainPipelineSparseDist[In, Out]): device (torch.device): device where device transfer, sparse data dist, and forward/backward pass will happen. apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. + apply_jit_context (Optional[ContextManager]): a context manager that + will surround the application of the JIT """ # The PipelinedForward class that is used in _rewrite_model @@ -1546,8 +1565,16 @@ def __init__( optimizer: torch.optim.Optimizer, device: torch.device, apply_jit: bool = False, + apply_jit_context: Optional[ContextManager[None]] = None, ) -> None: - super().__init__(model, optimizer, device, True, apply_jit) + super().__init__( + model, + optimizer, + device, + True, + apply_jit, + apply_jit_context=apply_jit_context, + ) self._batch_loader: Optional[DataLoadingThread[In]] = None def __del__(self) -> None: @@ -1909,6 +1936,7 @@ def __init__( custom_model_fwd: Optional[ Callable[[Optional[In]], Tuple[torch.Tensor, Out]] ] = None, + apply_jit_context: Optional[ContextManager[None]] = None, ) -> None: super().__init__( model, @@ -1919,6 +1947,7 @@ def __init__( context_type, pipeline_postproc, custom_model_fwd, + apply_jit_context=apply_jit_context, ) torch._logging.set_logs(compiled_autograd_verbose=True) diff --git a/torchrec/distributed/train_pipeline/utils.py b/torchrec/distributed/train_pipeline/utils.py index 2a561f80a..953b8832e 100644 --- a/torchrec/distributed/train_pipeline/utils.py +++ b/torchrec/distributed/train_pipeline/utils.py @@ -13,7 +13,7 @@ import itertools import logging from collections import defaultdict, OrderedDict -from contextlib import AbstractContextManager +from contextlib import AbstractContextManager, nullcontext from dataclasses import dataclass, field from itertools import chain @@ -22,6 +22,7 @@ Any, Callable, cast, + ContextManager, Dict, Generator, Generic, @@ -1540,6 +1541,7 @@ def _rewrite_model( # noqa C901 pipelined_forward: Type[BaseForward[TrainPipelineContext]] = PipelinedForward, pipeline_postproc: bool = False, default_stream: Optional[torch.Stream] = None, + apply_jit_context: Optional[ContextManager[None]] = None, ) -> Tuple[ List[ShardedModule], torch.nn.Module, @@ -1643,10 +1645,14 @@ def _rewrite_model( # noqa C901 # JIT script unsharded modules if applicable. if apply_jit: - graph_model = torch.fx.GraphModule(model, graph) - _jit_modules(graph_model, "") - if isinstance(input_model, DistributedModelParallel): - input_model.module = graph_model + if apply_jit_context is None: + apply_jit_context = nullcontext() + + with apply_jit_context: + graph_model = torch.fx.GraphModule(model, graph) + _jit_modules(graph_model, "") + if isinstance(input_model, DistributedModelParallel): + input_model.module = graph_model if non_pipelined_sharded_modules: logger.warn(