17
17
IntNBitTableBatchedEmbeddingBagsCodegen ,
18
18
)
19
19
from torch import nn
20
+ from torch .distributed ._shard .sharding_spec .api import EnumerableShardingSpec
20
21
from torchrec .distributed .embedding import (
21
22
create_sharding_infos_by_sharding_device_group ,
22
23
EmbeddingShardingInfo ,
@@ -83,14 +84,32 @@ def record_stream(self, stream: torch.Stream) -> None:
83
84
ctx .record_stream (stream )
84
85
85
86
86
- def get_device_from_parameter_sharding (ps : ParameterSharding ) -> str :
87
- # pyre-ignore
88
- return ps .sharding_spec .shards [0 ].placement .device ().type
87
+ def get_device_from_parameter_sharding (
88
+ ps : ParameterSharding ,
89
+ ) -> Union [str , Tuple [str , ...]]:
90
+ """
91
+ Returns list ofdevice type / shard if table is sharded across different device type
92
+ else reutrns single device type for the table parameter
93
+ """
94
+ if not isinstance (ps .sharding_spec , EnumerableShardingSpec ):
95
+ raise ValueError ("Expected EnumerableShardingSpec as input to the function" )
96
+
97
+ device_type_list : Tuple [str , ...] = tuple (
98
+ # pyre-fixme[16]: `Optional` has no attribute `device`
99
+ [shard .placement .device ().type for shard in ps .sharding_spec .shards ]
100
+ )
101
+ if len (set (device_type_list )) == 1 :
102
+ return device_type_list [0 ]
103
+ else :
104
+ assert (
105
+ ps .sharding_type == "row_wise"
106
+ ), "Only row_wise sharding supports sharding across multiple device types for a table"
107
+ return device_type_list
89
108
90
109
91
110
def get_device_from_sharding_infos (
92
111
emb_shard_infos : List [EmbeddingShardingInfo ],
93
- ) -> str :
112
+ ) -> Union [ str , Tuple [ str , ...]] :
94
113
res = list (
95
114
{
96
115
get_device_from_parameter_sharding (ps .param_sharding )
@@ -101,6 +120,13 @@ def get_device_from_sharding_infos(
101
120
return res [0 ]
102
121
103
122
123
+ def get_device_for_first_shard_from_sharding_infos (
124
+ emb_shard_infos : List [EmbeddingShardingInfo ],
125
+ ) -> str :
126
+ device_type = get_device_from_sharding_infos (emb_shard_infos )
127
+ return device_type [0 ] if isinstance (device_type , tuple ) else device_type
128
+
129
+
104
130
def create_infer_embedding_sharding (
105
131
sharding_type : str ,
106
132
sharding_infos : List [EmbeddingShardingInfo ],
@@ -112,8 +138,8 @@ def create_infer_embedding_sharding(
112
138
List [torch .Tensor ],
113
139
List [torch .Tensor ],
114
140
]:
115
- device_type_from_sharding_infos : str = get_device_from_sharding_infos (
116
- sharding_infos
141
+ device_type_from_sharding_infos : Union [ str , Tuple [ str , ...]] = (
142
+ get_device_from_sharding_infos ( sharding_infos )
117
143
)
118
144
119
145
if device_type_from_sharding_infos in ["cuda" , "mtia" ]:
@@ -132,7 +158,9 @@ def create_infer_embedding_sharding(
132
158
raise ValueError (
133
159
f"Sharding type not supported { sharding_type } for { device_type_from_sharding_infos } sharding"
134
160
)
135
- elif device_type_from_sharding_infos == "cpu" :
161
+ elif device_type_from_sharding_infos == "cpu" or isinstance (
162
+ device_type_from_sharding_infos , tuple
163
+ ):
136
164
if sharding_type == ShardingType .ROW_WISE .value :
137
165
return InferRwSequenceEmbeddingSharding (
138
166
sharding_infos = sharding_infos ,
@@ -437,13 +465,13 @@ def __init__(
437
465
self ._embedding_configs : List [EmbeddingConfig ] = module .embedding_configs ()
438
466
439
467
self ._sharding_type_device_group_to_sharding_infos : Dict [
440
- Tuple [str , str ], List [EmbeddingShardingInfo ]
468
+ Tuple [str , Union [ str , Tuple [ str , ...]] ], List [EmbeddingShardingInfo ]
441
469
] = create_sharding_infos_by_sharding_device_group (
442
470
module , table_name_to_parameter_sharding , fused_params
443
471
)
444
472
445
473
self ._sharding_type_device_group_to_sharding : Dict [
446
- Tuple [str , str ],
474
+ Tuple [str , Union [ str , Tuple [ str , ...]] ],
447
475
EmbeddingSharding [
448
476
InferSequenceShardingContext ,
449
477
InputDistOutputs ,
@@ -457,7 +485,11 @@ def __init__(
457
485
(
458
486
env
459
487
if not isinstance (env , Dict )
460
- else env [get_device_from_sharding_infos (embedding_configs )]
488
+ else env [
489
+ get_device_for_first_shard_from_sharding_infos (
490
+ embedding_configs
491
+ )
492
+ ]
461
493
),
462
494
device if get_propogate_device () else None ,
463
495
)
@@ -580,7 +612,7 @@ def tbes_configs(
580
612
581
613
def sharding_type_device_group_to_sharding_infos (
582
614
self ,
583
- ) -> Dict [Tuple [str , str ], List [EmbeddingShardingInfo ]]:
615
+ ) -> Dict [Tuple [str , Union [ str , Tuple [ str , ...]] ], List [EmbeddingShardingInfo ]]:
584
616
return self ._sharding_type_device_group_to_sharding_infos
585
617
586
618
def embedding_configs (self ) -> List [EmbeddingConfig ]:
@@ -872,7 +904,9 @@ def create_context(self) -> EmbeddingCollectionContext:
872
904
return EmbeddingCollectionContext (sharding_contexts = [])
873
905
874
906
@property
875
- def shardings (self ) -> Dict [Tuple [str , str ], FeatureShardingMixIn ]:
907
+ def shardings (
908
+ self ,
909
+ ) -> Dict [Tuple [str , Union [str , Tuple [str , ...]]], FeatureShardingMixIn ]:
876
910
# pyre-ignore [7]
877
911
return self ._sharding_type_device_group_to_sharding
878
912
@@ -965,7 +999,7 @@ def __init__(
965
999
self ,
966
1000
input_feature_names : List [str ],
967
1001
sharding_type_device_group_to_sharding : Dict [
968
- Tuple [str , str ],
1002
+ Tuple [str , Union [ str , Tuple [ str , ...]] ],
969
1003
EmbeddingSharding [
970
1004
InferSequenceShardingContext ,
971
1005
InputDistOutputs ,
0 commit comments