Skip to content

Commit e1b96a6

Browse files
Boris Saranafacebook-github-bot
Boris Sarana
authored andcommitted
Reland of D65489998 Optimize sharding performance of embeddings" (#2664)
Summary: Pull Request resolved: #2664 X-link: pytorch/FBGEMM#3549 X-link: facebookresearch/FBGEMM#634 This diff is a reland of D65489998 after backout in D66800554. Reviewed By: iamzainhuda Differential Revision: D66828907 fbshipit-source-id: ab6e6a9faa8255c4847a69a8efb46182bedc9737
1 parent 504642a commit e1b96a6

File tree

5 files changed

+169
-43
lines changed

5 files changed

+169
-43
lines changed

torchrec/distributed/embedding_types.py

+7
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import abc
1111
import copy
12+
import os
1213
from dataclasses import dataclass
1314
from enum import Enum, unique
1415
from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
@@ -343,6 +344,12 @@ def __init__(
343344
self._lookups: List[nn.Module] = []
344345
self._output_dists: List[nn.Module] = []
345346

347+
# option to construct ShardedTensor from metadata avoiding expensive all-gather
348+
self._construct_sharded_tensor_from_metadata: bool = (
349+
os.environ.get("TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA", "0")
350+
== "1"
351+
)
352+
346353
def prefetch(
347354
self,
348355
dist_input: KJTList,

torchrec/distributed/embeddingbag.py

+42-9
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings
3030
from torch import distributed as dist, nn, Tensor
3131
from torch.autograd.profiler import record_function
32+
from torch.distributed._shard.sharded_tensor import TensorProperties
3233
from torch.distributed._tensor import DTensor
3334
from torch.nn.modules.module import _IncompatibleKeys
3435
from torch.nn.parallel import DistributedDataParallel
@@ -81,6 +82,7 @@
8182
optimizer_type_to_emb_opt_type,
8283
)
8384
from torchrec.modules.embedding_configs import (
85+
data_type_to_dtype,
8486
EmbeddingBagConfig,
8587
EmbeddingTableConfig,
8688
PoolingType,
@@ -945,17 +947,48 @@ def _initialize_torch_state(self) -> None: # noqa
945947
# created ShardedTensors once in init, use in post_state_dict_hook
946948
# note: at this point kvstore backed tensors don't own valid snapshots, so no read
947949
# access is allowed on them.
948-
self._model_parallel_name_to_sharded_tensor[table_name] = (
949-
ShardedTensor._init_from_local_shards(
950-
local_shards,
951-
self._name_to_table_size[table_name],
952-
process_group=(
953-
self._env.sharding_pg
954-
if isinstance(self._env, ShardingEnv2D)
955-
else self._env.process_group
950+
951+
# create ShardedTensor from local shards and metadata avoding all_gather collective
952+
if self._construct_sharded_tensor_from_metadata:
953+
sharding_spec = none_throws(
954+
self.module_sharding_plan[table_name].sharding_spec
955+
)
956+
957+
tensor_properties = TensorProperties(
958+
dtype=(
959+
data_type_to_dtype(
960+
self._table_name_to_config[table_name].data_type
961+
)
956962
),
957963
)
958-
)
964+
965+
self._model_parallel_name_to_sharded_tensor[table_name] = (
966+
ShardedTensor._init_from_local_shards_and_global_metadata(
967+
local_shards=local_shards,
968+
sharded_tensor_metadata=sharding_spec.build_metadata(
969+
tensor_sizes=self._name_to_table_size[table_name],
970+
tensor_properties=tensor_properties,
971+
),
972+
process_group=(
973+
self._env.sharding_pg
974+
if isinstance(self._env, ShardingEnv2D)
975+
else self._env.process_group
976+
),
977+
)
978+
)
979+
else:
980+
# create ShardedTensor from local shards using all_gather collective
981+
self._model_parallel_name_to_sharded_tensor[table_name] = (
982+
ShardedTensor._init_from_local_shards(
983+
local_shards,
984+
self._name_to_table_size[table_name],
985+
process_group=(
986+
self._env.sharding_pg
987+
if isinstance(self._env, ShardingEnv2D)
988+
else self._env.process_group
989+
),
990+
)
991+
)
959992

960993
def extract_sharded_kvtensors(
961994
module: ShardedEmbeddingBagCollection,

0 commit comments

Comments
 (0)