Skip to content

Commit ea1cc27

Browse files
zzthgfacebook-github-bot
authored andcommitted
Add feature names onto ShardedQuantEbcInputDist (#2725)
Summary: Pull Request resolved: #2725 Add `_feature_names` on `ShardedQuantEbcInputDist` so that we can check feature order on input_dist side. Reviewed By: 842974287 Differential Revision: D69186049 fbshipit-source-id: 127e4a4e61fdd0ea39fcc5f007edf37a718c3fef
1 parent 1d9541b commit ea1cc27

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

torchrec/distributed/quant_embeddingbag.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -512,8 +512,8 @@ class ShardedQuantEbcInputDist(torch.nn.Module):
512512
This module implements distributed inputs of a ShardedQuantEmbeddingBagCollection.
513513
514514
Args:
515-
sharding_type_to_sharding (Dict[
516-
str,
515+
sharding_type_device_group_to_sharding (Dict[
516+
Tuple[str, str],
517517
EmbeddingSharding[
518518
NullShardingContext,
519519
KJTList,
@@ -526,8 +526,8 @@ class ShardedQuantEbcInputDist(torch.nn.Module):
526526
Example::
527527
528528
sqebc_input_dist = ShardedQuantEbcInputDist(
529-
sharding_type_to_sharding={
530-
ShardingType.TABLE_WISE: InferTwSequenceEmbeddingSharding(
529+
sharding_type_device_group_to_sharding={
530+
(ShardingType.TABLE_WISE, "cpu"): InferTwSequenceEmbeddingSharding(
531531
[],
532532
ShardingEnv(
533533
world_size=2,
@@ -568,10 +568,24 @@ def __init__(
568568
)
569569
self._device = device
570570

571+
self._shardings: List[
572+
EmbeddingSharding[
573+
NullShardingContext,
574+
InputDistOutputs,
575+
List[torch.Tensor],
576+
torch.Tensor,
577+
]
578+
] = list(sharding_type_device_group_to_sharding.values())
579+
571580
self._input_dists: List[nn.Module] = []
572581

573-
self._feature_splits: List[int] = []
574582
self._features_order: List[int] = []
583+
self._feature_names: List[List[str]] = [
584+
sharding.feature_names() for sharding in self._shardings
585+
]
586+
self._feature_splits: List[int] = [
587+
len(sharding) for sharding in self._feature_names
588+
]
575589

576590
# forward pass flow control
577591
self._has_uninitialized_input_dist: bool = True
@@ -583,18 +597,20 @@ def _create_input_dist(
583597
features_device: torch.device,
584598
input_dist_device: Optional[torch.device] = None,
585599
) -> None:
586-
feature_names: List[str] = []
587-
for sharding in self._sharding_type_device_group_to_sharding.values():
588-
self._input_dists.append(
589-
sharding.create_input_dist(device=input_dist_device)
590-
)
591-
feature_names.extend(sharding.feature_names())
592-
self._feature_splits.append(len(sharding.feature_names()))
600+
flat_feature_names: List[str] = [
601+
feature_name
602+
for sharding_feature_name in self._feature_names
603+
for feature_name in sharding_feature_name
604+
]
605+
self._input_dists = [
606+
sharding.create_input_dist(device=input_dist_device)
607+
for sharding in self._shardings
608+
]
593609

594-
if feature_names == input_feature_names:
610+
if flat_feature_names == input_feature_names:
595611
self._has_features_permute = False
596612
else:
597-
for f in feature_names:
613+
for f in flat_feature_names:
598614
self._features_order.append(input_feature_names.index(f))
599615
self.register_buffer(
600616
"_features_order_tensor",

0 commit comments

Comments
 (0)