diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index de3d495f2..8cfd16ae9 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,7 +27,6 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings -from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function from torch.distributed._shard.sharded_tensor import TensorProperties @@ -95,7 +94,6 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor -from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -658,7 +656,9 @@ def __init__( self._inverse_indices_permute_indices: Optional[torch.Tensor] = None # to support mean pooling callback hook self._has_mean_pooling_callback: bool = ( - PoolingType.MEAN.value in self._pooling_type_to_rs_features + True + if PoolingType.MEAN.value in self._pooling_type_to_rs_features + else False ) self._dim_per_key: Optional[torch.Tensor] = None self._kjt_key_indices: Dict[str, int] = {} @@ -1189,16 +1189,8 @@ def _create_inverse_indices_permute_indices( # pyre-ignore [14] def input_dist( - self, - ctx: EmbeddingBagCollectionContext, - features: Union[KeyedJaggedTensor, TensorDict], + self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor ) -> Awaitable[Awaitable[KJTList]]: - if isinstance(features, TensorDict): - feature_keys = list(features.keys()) # pyre-ignore[6] - if len(self._features_order) > 0: - feature_keys = [feature_keys[i] for i in self._features_order] - self._has_features_permute = False # feature_keys are in order - features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] ctx.variable_batch_per_feature = features.variable_stride_per_key() ctx.inverse_indices = features.inverse_indices_or_none() if self._has_uninitialized_input_dist: diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index fdb900fe0..e8dc5eccb 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -160,7 +160,7 @@ def main( tables = [ EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, + num_embeddings=(i + 1) * 1000, embedding_dim=dim_emb, name="table_" + str(i), feature_names=["feature_" + str(i)], @@ -169,7 +169,7 @@ def main( ] weighted_tables = [ EmbeddingBagConfig( - num_embeddings=max(i + 1, 100) * 1000, + num_embeddings=(i + 1) * 1000, embedding_dim=dim_emb, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 4ade3df2f..307d66639 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -19,7 +19,6 @@ pooling_type_to_str, ) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor -from torchrec.sparse.tensor_dict import maybe_td_to_kjt @torch.fx.wrap @@ -230,7 +229,6 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: KeyedTensor """ flat_feature_names: List[str] = [] - features = maybe_td_to_kjt(features, None) for names in self._feature_names: flat_feature_names.extend(names) inverse_indices = reorder_inverse_indices(