Skip to content

Commit 434e5dc

Browse files
Dark Knightfacebook-github-bot
Dark Knight
authored andcommitted
Revert D66521351
Summary: This diff reverts D66521351 Need to revert this to fix lowering import error breaking aps tests Reviewed By: PoojaAg18 Differential Revision: D68528333
1 parent f3d34fc commit 434e5dc

File tree

4 files changed

+10
-84
lines changed

4 files changed

+10
-84
lines changed

Diff for: torchrec/distributed/embedding.py

+2-11
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
)
2727

2828
import torch
29-
from tensordict import TensorDict
3029
from torch import distributed as dist, nn
3130
from torch.autograd.profiler import record_function
3231
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
@@ -91,7 +90,6 @@
9190
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
9291
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
9392
from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor
94-
from torchrec.sparse.tensor_dict import maybe_td_to_kjt
9593

9694
try:
9795
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -1200,15 +1198,8 @@ def _compute_sequence_vbe_context(
12001198
def input_dist(
12011199
self,
12021200
ctx: EmbeddingCollectionContext,
1203-
features: TypeUnion[KeyedJaggedTensor, TensorDict],
1201+
features: KeyedJaggedTensor,
12041202
) -> Awaitable[Awaitable[KJTList]]:
1205-
need_permute: bool = True
1206-
if isinstance(features, TensorDict):
1207-
feature_keys = list(features.keys()) # pyre-ignore[6]
1208-
if self._features_order:
1209-
feature_keys = [feature_keys[i] for i in self._features_order]
1210-
need_permute = False
1211-
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
12121203
if self._has_uninitialized_input_dist:
12131204
self._create_input_dist(input_feature_names=features.keys())
12141205
self._has_uninitialized_input_dist = False
@@ -1218,7 +1209,7 @@ def input_dist(
12181209
unpadded_features = features
12191210
features = pad_vbe_kjt_lengths(unpadded_features)
12201211

1221-
if need_permute and self._features_order:
1212+
if self._features_order:
12221213
features = features.permute(
12231214
self._features_order,
12241215
# pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]`

Diff for: torchrec/distributed/test_utils/test_sharding.py

+6-26
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def gen_model_and_input(
147147
long_indices: bool = True,
148148
global_constant_batch: bool = False,
149149
num_inputs: int = 1,
150-
input_type: str = "kjt", # "kjt" or "td"
151150
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
152151
torch.manual_seed(0)
153152
if dedup_feature_names:
@@ -178,9 +177,9 @@ def gen_model_and_input(
178177
feature_processor_modules=feature_processor_modules,
179178
)
180179
inputs = []
181-
if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input:
182-
for _ in range(num_inputs):
183-
inputs.append(
180+
for _ in range(num_inputs):
181+
inputs.append(
182+
(
184183
cast(VariableBatchModelInputCallable, generate)(
185184
average_batch_size=batch_size,
186185
world_size=world_size,
@@ -189,26 +188,8 @@ def gen_model_and_input(
189188
weighted_tables=weighted_tables or [],
190189
global_constant_batch=global_constant_batch,
191190
)
192-
)
193-
elif generate == ModelInput.generate:
194-
for _ in range(num_inputs):
195-
inputs.append(
196-
ModelInput.generate(
197-
world_size=world_size,
198-
tables=tables,
199-
dedup_tables=dedup_tables,
200-
weighted_tables=weighted_tables or [],
201-
num_float_features=num_float_features,
202-
variable_batch_size=variable_batch_size,
203-
batch_size=batch_size,
204-
long_indices=long_indices,
205-
input_type=input_type,
206-
)
207-
)
208-
else:
209-
for _ in range(num_inputs):
210-
inputs.append(
211-
cast(ModelInputCallable, generate)(
191+
if generate == ModelInput.generate_variable_batch_input
192+
else cast(ModelInputCallable, generate)(
212193
world_size=world_size,
213194
tables=tables,
214195
dedup_tables=dedup_tables,
@@ -219,6 +200,7 @@ def gen_model_and_input(
219200
long_indices=long_indices,
220201
)
221202
)
203+
)
222204
return (model, inputs)
223205

224206

@@ -315,7 +297,6 @@ def sharding_single_rank_test(
315297
global_constant_batch: bool = False,
316298
world_size_2D: Optional[int] = None,
317299
node_group_size: Optional[int] = None,
318-
input_type: str = "kjt", # "kjt" or "td"
319300
) -> None:
320301
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
321302
# Generate model & inputs.
@@ -338,7 +319,6 @@ def sharding_single_rank_test(
338319
batch_size=batch_size,
339320
feature_processor_modules=feature_processor_modules,
340321
global_constant_batch=global_constant_batch,
341-
input_type=input_type,
342322
)
343323
global_model = global_model.to(ctx.device)
344324
global_input = inputs[0][0].to(ctx.device)

Diff for: torchrec/distributed/tests/test_sequence_model_parallel.py

-41
Original file line numberDiff line numberDiff line change
@@ -376,44 +376,3 @@ def _test_sharding(
376376
variable_batch_per_feature=variable_batch_per_feature,
377377
global_constant_batch=True,
378378
)
379-
380-
381-
@skip_if_asan_class
382-
class TDSequenceModelParallelTest(SequenceModelParallelTest):
383-
384-
def test_sharding_variable_batch(self) -> None:
385-
pass
386-
387-
def _test_sharding(
388-
self,
389-
sharders: List[TestEmbeddingCollectionSharder],
390-
backend: str = "gloo",
391-
world_size: int = 2,
392-
local_size: Optional[int] = None,
393-
constraints: Optional[Dict[str, ParameterConstraints]] = None,
394-
model_class: Type[TestSparseNNBase] = TestSequenceSparseNN,
395-
qcomms_config: Optional[QCommsConfig] = None,
396-
apply_optimizer_in_backward_config: Optional[
397-
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
398-
] = None,
399-
variable_batch_size: bool = False,
400-
variable_batch_per_feature: bool = False,
401-
) -> None:
402-
self._run_multi_process_test(
403-
callable=sharding_single_rank_test,
404-
world_size=world_size,
405-
local_size=local_size,
406-
model_class=model_class,
407-
tables=self.tables,
408-
embedding_groups=self.embedding_groups,
409-
sharders=sharders,
410-
optim=EmbOptimType.EXACT_SGD,
411-
backend=backend,
412-
constraints=constraints,
413-
qcomms_config=qcomms_config,
414-
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
415-
variable_batch_size=variable_batch_size,
416-
variable_batch_per_feature=variable_batch_per_feature,
417-
global_constant_batch=True,
418-
input_type="td",
419-
)

Diff for: torchrec/modules/embedding_modules.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -219,10 +219,7 @@ def __init__(
219219
self._feature_names: List[List[str]] = [table.feature_names for table in tables]
220220
self.reset_parameters()
221221

222-
def forward(
223-
self,
224-
features: KeyedJaggedTensor, # can also take TensorDict as input
225-
) -> KeyedTensor:
222+
def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
226223
"""
227224
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
228225
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
@@ -453,7 +450,7 @@ def __init__( # noqa C901
453450

454451
def forward(
455452
self,
456-
features: KeyedJaggedTensor, # can also take TensorDict as input
453+
features: KeyedJaggedTensor,
457454
) -> Dict[str, JaggedTensor]:
458455
"""
459456
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
@@ -466,7 +463,6 @@ def forward(
466463
Dict[str, JaggedTensor]
467464
"""
468465

469-
features = maybe_td_to_kjt(features, None)
470466
feature_embeddings: Dict[str, JaggedTensor] = {}
471467
jt_dict: Dict[str, JaggedTensor] = features.to_dict()
472468
for i, emb_module in enumerate(self.embeddings.values()):

0 commit comments

Comments
 (0)