|
27 | 27 |
|
28 | 28 | import torch
|
29 | 29 | from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
|
30 |
| -from tensordict import TensorDict |
31 | 30 | from torch import distributed as dist, nn, Tensor
|
32 | 31 | from torch.autograd.profiler import record_function
|
33 | 32 | from torch.distributed._shard.sharded_tensor import TensorProperties
|
|
95 | 94 | from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
|
96 | 95 | from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
|
97 | 96 | from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor
|
98 |
| -from torchrec.sparse.tensor_dict import maybe_td_to_kjt |
99 | 97 |
|
100 | 98 | try:
|
101 | 99 | torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
|
@@ -658,7 +656,9 @@ def __init__(
|
658 | 656 | self._inverse_indices_permute_indices: Optional[torch.Tensor] = None
|
659 | 657 | # to support mean pooling callback hook
|
660 | 658 | self._has_mean_pooling_callback: bool = (
|
661 |
| - PoolingType.MEAN.value in self._pooling_type_to_rs_features |
| 659 | + True |
| 660 | + if PoolingType.MEAN.value in self._pooling_type_to_rs_features |
| 661 | + else False |
662 | 662 | )
|
663 | 663 | self._dim_per_key: Optional[torch.Tensor] = None
|
664 | 664 | self._kjt_key_indices: Dict[str, int] = {}
|
@@ -1189,16 +1189,8 @@ def _create_inverse_indices_permute_indices(
|
1189 | 1189 |
|
1190 | 1190 | # pyre-ignore [14]
|
1191 | 1191 | def input_dist(
|
1192 |
| - self, |
1193 |
| - ctx: EmbeddingBagCollectionContext, |
1194 |
| - features: Union[KeyedJaggedTensor, TensorDict], |
| 1192 | + self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor |
1195 | 1193 | ) -> Awaitable[Awaitable[KJTList]]:
|
1196 |
| - if isinstance(features, TensorDict): |
1197 |
| - feature_keys = list(features.keys()) # pyre-ignore[6] |
1198 |
| - if len(self._features_order) > 0: |
1199 |
| - feature_keys = [feature_keys[i] for i in self._features_order] |
1200 |
| - self._has_features_permute = False # feature_keys are in order |
1201 |
| - features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6] |
1202 | 1194 | ctx.variable_batch_per_feature = features.variable_stride_per_key()
|
1203 | 1195 | ctx.inverse_indices = features.inverse_indices_or_none()
|
1204 | 1196 | if self._has_uninitialized_input_dist:
|
|
0 commit comments