Skip to content

Commit d580841

Browse files
faran928facebook-github-bot
authored andcommitted
Tuple as device_type input to support Heterogenous Sharding of tables across different device_typestable (#2600)
Summary: Pull Request resolved: #2600 As we plan to support heterogenous sharding across different device types (cuda / cpu etc), we will pass device type per shard in the format of tuple for device_type_from_sharding_info where each index will represent the device_type for that particular shard Reviewed By: jiayisuse Differential Revision: D65933148 fbshipit-source-id: 9f97405f65fe8b69228277945886ad61a0e18b3e
1 parent e42a768 commit d580841

File tree

4 files changed

+89
-31
lines changed

4 files changed

+89
-31
lines changed

torchrec/distributed/embedding.py

+27-6
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import torch
2929
from torch import distributed as dist, nn
3030
from torch.autograd.profiler import record_function
31+
from torch.distributed._shard.sharding_spec.api import EnumerableShardingSpec
3132
from torch.distributed._tensor import DTensor
3233
from torch.nn.parallel import DistributedDataParallel
3334
from torchrec.distributed.embedding_sharding import (
@@ -102,9 +103,27 @@
102103
EC_INDEX_DEDUP: bool = False
103104

104105

105-
def get_device_from_parameter_sharding(ps: ParameterSharding) -> str:
106-
# pyre-ignore
107-
return ps.sharding_spec.shards[0].placement.device().type
106+
def get_device_from_parameter_sharding(
107+
ps: ParameterSharding,
108+
) -> TypeUnion[str, Tuple[str, ...]]:
109+
"""
110+
Returns list of device type per shard if table is sharded across different device type
111+
else reutrns single device type for the table parameter
112+
"""
113+
if not isinstance(ps.sharding_spec, EnumerableShardingSpec):
114+
raise ValueError("Expected EnumerableShardingSpec as input to the function")
115+
116+
device_type_list: Tuple[str, ...] = tuple(
117+
# pyre-fixme[16]: `Optional` has no attribute `device`
118+
[shard.placement.device().type for shard in ps.sharding_spec.shards]
119+
)
120+
if len(set(device_type_list)) == 1:
121+
return device_type_list[0]
122+
else:
123+
assert (
124+
ps.sharding_type == "row_wise"
125+
), "Only row_wise sharding supports sharding across multiple device types for a table"
126+
return device_type_list
108127

109128

110129
def set_ec_index_dedup(val: bool) -> None:
@@ -248,13 +267,13 @@ def create_sharding_infos_by_sharding_device_group(
248267
module: EmbeddingCollectionInterface,
249268
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
250269
fused_params: Optional[Dict[str, Any]],
251-
) -> Dict[Tuple[str, str], List[EmbeddingShardingInfo]]:
270+
) -> Dict[Tuple[str, TypeUnion[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]:
252271

253272
if fused_params is None:
254273
fused_params = {}
255274

256275
sharding_type_device_group_to_sharding_infos: Dict[
257-
Tuple[str, str], List[EmbeddingShardingInfo]
276+
Tuple[str, TypeUnion[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]
258277
] = {}
259278
# state_dict returns parameter.Tensor, which loses parameter level attributes
260279
parameter_by_name = dict(module.named_parameters())
@@ -280,7 +299,9 @@ def create_sharding_infos_by_sharding_device_group(
280299
assert param_name in parameter_by_name or param_name in state_dict
281300
param = parameter_by_name.get(param_name, state_dict[param_name])
282301

283-
device_group = get_device_from_parameter_sharding(parameter_sharding)
302+
device_group: TypeUnion[str, Tuple[str, ...]] = (
303+
get_device_from_parameter_sharding(parameter_sharding)
304+
)
284305
if (
285306
parameter_sharding.sharding_type,
286307
device_group,

torchrec/distributed/quant_embedding.py

+47-13
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
IntNBitTableBatchedEmbeddingBagsCodegen,
1818
)
1919
from torch import nn
20+
from torch.distributed._shard.sharding_spec.api import EnumerableShardingSpec
2021
from torchrec.distributed.embedding import (
2122
create_sharding_infos_by_sharding_device_group,
2223
EmbeddingShardingInfo,
@@ -83,14 +84,32 @@ def record_stream(self, stream: torch.Stream) -> None:
8384
ctx.record_stream(stream)
8485

8586

86-
def get_device_from_parameter_sharding(ps: ParameterSharding) -> str:
87-
# pyre-ignore
88-
return ps.sharding_spec.shards[0].placement.device().type
87+
def get_device_from_parameter_sharding(
88+
ps: ParameterSharding,
89+
) -> Union[str, Tuple[str, ...]]:
90+
"""
91+
Returns list ofdevice type / shard if table is sharded across different device type
92+
else reutrns single device type for the table parameter
93+
"""
94+
if not isinstance(ps.sharding_spec, EnumerableShardingSpec):
95+
raise ValueError("Expected EnumerableShardingSpec as input to the function")
96+
97+
device_type_list: Tuple[str, ...] = tuple(
98+
# pyre-fixme[16]: `Optional` has no attribute `device`
99+
[shard.placement.device().type for shard in ps.sharding_spec.shards]
100+
)
101+
if len(set(device_type_list)) == 1:
102+
return device_type_list[0]
103+
else:
104+
assert (
105+
ps.sharding_type == "row_wise"
106+
), "Only row_wise sharding supports sharding across multiple device types for a table"
107+
return device_type_list
89108

90109

91110
def get_device_from_sharding_infos(
92111
emb_shard_infos: List[EmbeddingShardingInfo],
93-
) -> str:
112+
) -> Union[str, Tuple[str, ...]]:
94113
res = list(
95114
{
96115
get_device_from_parameter_sharding(ps.param_sharding)
@@ -101,6 +120,13 @@ def get_device_from_sharding_infos(
101120
return res[0]
102121

103122

123+
def get_device_for_first_shard_from_sharding_infos(
124+
emb_shard_infos: List[EmbeddingShardingInfo],
125+
) -> str:
126+
device_type = get_device_from_sharding_infos(emb_shard_infos)
127+
return device_type[0] if isinstance(device_type, tuple) else device_type
128+
129+
104130
def create_infer_embedding_sharding(
105131
sharding_type: str,
106132
sharding_infos: List[EmbeddingShardingInfo],
@@ -112,8 +138,8 @@ def create_infer_embedding_sharding(
112138
List[torch.Tensor],
113139
List[torch.Tensor],
114140
]:
115-
device_type_from_sharding_infos: str = get_device_from_sharding_infos(
116-
sharding_infos
141+
device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
142+
get_device_from_sharding_infos(sharding_infos)
117143
)
118144

119145
if device_type_from_sharding_infos in ["cuda", "mtia"]:
@@ -132,7 +158,9 @@ def create_infer_embedding_sharding(
132158
raise ValueError(
133159
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
134160
)
135-
elif device_type_from_sharding_infos == "cpu":
161+
elif device_type_from_sharding_infos == "cpu" or isinstance(
162+
device_type_from_sharding_infos, tuple
163+
):
136164
if sharding_type == ShardingType.ROW_WISE.value:
137165
return InferRwSequenceEmbeddingSharding(
138166
sharding_infos=sharding_infos,
@@ -437,13 +465,13 @@ def __init__(
437465
self._embedding_configs: List[EmbeddingConfig] = module.embedding_configs()
438466

439467
self._sharding_type_device_group_to_sharding_infos: Dict[
440-
Tuple[str, str], List[EmbeddingShardingInfo]
468+
Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]
441469
] = create_sharding_infos_by_sharding_device_group(
442470
module, table_name_to_parameter_sharding, fused_params
443471
)
444472

445473
self._sharding_type_device_group_to_sharding: Dict[
446-
Tuple[str, str],
474+
Tuple[str, Union[str, Tuple[str, ...]]],
447475
EmbeddingSharding[
448476
InferSequenceShardingContext,
449477
InputDistOutputs,
@@ -457,7 +485,11 @@ def __init__(
457485
(
458486
env
459487
if not isinstance(env, Dict)
460-
else env[get_device_from_sharding_infos(embedding_configs)]
488+
else env[
489+
get_device_for_first_shard_from_sharding_infos(
490+
embedding_configs
491+
)
492+
]
461493
),
462494
device if get_propogate_device() else None,
463495
)
@@ -580,7 +612,7 @@ def tbes_configs(
580612

581613
def sharding_type_device_group_to_sharding_infos(
582614
self,
583-
) -> Dict[Tuple[str, str], List[EmbeddingShardingInfo]]:
615+
) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], List[EmbeddingShardingInfo]]:
584616
return self._sharding_type_device_group_to_sharding_infos
585617

586618
def embedding_configs(self) -> List[EmbeddingConfig]:
@@ -872,7 +904,9 @@ def create_context(self) -> EmbeddingCollectionContext:
872904
return EmbeddingCollectionContext(sharding_contexts=[])
873905

874906
@property
875-
def shardings(self) -> Dict[Tuple[str, str], FeatureShardingMixIn]:
907+
def shardings(
908+
self,
909+
) -> Dict[Tuple[str, Union[str, Tuple[str, ...]]], FeatureShardingMixIn]:
876910
# pyre-ignore [7]
877911
return self._sharding_type_device_group_to_sharding
878912

@@ -965,7 +999,7 @@ def __init__(
965999
self,
9661000
input_feature_names: List[str],
9671001
sharding_type_device_group_to_sharding: Dict[
968-
Tuple[str, str],
1002+
Tuple[str, Union[str, Tuple[str, ...]]],
9691003
EmbeddingSharding[
9701004
InferSequenceShardingContext,
9711005
InputDistOutputs,

torchrec/distributed/sharding/rw_sequence_sharding.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
# pyre-strict
99

10-
from typing import Any, Dict, List, Optional
10+
from typing import Any, Dict, List, Optional, Tuple, Union
1111

1212
import torch
1313
import torch.distributed as dist
@@ -166,11 +166,11 @@ def __init__(
166166
self,
167167
device: torch.device,
168168
world_size: int,
169-
device_type_from_sharding_infos: Optional[str] = None,
169+
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
170170
) -> None:
171171
super().__init__()
172172
self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size)
173-
self._device_type_from_sharding_infos: Optional[str] = (
173+
self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = (
174174
device_type_from_sharding_infos
175175
)
176176

@@ -179,13 +179,16 @@ def forward(
179179
local_embs: List[torch.Tensor],
180180
sharding_ctx: Optional[InferSequenceShardingContext] = None,
181181
) -> List[torch.Tensor]:
182-
# for cpu sharder, output dist should be a no-op
183-
return (
184-
local_embs
185-
if self._device_type_from_sharding_infos is not None
186-
and self._device_type_from_sharding_infos == "cpu"
187-
else self._dist(local_embs)
188-
)
182+
if self._device_type_from_sharding_infos is not None:
183+
if isinstance(self._device_type_from_sharding_infos, tuple):
184+
# Update the tagging when tuple has heterogenous device type
185+
# Done in next diff stack along with the full support for
186+
# hetergoenous device type
187+
return local_embs
188+
elif self._device_type_from_sharding_infos == "cpu":
189+
# for cpu sharder, output dist should be a no-op
190+
return local_embs
191+
return self._dist(local_embs)
189192

190193

191194
class InferRwSequenceEmbeddingSharding(

torchrec/distributed/sharding/rw_sharding.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def __init__(
118118
device: Optional[torch.device] = None,
119119
need_pos: bool = False,
120120
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
121-
device_type_from_sharding_infos: Optional[str] = None,
121+
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
122122
) -> None:
123123
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
124124
self._env = env
@@ -133,7 +133,7 @@ def __init__(
133133
if device is None:
134134
device = torch.device("cpu")
135135
self._device: torch.device = device
136-
self._device_type_from_sharding_infos: Optional[str] = (
136+
self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = (
137137
device_type_from_sharding_infos
138138
)
139139
sharded_tables_per_rank = self._shard(sharding_infos)

0 commit comments

Comments
 (0)