From 37707b4467e0350b48161cdae50915ba02ecb9a1 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Tue, 21 Jan 2025 11:16:53 -0800 Subject: [PATCH 1/2] add NJT/TD support for EBC and pipeline benchmark (#2581) Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EBC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingBagCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. * ref: D63436011 # Details * `td_to_kjt` implemented in python, which has cpu perf regression. But it's not on the training critical path so it has a minimal impact on the overall training QPS (see test plan benchmark results) * Currently only support EBC use case WARNING: `TensorDict` does **NOT** support weighted jagged tensor, **Nor** variable batch_size neither. NOTE: All the following comparisons are between the **`KJT.permute`** in the KJT input scenario and the **`TD-KJT conversion`** in the TD input scenario. * Both `KJT.permute` and `TD-KJT conversion` are correctly marked in the `TrainPipelineBase` traces `TD-KJT conversion` has more real executions in CPU, but the heavy-lifting computation is in GPU, which is delayed/blocked by the backward pass of the previous batch. GPU runtime has a small difference ~10%. {F1949366822} * For the `Copy-Batch-To-GPU` part, TD has more fragmented `HtoD` comms while KJT has a single contiguous `HtoD` comm Runtime-wise they are similar ~10% {F1949374305} * In the most commonly used `TrainPipelineSparseDist`, where the `Copy-Batch-To-GPU` and the cpu runtime are not on the critical path, we do observe very similar training QPS in the pipeline benchmark ~1% {F1949390271} ``` TrainPipelineSparseDist | Runtime (P90): 26.737 s | Memory (P90): 34.801 GB (TD) TrainPipelineSparseDist | Runtime (P90): 26.539 s | Memory (P90): 34.765 GB (KJT) ``` * increased data size, GPU runtime is 4x {F1949386106} # Conclusion 1. [Enablement] With this approach (replacing the `KJT permute` with `TD-KJT conversion`), the EBC can now take `TensorDict` as the module input in both single-GPU and multi-GPU (sharded) scenarios, tested with TrainPipelineBase, TrainPipelineSparseDist, TrainPipelineSemiSync, and TrainPipelinePrefetch. 2. [Performance] The TD host-to-device data transfer might not necessarily be a concern/blocker for the most commonly used train pipeline (TrainPipelineSparseDist). 2. [Feature Support] In order to become production-ready, the TensorDict needs to (1) integrate the `KJT.weights` data, and (2) to support the variable batch size, which are almost used in all the production models. 3. [Improvement] There are two major operations we can improve: (1) move TensorDict from host to device, and (2) convert TD to KJT. Currently they are both in the vanilla state. Since we are not sure how the real traces would be like with production models, we can't tell if these improvements are needed/helpful. Reviewed By: dstaay-fb Differential Revision: D65103519 --- torchrec/distributed/embeddingbag.py | 16 ++++++++++++---- .../train_pipeline/tests/pipeline_benchmarks.py | 4 ++-- torchrec/modules/embedding_modules.py | 2 ++ 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 8cfd16ae9..de3d495f2 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,6 +27,7 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings +from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function from torch.distributed._shard.sharded_tensor import TensorProperties @@ -94,6 +95,7 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -656,9 +658,7 @@ def __init__( self._inverse_indices_permute_indices: Optional[torch.Tensor] = None # to support mean pooling callback hook self._has_mean_pooling_callback: bool = ( - True - if PoolingType.MEAN.value in self._pooling_type_to_rs_features - else False + PoolingType.MEAN.value in self._pooling_type_to_rs_features ) self._dim_per_key: Optional[torch.Tensor] = None self._kjt_key_indices: Dict[str, int] = {} @@ -1189,8 +1189,16 @@ def _create_inverse_indices_permute_indices( # pyre-ignore [14] def input_dist( - self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor + self, + ctx: EmbeddingBagCollectionContext, + features: Union[KeyedJaggedTensor, TensorDict], ) -> Awaitable[Awaitable[KJTList]]: + if isinstance(features, TensorDict): + feature_keys = list(features.keys()) # pyre-ignore[6] + if len(self._features_order) > 0: + feature_keys = [feature_keys[i] for i in self._features_order] + self._has_features_permute = False # feature_keys are in order + features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] ctx.variable_batch_per_feature = features.variable_stride_per_key() ctx.inverse_indices = features.inverse_indices_or_none() if self._has_uninitialized_input_dist: diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index e8dc5eccb..fdb900fe0 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -160,7 +160,7 @@ def main( tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 1000, + num_embeddings=max(i + 1, 100) * 1000, embedding_dim=dim_emb, name="table_" + str(i), feature_names=["feature_" + str(i)], @@ -169,7 +169,7 @@ def main( ] weighted_tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 1000, + num_embeddings=max(i + 1, 100) * 1000, embedding_dim=dim_emb, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 307d66639..4ade3df2f 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -19,6 +19,7 @@ pooling_type_to_str, ) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt @torch.fx.wrap @@ -229,6 +230,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: KeyedTensor """ flat_feature_names: List[str] = [] + features = maybe_td_to_kjt(features, None) for names in self._feature_names: flat_feature_names.extend(names) inverse_indices = reorder_inverse_indices( From 4428885d8a6a18c9f80cb027886cc329894dad91 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Tue, 21 Jan 2025 11:16:53 -0800 Subject: [PATCH 2/2] add NJT/TD support for EC (#2596) Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Reviewed By: dstaay-fb Differential Revision: D66521351 --- torchrec/distributed/embedding.py | 13 +++++- .../distributed/test_utils/test_sharding.py | 32 ++++++++++++--- .../tests/test_sequence_model_parallel.py | 41 +++++++++++++++++++ torchrec/modules/embedding_modules.py | 8 +++- 4 files changed, 84 insertions(+), 10 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 93773cc1f..feb77a72a 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -26,6 +26,7 @@ ) import torch +from tensordict import TensorDict from torch import distributed as dist, nn from torch.autograd.profiler import record_function from torch.distributed._shard.sharding_spec import EnumerableShardingSpec @@ -90,6 +91,7 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -1198,8 +1200,15 @@ def _compute_sequence_vbe_context( def input_dist( self, ctx: EmbeddingCollectionContext, - features: KeyedJaggedTensor, + features: TypeUnion[KeyedJaggedTensor, TensorDict], ) -> Awaitable[Awaitable[KJTList]]: + need_permute: bool = True + if isinstance(features, TensorDict): + feature_keys = list(features.keys()) # pyre-ignore[6] + if self._features_order: + feature_keys = [feature_keys[i] for i in self._features_order] + need_permute = False + features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] if self._has_uninitialized_input_dist: self._create_input_dist(input_feature_names=features.keys()) self._has_uninitialized_input_dist = False @@ -1209,7 +1218,7 @@ def input_dist( unpadded_features = features features = pad_vbe_kjt_lengths(unpadded_features) - if self._features_order: + if need_permute and self._features_order: features = features.permute( self._features_order, # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index f2b65a833..48b9a90ab 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -147,6 +147,7 @@ def gen_model_and_input( long_indices: bool = True, global_constant_batch: bool = False, num_inputs: int = 1, + input_type: str = "kjt", # "kjt" or "td" ) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]: torch.manual_seed(0) if dedup_feature_names: @@ -177,9 +178,9 @@ def gen_model_and_input( feature_processor_modules=feature_processor_modules, ) inputs = [] - for _ in range(num_inputs): - inputs.append( - ( + if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input: + for _ in range(num_inputs): + inputs.append( cast(VariableBatchModelInputCallable, generate)( average_batch_size=batch_size, world_size=world_size, @@ -188,8 +189,26 @@ def gen_model_and_input( weighted_tables=weighted_tables or [], global_constant_batch=global_constant_batch, ) - if generate == ModelInput.generate_variable_batch_input - else cast(ModelInputCallable, generate)( + ) + elif generate == ModelInput.generate: + for _ in range(num_inputs): + inputs.append( + ModelInput.generate( + world_size=world_size, + tables=tables, + dedup_tables=dedup_tables, + weighted_tables=weighted_tables or [], + num_float_features=num_float_features, + variable_batch_size=variable_batch_size, + batch_size=batch_size, + long_indices=long_indices, + input_type=input_type, + ) + ) + else: + for _ in range(num_inputs): + inputs.append( + cast(ModelInputCallable, generate)( world_size=world_size, tables=tables, dedup_tables=dedup_tables, @@ -200,7 +219,6 @@ def gen_model_and_input( long_indices=long_indices, ) ) - ) return (model, inputs) @@ -297,6 +315,7 @@ def sharding_single_rank_test( global_constant_batch: bool = False, world_size_2D: Optional[int] = None, node_group_size: Optional[int] = None, + input_type: str = "kjt", # "kjt" or "td" ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: # Generate model & inputs. @@ -319,6 +338,7 @@ def sharding_single_rank_test( batch_size=batch_size, feature_processor_modules=feature_processor_modules, global_constant_batch=global_constant_batch, + input_type=input_type, ) global_model = global_model.to(ctx.device) global_input = inputs[0][0].to(ctx.device) diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index aec092354..d13d819c3 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -376,3 +376,44 @@ def _test_sharding( variable_batch_per_feature=variable_batch_per_feature, global_constant_batch=True, ) + + +@skip_if_asan_class +class TDSequenceModelParallelTest(SequenceModelParallelTest): + + def test_sharding_variable_batch(self) -> None: + pass + + def _test_sharding( + self, + sharders: List[TestEmbeddingCollectionSharder], + backend: str = "gloo", + world_size: int = 2, + local_size: Optional[int] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + model_class: Type[TestSparseNNBase] = TestSequenceSparseNN, + qcomms_config: Optional[QCommsConfig] = None, + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ] = None, + variable_batch_size: bool = False, + variable_batch_per_feature: bool = False, + ) -> None: + self._run_multi_process_test( + callable=sharding_single_rank_test, + world_size=world_size, + local_size=local_size, + model_class=model_class, + tables=self.tables, + embedding_groups=self.embedding_groups, + sharders=sharders, + optim=EmbOptimType.EXACT_SGD, + backend=backend, + constraints=constraints, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + variable_batch_per_feature=variable_batch_per_feature, + global_constant_batch=True, + input_type="td", + ) diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 4ade3df2f..d110fd57f 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -219,7 +219,10 @@ def __init__( self._feature_names: List[List[str]] = [table.feature_names for table in tables] self.reset_parameters() - def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + def forward( + self, + features: KeyedJaggedTensor, # can also take TensorDict as input + ) -> KeyedTensor: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature. @@ -450,7 +453,7 @@ def __init__( # noqa C901 def forward( self, - features: KeyedJaggedTensor, + features: KeyedJaggedTensor, # can also take TensorDict as input ) -> Dict[str, JaggedTensor]: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` @@ -463,6 +466,7 @@ def forward( Dict[str, JaggedTensor] """ + features = maybe_td_to_kjt(features, None) feature_embeddings: Dict[str, JaggedTensor] = {} jt_dict: Dict[str, JaggedTensor] = features.to_dict() for i, emb_module in enumerate(self.embeddings.values()):