Skip to content

Commit 6337072

Browse files
Allow passing in additional params to be ignored in the DDP wrapper (#2103)
Summary: Pull Request resolved: #2103 # What Add an option to allow users to pass in additional params to be ignored in DDP. # Why Currently the wrapper calls `DistributedDataParallel._set_params_and_buffers_to_ignore_for_model` to ignore all sharded params in the embedding modules. However, if users want to call `DistributedDataParallel._set_params_and_buffers_to_ignore_for_model` before torchrec, their params to-be-ignored will be overwriten by torchrec's call. Discussion: https://fb.workplace.com/groups/319878845696681/permalink/1199477041070186/ Why users want to call `_set_params_and_buffers_to_ignore_for_model` -- please see the diff on top of this for the motivation. # How In oder to mitigate this issue, we have to "batch" the call to `_set_params_and_buffers_to_ignore_for_model`. Therefore, we allow users to pass their params-to-be-ignored to the wrapper to batch with torchrec sharded params. Reviewed By: dstaay-fb Differential Revision: D58486022 fbshipit-source-id: 3896e02fec0cec7db528c265c7d0fbfdef1fea87
1 parent 54a51c7 commit 6337072

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

torchrec/distributed/model_parallel.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,13 @@ def __init__(
7777
static_graph: bool = True,
7878
find_unused_parameters: bool = False,
7979
allreduce_comm_precision: Optional[str] = None,
80+
params_to_ignore: Optional[List[str]] = None,
8081
) -> None:
8182
self._bucket_cap_mb: int = bucket_cap_mb
8283
self._static_graph: bool = static_graph
8384
self._find_unused_parameters: bool = find_unused_parameters
8485
self._allreduce_comm_precision = allreduce_comm_precision
86+
self._additional_params_to_ignore: Set[str] = set(params_to_ignore or [])
8587

8688
def _ddp_wrap(
8789
self,
@@ -136,7 +138,10 @@ def wrap(
136138
sharded_parameter_names = set(
137139
DistributedModelParallel._sharded_parameter_names(dmp._dmp_wrapped_module)
138140
)
139-
self._ddp_wrap(dmp, env, device, sharded_parameter_names)
141+
params_to_ignore = sharded_parameter_names.union(
142+
self._additional_params_to_ignore
143+
)
144+
self._ddp_wrap(dmp, env, device, params_to_ignore)
140145

141146

142147
def get_unwrapped_module(module: nn.Module) -> nn.Module:

0 commit comments

Comments
 (0)