Skip to content

Commit b4f1ab4

Browse files
mingxzhaopytorchmergebot
authored andcommitted
Docs: fix docstring errors in ddp_comm_hooks (pytorch#116866)
Reopens pytorch#115272 Fixes ddp_comm_hooks errors in pytorch#112644 Pull Request resolved: pytorch#116866 Approved by: https://github.com/awgu
1 parent 16d6929 commit b4f1ab4

File tree

9 files changed

+79
-48
lines changed

9 files changed

+79
-48
lines changed

torch/distributed/algorithms/ddp_comm_hooks/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def _powerSGD_comm_hook_wrapper(
2525
start_powerSGD_iter=1_000,
2626
):
2727
"""
28+
Wrap PowerSGD communication hook.
29+
2830
To be consistent with the wrappers of other DDP comm hooks, the input state only needs to be a process group,
2931
which will be wrapped up with other state info.
3032
"""
@@ -38,6 +40,8 @@ def _powerSGD_comm_hook_wrapper(
3840

3941
class DDPCommHookType(Enum):
4042
"""
43+
Enumerate ``ddp_comm_hooks`` and ``ddp_comm_hook_wrapper`` communucation hook types.
44+
4145
DDPCommHookType enumerates the hooks of ``torch.distributed.algorithms.ddp_comm_hooks``
4246
as names and ``ddp_comm_hook_wrapper`` partials with hook specified. As an example,
4347
you can register allreduce hook by
@@ -89,6 +93,8 @@ def register_ddp_comm_hook(
8993
comm_hook_type: DDPCommHookType, model, state=None
9094
):
9195
"""
96+
Register ``ddp_comm_hooks`` to DDP model.
97+
9298
Registers the hooks of ``torch.distributed.algorithms.ddp_comm_hooks``
9399
to the DDP model. User can specify the type of hook as an enum
94100
``DDPCommHookType`` type using ``comm_hook_type`` input. State input will

torch/distributed/algorithms/ddp_comm_hooks/ddp_zero_hook.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def _perform_local_step(
2424
rank: int,
2525
):
2626
r"""
27-
Performs a local optimizer step using the gradients provided by ``bucket``.
27+
Perform a local optimizer step using the gradients provided by ``bucket``.
2828
2929
Arguments:
3030
bucket (dist.GradBucket): the bucket providing the gradients.
@@ -98,10 +98,10 @@ def _save_ddp_bucket_info(
9898
zero: ZeroRedundancyOptimizer,
9999
):
100100
r"""
101-
Saves :class:`DistributedDataParallel` gradient bucket information for the
102-
:class:`ZeroRedundancyOptimizer` instance ``zero`` to use when overlapping.
101+
Save :class:`DistributedDataParallel` gradient bucket information for :class:`ZeroRedundancyOptimizer` instance ``zero``.
102+
103103
In particular, this function is meant to be called upon seeing each
104-
gradient bucket, meaning it does not save or compute any global
104+
gradient bucket to use when overlapping, meaning it does not save or compute any global
105105
information.
106106
107107
Arguments:
@@ -130,8 +130,9 @@ def _hook_with_zero_step_setup(
130130
bucket: dist.GradBucket,
131131
):
132132
r"""
133-
Encapsulates the setup logic for :func:`hook_with_zero_step` and
134-
:func:`hook_with_zero_step_interleaved`, meaning the logic to run in the
133+
Encapsulate the setup logic for :func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`.
134+
135+
This means the logic to run in the
135136
hook before the backward pass and optimizer step can actually be
136137
overlapped. This is factored out since it is common to both
137138
:func:`hook_with_zero_step` and :func:`hook_with_zero_step_interleaved`.
@@ -172,16 +173,14 @@ def hook_with_zero_step(
172173
shard_buckets: bool = False,
173174
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
174175
r"""
175-
Modifies the given ``hook`` to overlap the :class:`ZeroRedundancyOptimizer`
176-
optimizer step with the :class:`DistributedDataParallel` backward pass,
177-
where the optimizer step computation begins after the last gradient bucket
178-
computation has finished.
176+
Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass.
179177
180178
This approach overlaps the optimizer computation and communication with the
181179
backward communication. In particular, the backward computation proceeds
182180
contiguously, and the optimizer computation follows, overlapping with
183181
outstanding backward communication (i.e. all-reduces) and possibly other
184182
optimizer communication (i.e. broadcasts).
183+
The optimizer step computation begins after the last gradient bucket computation has finished.
185184
186185
This approach may be preferred over :meth:`hook_with_zero_step_interleaved`
187186
if communication is relatively slow compared to computation.
@@ -244,11 +243,11 @@ def hook_with_zero_fn(
244243
bucket: dist.GradBucket,
245244
) -> torch.futures.Future[torch.Tensor]:
246245
r"""
247-
Returns a :class:`Future` that gives a gradient bucket tensor and
248-
performs the equivalent of a :class:`ZeroRedundancyOptimizer`
249-
:meth:`step` if ``bucket`` is the last gradient bucket.
246+
Return :class:`Future` that runs the optimizer step if this corresponds to the last gradient bucket.
250247
251-
The function performs additional computation on the iteration that
248+
Perform equivalent of :class:`ZeroRedundancyOptimizer` :meth:`step` if ``bucket`` is last gradient bucket.
249+
The function gives a gradient bucket tensor and
250+
performs additional computation on the iteration that
252251
the :class:`DistributedDataParallel` buckets are rebuilt to collect
253252
information used to implement the modified hook.
254253
@@ -331,10 +330,7 @@ def hook_with_zero_step_interleaved(
331330
shard_buckets: bool = False,
332331
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
333332
r"""
334-
Modifies the given ``hook`` to overlap the :class:`ZeroRedundancyOptimizer`
335-
optimizer step with the :class:`DistributedDataParallel` backward pass,
336-
where the optimizer step computation interleaves with the backward
337-
computation.
333+
Modify ``hook`` to overlap :class:`ZeroRedundancyOptimizer` optimizer step with :class:`DistributedDataParallel` backward pass
338334
339335
This approach overlaps the optimizer computation and communication with the
340336
backward computation and communication. In particular, once a bucket's
@@ -404,9 +400,11 @@ def hook_with_zero_interleaved_fn(
404400
bucket: dist.GradBucket,
405401
) -> torch.futures.Future[torch.Tensor]:
406402
r"""
407-
Returns a :class:`Future` that gives a gradient bucket tensor and
408-
performs a partial :class:`ZeroRedundancyOptimizer` :meth:`step` using
409-
the gradients in that bucket.
403+
Return :class:`Future` that gives gradient bucket tensor and performs partial :class:`ZeroRedundancyOptimizer` :meth:`step`.
404+
405+
This function uses the gradients in gradient in given bucket to perform a partial
406+
:class:`ZeroRedundancyOptimizer` :meth:`step`
407+
410408
Arguments:
411409
state: any state for the hook.
412410
bucket (dist.GradBucket): the :class:`DistributedDataParallel`
@@ -419,9 +417,7 @@ def hook_with_zero_interleaved_fn(
419417

420418
def zero_step(fut: torch.futures.Future) -> torch.Tensor:
421419
r"""
422-
Performs a partial :class:`ZeroRedundancyOptimizer` :meth:`step`
423-
using the gradients in the given :class:`DistributedDataParallel`
424-
gradient bucket.
420+
Perform partial :class:`ZeroRedundancyOptimizer` :meth:`step` using gradients in the :class:`DistributedDataParallel`.
425421
426422
Returns:
427423
A :class:`torch.Tensor` representing the contents of the

torch/distributed/algorithms/ddp_comm_hooks/debugging_hooks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88

99
def noop_hook(_: Any, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]:
1010
"""
11-
This DDP communication hook returns a future that wraps the input,
12-
so it is a noop that does not incur any communication overheads.
11+
Return a future that wraps the input, so it is a no-op that does not incur any communication overheads.
1312
1413
This hook should **only** be used for headroom analysis of allreduce optimization,
1514
instead of the normal gradient synchronization.

torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
def _allreduce_fut(
99
process_group: dist.ProcessGroup, tensor: torch.Tensor
1010
) -> torch.futures.Future[torch.Tensor]:
11-
"Averages the input gradient tensor by allreduce and returns a future."
11+
"""Average the input gradient tensor by allreduce and returns a future."""
1212
group_to_use = process_group if process_group is not None else dist.group.WORLD
1313

1414
# Apply the division first to avoid overflow, especially for FP16.
@@ -25,9 +25,12 @@ def allreduce_hook(
2525
process_group: dist.ProcessGroup, bucket: dist.GradBucket
2626
) -> torch.futures.Future[torch.Tensor]:
2727
"""
28-
This DDP communication hook just calls ``allreduce`` using ``GradBucket``
29-
tensors. Once gradient tensors are aggregated across all workers, its ``then``
30-
callback takes the mean and returns the result. If user registers this hook,
28+
Call ``allreduce`` using ``GradBucket`` tensors.
29+
30+
Once gradient tensors are aggregated across all workers, its ``then``
31+
callback takes the mean and returns the result.
32+
33+
If user registers this DDP communication hook,
3134
DDP results is expected to be same as the case where no hook was registered.
3235
Hence, this won't change behavior of DDP and user can use this as a reference
3336
or modify this hook to log useful information or any other purposes while
@@ -44,6 +47,8 @@ def fp16_compress_hook(
4447
process_group: dist.ProcessGroup, bucket: dist.GradBucket
4548
) -> torch.futures.Future[torch.Tensor]:
4649
"""
50+
Compress by casting ``GradBucket`` to ``torch.float16`` divided by process group size.
51+
4752
This DDP communication hook implements a simple gradient compression
4853
approach that casts ``GradBucket`` tensor to half-precision floating-point format (``torch.float16``)
4954
and then divides it by the process group size.
@@ -113,10 +118,11 @@ def fp16_compress_wrapper(
113118
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]
114119
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
115120
"""
121+
Cast input tensor to ``torch.float16``, cast result of hook back to input dtype.
122+
116123
This wrapper casts the input gradient tensor of a given DDP communication hook to half-precision
117124
floating point format (``torch.float16``), and casts the resulting tensor of the given hook back to
118125
the input data type, such as ``float32``.
119-
120126
Therefore, ``fp16_compress_hook`` is equivalent to ``fp16_compress_wrapper(allreduce_hook)``.
121127
122128
Example::

torch/distributed/algorithms/ddp_comm_hooks/mixed_precision_hooks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
class _AllreduceUpcastHookState:
1111
"""
1212
State to manage DDP mixed precision in backward / gradient communication.
13+
1314
This contains a weakref to the DDP module for access to reducer and process
1415
group, and a stream to run parameter and gradient upcasts.
1516
"""
17+
1618
ddp_weakref: Any
1719
upcast_stream: torch.cuda.Stream
1820
wait_for_stream_enqueued: bool = False
@@ -22,6 +24,8 @@ def _reducer_allreduce_and_upcast_hook(
2224
hook_state: _AllreduceUpcastHookState, bucket: dist.GradBucket
2325
) -> torch.futures.Future[torch.Tensor]:
2426
"""
27+
Perform allreduce in precision ``reduce_dtype``, upcast to prepare for optimizer.
28+
2529
Performs allreduce in the reduced precision given by DDP's mixed precision
2630
reduce_dtype, and upcasts parameters and gradients to fp32 in preparation
2731
to run the optimizer.

torch/distributed/algorithms/ddp_comm_hooks/optimizer_overlap_hooks.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
class _OptimizerHookState:
1414
"""
1515
Holds state for running optimizer in-line after DDP communication hook.
16+
1617
Currently contains only optimizer class which must have a method `step_param`.
1718
"""
1819

@@ -45,6 +46,8 @@ def _apply_optim_in_backward_hook(
4546
gradient_is_bucket_view: bool
4647
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
4748
r"""
49+
Register hook to apply the optimizer in backward.
50+
4851
If torch.distributed.optim._apply_optimizer_in_backward is used to overlap
4952
optimizer with backward pass, DDP will run the below hook to run optimizer
5053
step for parameters after gradient communication has taken place.
@@ -123,9 +126,7 @@ def _hook_then_optimizer(
123126
hook: Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]],
124127
optimizer_state: _OptimizerHookState,
125128
) -> Callable[[Any, dist.GradBucket], torch.futures.Future[torch.Tensor]]:
126-
r"""
127-
Runs optimizer in a functional fashion after DDP communication hook.
128-
"""
129+
r"""Run optimizer in a functional fashion after DDP communication hook."""
129130
has_set_params = (
130131
hasattr(optimizer_state, 'params_to_optimize')
131132
and optimizer_state.params_to_optimize is not None

torch/distributed/algorithms/ddp_comm_hooks/post_localSGD_hook.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
class PostLocalSGDState:
1212
r"""
13+
Store state for all-reducing gradients globally until given step, then locally after.
14+
1315
Stores the state for all-reducing gradients globally using ``process_group`` until step ``start_localSGD_iter``,
1416
and all-reducing gradients locally using ``subgroup`` afterwards.
1517
@@ -35,6 +37,7 @@ def __init__(
3537
start_localSGD_iter,
3638
post_local_gradient_allreduce=True,
3739
):
40+
"""Initialize state object with given parameters and log when localSGD start."""
3841
logger.info(
3942
"Local SGD will be started after %s iterations", start_localSGD_iter
4043
)
@@ -51,6 +54,7 @@ def __init__(
5154
self.iter = 0
5255

5356
def maybe_increase_iter(self, bucket):
57+
"""Track iterations and trigger log message at start of local SGD."""
5458
# Since bucket 0 is the last bucket to allreduce in an iteration.
5559
# Only increase `iter` when bucket 0 is processed.
5660
if bucket.is_last():
@@ -61,11 +65,12 @@ def maybe_increase_iter(self, bucket):
6165
"Start to apply local SGD after %s iterations.", self.iter
6266
)
6367

64-
6568
def post_localSGD_hook(
6669
state: PostLocalSGDState, bucket: dist.GradBucket
6770
) -> torch.futures.Future[torch.Tensor]:
6871
"""
72+
Run post-localSGD algorithm.
73+
6974
This DDP communication hook is used for running post-localSGD algorithm,
7075
by combining with a model averaging component (e.g.,
7176
:class:`~torch.distributed.algorithms.model_averaging.averagers.PeriodicModelAverager`)

torch/distributed/algorithms/ddp_comm_hooks/powerSGD_hook.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
def _orthogonalize(matrices, epsilon=0):
2020
"""
2121
Decide between Gram-Schmidt or QR factorization to orthogonalize a batch of matrices.
22+
2223
QR factorization doesn't work with half-precision, but it is usually faster with a rank > 2.
2324
"""
2425
assert len(matrices.shape) == 3 and matrices.shape[2] <= matrices.shape[1]
@@ -39,7 +40,8 @@ def _orthogonalize(matrices, epsilon=0):
3940

4041
def _orthogonalize_gram_schmidt(matrices, epsilon=0):
4142
"""
42-
Applies Gram-Schmidt procedure to orthogonalize a batch of matrices.
43+
Apply Gram-Schmidt procedure to orthogonalize a batch of matrices.
44+
4345
If epsilon is 0, this is equivalent to `torch.qr(matrices, out=(matrices, _))`,
4446
"""
4547
num_cols = matrices.shape[2]
@@ -73,6 +75,8 @@ def _should_compress(
7375
num_rows, num_cols, matrix_approximation_rank, min_compression_rate
7476
):
7577
"""
78+
Recommend if tensor given is worth compressing.
79+
7680
Returns a recommendation as to whether the 2D tensor described by the arguments is worth compressing,
7781
including statistics describing the expected savings from compression. We consider a tensor worth
7882
compressing when ``min_compression_rate`` < uncompressed size / compressed size, where
@@ -97,9 +101,7 @@ def _should_compress(
97101

98102

99103
def _report_compression_stats(bucket, state):
100-
"""
101-
Report compression stats at the frequency of `compression_stats_logging_frequency` specified in PowerSGD state.
102-
"""
104+
"""Report compression stats at frequency of ``compression_stats_logging_frequency`` specified in PowerSGD state."""
103105
if (
104106
bucket.is_last()
105107
and state.iter >= state.next_stats_report
@@ -114,7 +116,8 @@ def _report_compression_stats(bucket, state):
114116

115117
class PowerSGDState:
116118
r"""
117-
Stores both the algorithm's hyperparameters and the internal state for all the gradients during the training.
119+
Store both the algorithm's hyperparameters and internal state for all gradients during training.
120+
118121
Particularly, ``matrix_approximation_rank`` and ``start_powerSGD_iter`` are the main hyperparameters that should be tuned by the user.
119122
For performance, we suggest to keep binary hyperparameters ``use_error_feedback`` and ``warm_start`` on.
120123
@@ -266,7 +269,8 @@ def __init__(
266269

267270
def __getstate__(self):
268271
r"""
269-
Returns a ``Dict[str, Any]`` which will be pickled and saved.
272+
Return a ``Dict[str, Any]`` which will be pickled and saved.
273+
270274
``process_group`` is not serializable and excluded from
271275
a returned state.
272276
"""
@@ -280,7 +284,8 @@ def __getstate__(self):
280284

281285
def __setstate__(self, state):
282286
r"""
283-
Takes a provided ``state`` and retrieves ``PowerSGDState``.
287+
Take a provided ``state`` and set to this ``PowerSGDState`` instance.
288+
284289
``process_group`` is set to default.
285290
"""
286291
self.process_group = distributed_c10d._get_default_group()
@@ -292,6 +297,7 @@ def __setstate__(self, state):
292297
setattr(self, slot, value)
293298

294299
def maybe_increase_iter(self, bucket):
300+
"""Track iterations and trigger log message at start of local SGD."""
295301
# Since bucket 0 is the last bucket to allreduce in an iteration.
296302
# Only increase `iter` when bucket 0 is processed.
297303
if bucket.is_last():
@@ -304,7 +310,9 @@ def maybe_increase_iter(self, bucket):
304310

305311
def compression_stats(self):
306312
r"""
307-
Returns the latest compression statistics as a tuple of the form (compress_rate, numel_before_compression, numel_after_compression), where:
313+
Return latest compression statistics as tuple.
314+
315+
Returns tuple of form (compress_rate, numel_before_compression, numel_after_compression) where:
308316
309317
compress_rate is the effective compression rate i.e. (number of elements before compression) / (number of elements after compression);
310318
@@ -328,6 +336,8 @@ def powerSGD_hook(
328336
state: PowerSGDState, bucket: dist.GradBucket
329337
) -> torch.futures.Future[torch.Tensor]:
330338
r"""
339+
Implement PowerSGD algorithm.
340+
331341
This DDP communication hook implements PowerSGD gradient compression
332342
algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_.
333343
Once gradient tensors are aggregated across all workers, this hook applies
@@ -636,6 +646,8 @@ def batched_powerSGD_hook(
636646
state: PowerSGDState, bucket: dist.GradBucket
637647
) -> torch.futures.Future[torch.Tensor]:
638648
r"""
649+
Implement simplified PowerSGD algorithm.
650+
639651
This DDP communication hook implements a simplified PowerSGD gradient compression
640652
algorithm described in the `paper <https://arxiv.org/abs/1905.13727>`_.
641653
This variant does not compress the gradients layer by layer,
@@ -750,7 +762,7 @@ def batched_powerSGD_hook(
750762
)
751763

752764
def create_low_rank_tensor(fill_random_values, rng):
753-
"Returns a low-rank 2D tensor of square_side_length * matrix_approximation_rank."
765+
"""Return a low-rank 2D tensor of square_side_length * matrix_approximation_rank."""
754766
if fill_random_values:
755767
with torch.random.fork_rng(devices=[]):
756768
# Fork this RNG to avoid changing the seed globally and affecting the random sampling

0 commit comments

Comments
 (0)