@@ -512,8 +512,8 @@ class ShardedQuantEbcInputDist(torch.nn.Module):
512512 This module implements distributed inputs of a ShardedQuantEmbeddingBagCollection.
513513
514514 Args:
515- sharding_type_to_sharding (Dict[
516- str,
515+ sharding_type_device_group_to_sharding (Dict[
516+ Tuple[ str, str] ,
517517 EmbeddingSharding[
518518 NullShardingContext,
519519 KJTList,
@@ -526,8 +526,8 @@ class ShardedQuantEbcInputDist(torch.nn.Module):
526526 Example::
527527
528528 sqebc_input_dist = ShardedQuantEbcInputDist(
529- sharding_type_to_sharding ={
530- ShardingType.TABLE_WISE: InferTwSequenceEmbeddingSharding(
529+ sharding_type_device_group_to_sharding ={
530+ ( ShardingType.TABLE_WISE, "cpu") : InferTwSequenceEmbeddingSharding(
531531 [],
532532 ShardingEnv(
533533 world_size=2,
@@ -568,10 +568,24 @@ def __init__(
568568 )
569569 self ._device = device
570570
571+ self ._shardings : List [
572+ EmbeddingSharding [
573+ NullShardingContext ,
574+ InputDistOutputs ,
575+ List [torch .Tensor ],
576+ torch .Tensor ,
577+ ]
578+ ] = list (sharding_type_device_group_to_sharding .values ())
579+
571580 self ._input_dists : List [nn .Module ] = []
572581
573- self ._feature_splits : List [int ] = []
574582 self ._features_order : List [int ] = []
583+ self ._feature_names : List [List [str ]] = [
584+ sharding .feature_names () for sharding in self ._shardings
585+ ]
586+ self ._feature_splits : List [int ] = [
587+ len (sharding ) for sharding in self ._feature_names
588+ ]
575589
576590 # forward pass flow control
577591 self ._has_uninitialized_input_dist : bool = True
@@ -583,18 +597,20 @@ def _create_input_dist(
583597 features_device : torch .device ,
584598 input_dist_device : Optional [torch .device ] = None ,
585599 ) -> None :
586- feature_names : List [str ] = []
587- for sharding in self ._sharding_type_device_group_to_sharding .values ():
588- self ._input_dists .append (
589- sharding .create_input_dist (device = input_dist_device )
590- )
591- feature_names .extend (sharding .feature_names ())
592- self ._feature_splits .append (len (sharding .feature_names ()))
600+ flat_feature_names : List [str ] = [
601+ feature_name
602+ for sharding_feature_name in self ._feature_names
603+ for feature_name in sharding_feature_name
604+ ]
605+ self ._input_dists = [
606+ sharding .create_input_dist (device = input_dist_device )
607+ for sharding in self ._shardings
608+ ]
593609
594- if feature_names == input_feature_names :
610+ if flat_feature_names == input_feature_names :
595611 self ._has_features_permute = False
596612 else :
597- for f in feature_names :
613+ for f in flat_feature_names :
598614 self ._features_order .append (input_feature_names .index (f ))
599615 self .register_buffer (
600616 "_features_order_tensor" ,
0 commit comments