Skip to content

Commit 771f142

Browse files
Dark Knightfacebook-github-bot
Dark Knight
authored andcommitted
Revert D65103519
Summary: This diff reverts D65103519 Depends on D68528333 Need to revert this to fix lowering import error breaking aps tests Reviewed By: PoojaAg18 Differential Revision: D68528363
1 parent 434e5dc commit 771f142

File tree

3 files changed

+6
-16
lines changed

3 files changed

+6
-16
lines changed

torchrec/distributed/embeddingbag.py

+4-12
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
import torch
2929
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
30-
from tensordict import TensorDict
3130
from torch import distributed as dist, nn, Tensor
3231
from torch.autograd.profiler import record_function
3332
from torch.distributed._shard.sharded_tensor import TensorProperties
@@ -95,7 +94,6 @@
9594
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
9695
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
9796
from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
98-
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
9997

10098
try:
10199
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -658,7 +656,9 @@ def __init__(
658656
self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
659657
# to support mean pooling callback hook
660658
self._has_mean_pooling_callback: bool = (
661-
PoolingType.MEAN.value in self._pooling_type_to_rs_features
659+
True
660+
if PoolingType.MEAN.value in self._pooling_type_to_rs_features
661+
else False
662662
)
663663
self._dim_per_key: Optional[torch.Tensor] = None
664664
self._kjt_key_indices: Dict[str, int] = {}
@@ -1189,16 +1189,8 @@ def _create_inverse_indices_permute_indices(
11891189

11901190
# pyre-ignore [14]
11911191
def input_dist(
1192-
self,
1193-
ctx: EmbeddingBagCollectionContext,
1194-
features: Union[KeyedJaggedTensor, TensorDict],
1192+
self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor
11951193
) -> Awaitable[Awaitable[KJTList]]:
1196-
if isinstance(features, TensorDict):
1197-
feature_keys = list(features.keys()) # pyre-ignore[6]
1198-
if len(self._features_order) > 0:
1199-
feature_keys = [feature_keys[i] for i in self._features_order]
1200-
self._has_features_permute = False # feature_keys are in order
1201-
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
12021194
ctx.variable_batch_per_feature = features.variable_stride_per_key()
12031195
ctx.inverse_indices = features.inverse_indices_or_none()
12041196
if self._has_uninitialized_input_dist:

torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def main(
160160

161161
tables = [
162162
EmbeddingBagConfig(
163-
num_embeddings=max(i + 1, 100) * 1000,
163+
num_embeddings=(i + 1) * 1000,
164164
embedding_dim=dim_emb,
165165
name="table_" + str(i),
166166
feature_names=["feature_" + str(i)],
@@ -169,7 +169,7 @@ def main(
169169
]
170170
weighted_tables = [
171171
EmbeddingBagConfig(
172-
num_embeddings=max(i + 1, 100) * 1000,
172+
num_embeddings=(i + 1) * 1000,
173173
embedding_dim=dim_emb,
174174
name="weighted_table_" + str(i),
175175
feature_names=["weighted_feature_" + str(i)],

torchrec/modules/embedding_modules.py

-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
pooling_type_to_str,
2020
)
2121
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
22-
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
2322

2423

2524
@torch.fx.wrap
@@ -230,7 +229,6 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
230229
KeyedTensor
231230
"""
232231
flat_feature_names: List[str] = []
233-
features = maybe_td_to_kjt(features, None)
234232
for names in self._feature_names:
235233
flat_feature_names.extend(names)
236234
inverse_indices = reorder_inverse_indices(

0 commit comments

Comments
 (0)