Skip to content

Commit 6f4bfe2

Browse files
Boris Saranafacebook-github-bot
Boris Sarana
authored andcommitted
Move sharding optimization flag to global_settings (#2665)
Summary: Pull Request resolved: #2665 As per title move the configuration flag to separate module for better abstraction and simpler rollout Reviewed By: iamzainhuda Differential Revision: D67777011 fbshipit-source-id: 8a659bee7b81d3181c4014fdf2678c69b306b8c1
1 parent e1b96a6 commit 6f4bfe2

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

torchrec/distributed/embedding_types.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99

1010
import abc
1111
import copy
12-
import os
1312
from dataclasses import dataclass
1413
from enum import Enum, unique
1514
from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
@@ -21,6 +20,9 @@
2120
from torch.distributed._tensor.placement_types import Placement
2221
from torch.nn.modules.module import _addindent
2322
from torch.nn.parallel import DistributedDataParallel
23+
from torchrec.distributed.global_settings import (
24+
construct_sharded_tensor_from_metadata_enabled,
25+
)
2426
from torchrec.distributed.types import (
2527
get_tensor_size_bytes,
2628
ModuleSharder,
@@ -346,8 +348,7 @@ def __init__(
346348

347349
# option to construct ShardedTensor from metadata avoiding expensive all-gather
348350
self._construct_sharded_tensor_from_metadata: bool = (
349-
os.environ.get("TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA", "0")
350-
== "1"
351+
construct_sharded_tensor_from_metadata_enabled()
351352
)
352353

353354
def prefetch(

torchrec/distributed/global_settings.py

+12
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,14 @@
77

88
# pyre-strict
99

10+
import os
11+
1012
PROPOGATE_DEVICE: bool = False
1113

14+
TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV = (
15+
"TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA"
16+
)
17+
1218

1319
def set_propogate_device(val: bool) -> None:
1420
global PROPOGATE_DEVICE
@@ -18,3 +24,9 @@ def set_propogate_device(val: bool) -> None:
1824
def get_propogate_device() -> bool:
1925
global PROPOGATE_DEVICE
2026
return PROPOGATE_DEVICE
27+
28+
29+
def construct_sharded_tensor_from_metadata_enabled() -> bool:
30+
return (
31+
os.environ.get(TORCHREC_CONSTRUCT_SHARDED_TENSOR_FROM_METADATA_ENV, "0") == "1"
32+
)

0 commit comments

Comments
 (0)