@@ -512,8 +512,8 @@ class ShardedQuantEbcInputDist(torch.nn.Module):
512
512
This module implements distributed inputs of a ShardedQuantEmbeddingBagCollection.
513
513
514
514
Args:
515
- sharding_type_to_sharding (Dict[
516
- str,
515
+ sharding_type_device_group_to_sharding (Dict[
516
+ Tuple[ str, str] ,
517
517
EmbeddingSharding[
518
518
NullShardingContext,
519
519
KJTList,
@@ -526,8 +526,8 @@ class ShardedQuantEbcInputDist(torch.nn.Module):
526
526
Example::
527
527
528
528
sqebc_input_dist = ShardedQuantEbcInputDist(
529
- sharding_type_to_sharding ={
530
- ShardingType.TABLE_WISE: InferTwSequenceEmbeddingSharding(
529
+ sharding_type_device_group_to_sharding ={
530
+ ( ShardingType.TABLE_WISE, "cpu") : InferTwSequenceEmbeddingSharding(
531
531
[],
532
532
ShardingEnv(
533
533
world_size=2,
@@ -568,10 +568,24 @@ def __init__(
568
568
)
569
569
self ._device = device
570
570
571
+ self ._shardings : List [
572
+ EmbeddingSharding [
573
+ NullShardingContext ,
574
+ InputDistOutputs ,
575
+ List [torch .Tensor ],
576
+ torch .Tensor ,
577
+ ]
578
+ ] = list (sharding_type_device_group_to_sharding .values ())
579
+
571
580
self ._input_dists : List [nn .Module ] = []
572
581
573
- self ._feature_splits : List [int ] = []
574
582
self ._features_order : List [int ] = []
583
+ self ._feature_names : List [List [str ]] = [
584
+ sharding .feature_names () for sharding in self ._shardings
585
+ ]
586
+ self ._feature_splits : List [int ] = [
587
+ len (sharding ) for sharding in self ._feature_names
588
+ ]
575
589
576
590
# forward pass flow control
577
591
self ._has_uninitialized_input_dist : bool = True
@@ -583,18 +597,20 @@ def _create_input_dist(
583
597
features_device : torch .device ,
584
598
input_dist_device : Optional [torch .device ] = None ,
585
599
) -> None :
586
- feature_names : List [str ] = []
587
- for sharding in self ._sharding_type_device_group_to_sharding .values ():
588
- self ._input_dists .append (
589
- sharding .create_input_dist (device = input_dist_device )
590
- )
591
- feature_names .extend (sharding .feature_names ())
592
- self ._feature_splits .append (len (sharding .feature_names ()))
600
+ flat_feature_names : List [str ] = [
601
+ feature_name
602
+ for sharding_feature_name in self ._feature_names
603
+ for feature_name in sharding_feature_name
604
+ ]
605
+ self ._input_dists = [
606
+ sharding .create_input_dist (device = input_dist_device )
607
+ for sharding in self ._shardings
608
+ ]
593
609
594
- if feature_names == input_feature_names :
610
+ if flat_feature_names == input_feature_names :
595
611
self ._has_features_permute = False
596
612
else :
597
- for f in feature_names :
613
+ for f in flat_feature_names :
598
614
self ._features_order .append (input_feature_names .index (f ))
599
615
self .register_buffer (
600
616
"_features_order_tensor" ,
0 commit comments