Skip to content

Commit 1f0681e

Browse files
faran928facebook-github-bot
authored andcommitted
Propagate device type for heterogenous sharding of table across different device types (#2606)
Summary: Pull Request resolved: #2606 For row wise heterogenous sharding of tables acorss cuda and cpu, we propagate the correct device type within each look up module based on which shard of each table is being looked up / fetched within that module. We also move some of the wrapper functions that can enable us to pass batch info information correctly across different modules during model split. The changes should be backward compatible and not impact existing behavior Reviewed By: jiayisuse Differential Revision: D66682124 fbshipit-source-id: 86a8ee004f9481d3891367dfb6011f5a45aff149
1 parent d580841 commit 1f0681e

File tree

7 files changed

+145
-44
lines changed

7 files changed

+145
-44
lines changed

torchrec/distributed/embedding_lookup.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -677,25 +677,30 @@ def __init__(
677677
grouped_configs: List[GroupedEmbeddingConfig],
678678
device: Optional[torch.device] = None,
679679
fused_params: Optional[Dict[str, Any]] = None,
680+
shard_index: Optional[int] = None,
680681
) -> None:
681682
# TODO rename to _create_embedding_kernel
682683
def _create_lookup(
683684
config: GroupedEmbeddingConfig,
684685
device: Optional[torch.device] = None,
685686
fused_params: Optional[Dict[str, Any]] = None,
687+
shard_index: Optional[int] = None,
686688
) -> BaseBatchedEmbedding[
687689
Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]
688690
]:
689691
return QuantBatchedEmbedding(
690692
config=config,
691693
device=device,
692694
fused_params=fused_params,
695+
shard_index=shard_index,
693696
)
694697

695698
super().__init__()
696699
self._emb_modules: nn.ModuleList = nn.ModuleList()
697700
for config in grouped_configs:
698-
self._emb_modules.append(_create_lookup(config, device, fused_params))
701+
self._emb_modules.append(
702+
_create_lookup(config, device, fused_params, shard_index)
703+
)
699704

700705
self._feature_splits: List[int] = [
701706
config.num_features() for config in grouped_configs
@@ -1076,6 +1081,7 @@ def __init__(
10761081
world_size: int,
10771082
fused_params: Optional[Dict[str, Any]] = None,
10781083
device: Optional[torch.device] = None,
1084+
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
10791085
) -> None:
10801086
super().__init__()
10811087
self._embedding_lookups_per_rank: List[MetaInferGroupedEmbeddingsLookup] = []
@@ -1089,11 +1095,18 @@ def __init__(
10891095
"meta" if device is not None and device.type == "meta" else "cuda"
10901096
)
10911097
for rank in range(world_size):
1098+
# propagate shard index to get the correct runtime_device based on shard metadata
1099+
# in case of heterogenous sharding of a single table acorss different device types
1100+
shard_index = (
1101+
rank if isinstance(device_type_from_sharding_infos, tuple) else None
1102+
)
1103+
device = rank_device(device_type, rank)
10921104
self._embedding_lookups_per_rank.append(
10931105
MetaInferGroupedEmbeddingsLookup(
10941106
grouped_configs=grouped_configs_per_rank[rank],
10951107
device=rank_device(device_type, rank),
10961108
fused_params=fused_params,
1109+
shard_index=shard_index,
10971110
)
10981111
)
10991112

torchrec/distributed/quant_embedding.py

+28-21
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,11 @@
5757
dtype_to_data_type,
5858
EmbeddingConfig,
5959
)
60-
from torchrec.quant.embedding_modules import (
60+
from torchrec.modules.utils import (
61+
_fx_trec_get_feature_length,
6162
_get_batching_hinted_output,
63+
)
64+
from torchrec.quant.embedding_modules import (
6265
EmbeddingCollection as QuantEmbeddingCollection,
6366
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
6467
)
@@ -67,6 +70,7 @@
6770

6871
torch.fx.wrap("len")
6972
torch.fx.wrap("_get_batching_hinted_output")
73+
torch.fx.wrap("_fx_trec_get_feature_length")
7074

7175
try:
7276
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -201,17 +205,6 @@ def _get_unbucketize_tensor_via_length_alignment(
201205
return bucketize_permute_tensor
202206

203207

204-
@torch.fx.wrap
205-
def _fx_trec_get_feature_length(
206-
features: KeyedJaggedTensor, embedding_names: List[str]
207-
) -> torch.Tensor:
208-
torch._assert(
209-
len(embedding_names) == len(features.keys()),
210-
"embedding output and features mismatch",
211-
)
212-
return features.lengths()
213-
214-
215208
def _construct_jagged_tensors_tw(
216209
embeddings: List[torch.Tensor],
217210
embedding_names_per_rank: List[List[str]],
@@ -355,6 +348,7 @@ def _construct_jagged_tensors(
355348
rw_feature_length_after_bucketize: Optional[torch.Tensor],
356349
cw_features_to_permute_indices: Dict[str, torch.Tensor],
357350
key_to_feature_permuted_coordinates: Dict[str, torch.Tensor],
351+
device_type: Union[str, Tuple[str, ...]],
358352
) -> Dict[str, JaggedTensor]:
359353

360354
# Validating sharding type and parameters
@@ -373,15 +367,24 @@ def _construct_jagged_tensors(
373367
features_before_input_dist_length = _fx_trec_get_feature_length(
374368
features_before_input_dist, embedding_names
375369
)
376-
embeddings = [
377-
_get_batching_hinted_output(
378-
_fx_trec_get_feature_length(features[i], embedding_names_per_rank[i]),
379-
embeddings[i],
380-
)
381-
for i in range(len(embedding_names_per_rank))
382-
]
370+
input_embeddings = []
371+
for i in range(len(embedding_names_per_rank)):
372+
if isinstance(device_type, tuple) and device_type[i] != "cpu":
373+
# batching hint is already propagated and passed for this case
374+
# upstream
375+
input_embeddings.append(embeddings[i])
376+
else:
377+
input_embeddings.append(
378+
_get_batching_hinted_output(
379+
_fx_trec_get_feature_length(
380+
features[i], embedding_names_per_rank[i]
381+
),
382+
embeddings[i],
383+
)
384+
)
385+
383386
return _construct_jagged_tensors_rw(
384-
embeddings,
387+
input_embeddings,
385388
embedding_names,
386389
features_before_input_dist_length,
387390
features_before_input_dist.values() if need_indices else None,
@@ -746,6 +749,9 @@ def input_dist(
746749
unbucketize_permute_tensor=unbucketize_permute_tensor_list[i],
747750
bucket_mapping_tensor=bucket_mapping_tensor_list[i],
748751
bucketized_length=bucketized_length_list[i],
752+
embedding_names_per_rank=self._embedding_names_per_rank_per_sharding[
753+
i
754+
],
749755
)
750756
)
751757
return input_dist_result_list
@@ -828,7 +834,7 @@ def output_jt_dict(
828834
) -> Dict[str, JaggedTensor]:
829835
jt_dict_res: Dict[str, JaggedTensor] = {}
830836
for (
831-
(sharding_type, _),
837+
(sharding_type, device_type),
832838
emb_sharding,
833839
features_sharding,
834840
embedding_names,
@@ -876,6 +882,7 @@ def output_jt_dict(
876882
),
877883
cw_features_to_permute_indices=self._features_to_permute_indices,
878884
key_to_feature_permuted_coordinates=key_to_feature_permuted_coordinates,
885+
device_type=device_type,
879886
)
880887
for embedding_name in embedding_names:
881888
jt_dict_res[embedding_name] = jt_dict[embedding_name]

torchrec/distributed/quant_embedding_kernel.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,11 @@ def _quantize_weight(
120120

121121

122122
def _get_runtime_device(
123-
device: Optional[torch.device], config: GroupedEmbeddingConfig
123+
device: Optional[torch.device],
124+
config: GroupedEmbeddingConfig,
125+
shard_index: Optional[int] = None,
124126
) -> torch.device:
127+
index: int = 0 if shard_index is None else shard_index
125128
if device is not None and device.type != "meta":
126129
return device
127130
else:
@@ -136,9 +139,12 @@ def _get_runtime_device(
136139
or (
137140
table.global_metadata is not None
138141
and len(table.global_metadata.shards_metadata)
139-
and table.global_metadata.shards_metadata[0].placement is not None
142+
and table.global_metadata.shards_metadata[index].placement
143+
is not None
140144
# pyre-ignore: Undefined attribute [16]
141-
and table.global_metadata.shards_metadata[0].placement.device().type
145+
and table.global_metadata.shards_metadata[index]
146+
.placement.device()
147+
.type
142148
== "cpu"
143149
)
144150
for table in config.embedding_tables
@@ -430,6 +436,7 @@ def __init__(
430436
pg: Optional[dist.ProcessGroup] = None,
431437
device: Optional[torch.device] = None,
432438
fused_params: Optional[Dict[str, Any]] = None,
439+
shard_index: Optional[int] = None,
433440
) -> None:
434441
super().__init__(config, pg, device)
435442

@@ -446,7 +453,9 @@ def __init__(
446453
self._quant_state_dict_split_scale_bias: bool = (
447454
is_fused_param_quant_state_dict_split_scale_bias(fused_params)
448455
)
449-
self._runtime_device: torch.device = _get_runtime_device(device, config)
456+
self._runtime_device: torch.device = _get_runtime_device(
457+
device, config, shard_index
458+
)
450459
# 16 for CUDA, 1 for others like CPU and MTIA.
451460
self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1
452461
self._emb_module: IntNBitTableBatchedEmbeddingBagsCodegen = (

torchrec/distributed/sharding/rw_sequence_sharding.py

+63-11
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,15 @@
3939
SequenceShardingContext,
4040
)
4141
from torchrec.distributed.types import Awaitable, CommOp, QuantizedCommCodecs
42+
from torchrec.modules.utils import (
43+
_fx_trec_get_feature_length,
44+
_get_batching_hinted_output,
45+
)
4246
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
4347

48+
torch.fx.wrap("_get_batching_hinted_output")
49+
torch.fx.wrap("_fx_trec_get_feature_length")
50+
4451

4552
class RwSequenceEmbeddingDist(
4653
BaseEmbeddingDist[SequenceShardingContext, torch.Tensor, torch.Tensor]
@@ -169,26 +176,70 @@ def __init__(
169176
device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = None,
170177
) -> None:
171178
super().__init__()
172-
self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size)
173179
self._device_type_from_sharding_infos: Optional[Union[str, Tuple[str, ...]]] = (
174180
device_type_from_sharding_infos
175181
)
182+
num_cpu_ranks = 0
183+
if self._device_type_from_sharding_infos and isinstance(
184+
self._device_type_from_sharding_infos, tuple
185+
):
186+
for device_type in self._device_type_from_sharding_infos:
187+
if device_type == "cpu":
188+
num_cpu_ranks += 1
189+
elif self._device_type_from_sharding_infos == "cpu":
190+
num_cpu_ranks = world_size
191+
192+
self._device_dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(
193+
device, world_size - num_cpu_ranks
194+
)
176195

177196
def forward(
178197
self,
179198
local_embs: List[torch.Tensor],
180199
sharding_ctx: Optional[InferSequenceShardingContext] = None,
181200
) -> List[torch.Tensor]:
182-
if self._device_type_from_sharding_infos is not None:
183-
if isinstance(self._device_type_from_sharding_infos, tuple):
184-
# Update the tagging when tuple has heterogenous device type
185-
# Done in next diff stack along with the full support for
186-
# hetergoenous device type
187-
return local_embs
188-
elif self._device_type_from_sharding_infos == "cpu":
189-
# for cpu sharder, output dist should be a no-op
190-
return local_embs
191-
return self._dist(local_embs)
201+
assert (
202+
self._device_type_from_sharding_infos is not None
203+
), "_device_type_from_sharding_infos should always be set for InferRwSequenceEmbeddingDist"
204+
if isinstance(self._device_type_from_sharding_infos, tuple):
205+
assert sharding_ctx is not None
206+
assert sharding_ctx.embedding_names_per_rank is not None
207+
assert len(self._device_type_from_sharding_infos) == len(
208+
local_embs
209+
), "For heterogeneous sharding, the number of local_embs should be equal to the number of device types"
210+
non_cpu_local_embs = []
211+
# Here looping through local_embs is also compatible with tracing
212+
# given the number of looks up / shards withing ShardedQuantEmbeddingCollection
213+
# are fixed and local_embs is the output of those looks ups. However, still
214+
# using _device_type_from_sharding_infos to iterate on local_embs list as
215+
# that's a better practice.
216+
for i, device_type in enumerate(self._device_type_from_sharding_infos):
217+
if device_type != "cpu":
218+
non_cpu_local_embs.append(
219+
_get_batching_hinted_output(
220+
_fx_trec_get_feature_length(
221+
sharding_ctx.features[i],
222+
# pyre-fixme [16]
223+
sharding_ctx.embedding_names_per_rank[i],
224+
),
225+
local_embs[i],
226+
)
227+
)
228+
non_cpu_local_embs_dist = self._device_dist(non_cpu_local_embs)
229+
index = 0
230+
result = []
231+
for i, device_type in enumerate(self._device_type_from_sharding_infos):
232+
if device_type == "cpu":
233+
result.append(local_embs[i])
234+
else:
235+
result.append(non_cpu_local_embs_dist[index])
236+
index += 1
237+
return result
238+
elif self._device_type_from_sharding_infos == "cpu":
239+
# for cpu sharder, output dist should be a no-op
240+
return local_embs
241+
else:
242+
return self._device_dist(local_embs)
192243

193244

194245
class InferRwSequenceEmbeddingSharding(
@@ -237,6 +288,7 @@ def create_lookup(
237288
world_size=self._world_size,
238289
fused_params=fused_params,
239290
device=device if device is not None else self._device,
291+
device_type_from_sharding_infos=self._device_type_from_sharding_infos,
240292
)
241293

242294
def create_output_dist(

torchrec/distributed/sharding/sequence_sharding.py

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class InferSequenceShardingContext(Multistreamable):
9898
unbucketize_permute_tensor: Optional[torch.Tensor] = None
9999
bucket_mapping_tensor: Optional[torch.Tensor] = None
100100
bucketized_length: Optional[torch.Tensor] = None
101+
embedding_names_per_rank: Optional[List[List[str]]] = None
101102

102103
def record_stream(self, stream: torch.Stream) -> None:
103104
for feature in self.features:

torchrec/modules/utils.py

+20
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
1414

1515
import torch
16+
from torch import Tensor
1617
from torch.profiler import record_function
1718
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
1819
from torchrec.streamable import Multistreamable
1920
from torchrec.types import CacheMixin
2021

22+
torch.fx.wrap("len")
23+
2124

2225
@dataclass
2326
class SequenceVBEContext(Multistreamable):
@@ -406,3 +409,20 @@ def reset_module_states_post_sharding(
406409
for submod in module.modules():
407410
if isinstance(submod, CacheMixin):
408411
submod.clear_cache()
412+
413+
414+
@torch.fx.wrap
415+
def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor:
416+
# this is a fx rule to help with batching hinting jagged sequence tensor coalescing.
417+
return output
418+
419+
420+
@torch.fx.wrap
421+
def _fx_trec_get_feature_length(
422+
features: KeyedJaggedTensor, embedding_names: List[str]
423+
) -> torch.Tensor:
424+
torch._assert(
425+
len(embedding_names) == len(features.keys()),
426+
"embedding output and features mismatch",
427+
)
428+
return features.lengths()

torchrec/quant/embedding_modules.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@
4848
ManagedCollisionEmbeddingCollection as OriginalManagedCollisionEmbeddingCollection,
4949
)
5050
from torchrec.modules.mc_modules import ManagedCollisionCollection
51-
from torchrec.modules.utils import construct_jagged_tensors_inference
51+
from torchrec.modules.utils import (
52+
_get_batching_hinted_output,
53+
construct_jagged_tensors_inference,
54+
)
5255
from torchrec.sparse.jagged_tensor import (
5356
ComputeKJTToJTDict,
5457
JaggedTensor,
@@ -58,6 +61,8 @@
5861
from torchrec.tensor_types import UInt2Tensor, UInt4Tensor
5962
from torchrec.types import ModuleNoCopyMixin
6063

64+
torch.fx.wrap("_get_batching_hinted_output")
65+
6166
try:
6267
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
6368
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
@@ -93,12 +98,6 @@
9398
DEFAULT_ROW_ALIGNMENT = 16
9499

95100

96-
@torch.fx.wrap
97-
def _get_batching_hinted_output(lengths: Tensor, output: Tensor) -> Tensor:
98-
# this is a fx rule to help with batching hinting jagged sequence tensor coalescing.
99-
return output
100-
101-
102101
@torch.fx.wrap
103102
def _get_feature_length(feature: KeyedJaggedTensor) -> Tensor:
104103
return feature.lengths()

0 commit comments

Comments
 (0)