Skip to content

Commit 00d8ed2

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
add size and stride for empty shard DT (#2662)
Summary: Pull Request resolved: #2662 Bringing DT empty shard on rank to behave the same as ST empty shard. For OT, our current DT approach broke transfer learning because they expect the tensor.size() to return global shape, we amend the DT empty shard init to include global shape and stride. Differential Revision: D67727355 fbshipit-source-id: 9823d3e75c7e4bf2dad1b77d8dcbd0ee960205ec
1 parent fc79c7a commit 00d8ed2

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

torchrec/distributed/embeddingbag.py

+12
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from torch.distributed._tensor import DTensor
3333
from torch.nn.modules.module import _IncompatibleKeys
3434
from torch.nn.parallel import DistributedDataParallel
35+
from torchrec.distributed.comm import get_local_size
3536
from torchrec.distributed.embedding_sharding import (
3637
EmbeddingSharding,
3738
EmbeddingShardingContext,
@@ -73,6 +74,7 @@
7374
add_params_from_parameter_sharding,
7475
append_prefix,
7576
convert_to_fbgemm_types,
77+
create_global_tensor_shape_stride_from_metadata,
7678
maybe_annotate_embedding_event,
7779
merge_fused_params,
7880
none_throws,
@@ -918,6 +920,14 @@ def _initialize_torch_state(self) -> None: # noqa
918920
)
919921
)
920922
else:
923+
shape, stride = create_global_tensor_shape_stride_from_metadata(
924+
none_throws(self.module_sharding_plan[table_name]),
925+
(
926+
self._env.node_group_size
927+
if isinstance(self._env, ShardingEnv2D)
928+
else get_local_size(self._env.world_size)
929+
),
930+
)
921931
# empty shard case
922932
self._model_parallel_name_to_dtensor[table_name] = (
923933
DTensor.from_local(
@@ -927,6 +937,8 @@ def _initialize_torch_state(self) -> None: # noqa
927937
),
928938
device_mesh=self._env.device_mesh,
929939
run_check=False,
940+
shape=shape,
941+
stride=stride,
930942
)
931943
)
932944
else:

torchrec/distributed/utils.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from collections import OrderedDict
1616
from contextlib import AbstractContextManager, nullcontext
1717
from dataclasses import asdict
18-
from typing import Any, Dict, List, Optional, Set, Type, TypeVar, Union
18+
from typing import Any, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
1919

2020
import torch
2121
from fbgemm_gpu.split_embedding_configs import EmbOptimType
@@ -511,3 +511,44 @@ def interaction(self, *args, **kwargs) -> None:
511511
pdb.Pdb.interaction(self, *args, **kwargs)
512512
finally:
513513
sys.stdin = _stdin
514+
515+
516+
def create_global_tensor_shape_stride_from_metadata(
517+
parameter_sharding: ParameterSharding, devices_per_node: Optional[int] = None
518+
) -> Tuple[torch.Size, Tuple[int, int]]:
519+
"""
520+
Create a global tensor shape and stride from shard metadata.
521+
522+
Returns:
523+
torch.Size: global tensor shape.
524+
tuple: global tensor stride.
525+
"""
526+
size = None
527+
if parameter_sharding.sharding_type == ShardingType.COLUMN_WISE.value:
528+
row_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[0] # pyre-ignore[16]
529+
col_dim = 0
530+
for shard in parameter_sharding.sharding_spec.shards:
531+
col_dim += shard.shard_sizes[1]
532+
size = torch.Size([row_dim, col_dim])
533+
elif (
534+
parameter_sharding.sharding_type == ShardingType.ROW_WISE.value
535+
or parameter_sharding.sharding_type == ShardingType.TABLE_ROW_WISE.value
536+
):
537+
row_dim = 0
538+
col_dim = parameter_sharding.sharding_spec.shards[0].shard_sizes[1]
539+
for shard in parameter_sharding.sharding_spec.shards:
540+
row_dim += shard.shard_sizes[0]
541+
size = torch.Size([row_dim, col_dim])
542+
elif parameter_sharding.sharding_type == ShardingType.TABLE_WISE.value:
543+
size = torch.Size(parameter_sharding.sharding_spec.shards[0].shard_sizes)
544+
elif parameter_sharding.sharding_type == ShardingType.GRID_SHARD.value:
545+
# we need node group size to appropriately calculate global shape from shard
546+
assert devices_per_node is not None
547+
row_dim, col_dim = 0, 0
548+
num_cw_shards = len(parameter_sharding.sharding_spec.shards) // devices_per_node
549+
for _ in range(num_cw_shards):
550+
col_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[1]
551+
for _ in range(devices_per_node):
552+
row_dim += parameter_sharding.sharding_spec.shards[0].shard_sizes[0]
553+
size = torch.Size([row_dim, col_dim])
554+
return size, (size[1], 1) if size else (torch.Size([0, 0]), (0, 1)) # pyre-ignore[7]

0 commit comments

Comments
 (0)