Skip to content

Commit 5d3a347

Browse files
Stashing checkpointing RNG states based on devices of arg tensors (pytorch#14518)
Summary: This PR intends to address apaszke's concerns in pytorch#14253 (comment). Preserving the rng state is now controlled by a kwarg rather than a global state, hopefully in a python 2.7-compatible way. Additionally, the checkpointing function stashes and restores the RNG states of 1. devices associated with all input tensor args to run_fn as well as 2. the current device. I could easily change this to only save and restore the RNG states associated 1. alone. This would simplify the logic to create a [deduplicated, ordered](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R37) list of devices considered active. I'm wondering if the [get_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R32) and [set_device_states](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) functions are general enough to reside elsewhere (presumably torch/random.py). I'm also wondering if the check on [torch.cuda._initialized](https://github.com/pytorch/pytorch/compare/master...mcarilli:checkpointing_rng_touchup?expand=1#diff-58da227fc9b1d56752b7dfad90428fe0R47) would be better placed within `get_device_states`. Pull Request resolved: pytorch#14518 Differential Revision: D13356210 Pulled By: ezyang fbshipit-source-id: afa4cc21ce7862142d5cb1dec3750018df222039
1 parent 25ddd65 commit 5d3a347

File tree

2 files changed

+65
-29
lines changed

2 files changed

+65
-29
lines changed

docs/source/checkpoint.rst

+10-2
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,18 @@ torch.utils.checkpoint
1111
compared to non-checkpointed passes. The logic to stash and restore
1212
RNG states can incur a moderate performance hit depending on the runtime
1313
of checkpointed operations. If deterministic output compared to
14-
non-checkpointed passes is not required, set the global flag
15-
``torch.utils.checkpoint.preserve_rng_state=False`` to omit stashing and
14+
non-checkpointed passes is not required, supply ``preserve_rng_state=False``
15+
to ``checkpoint`` or ``checkpoint_sequential`` to omit stashing and
1616
restoring the RNG state during each checkpoint.
1717

18+
The stashing logic saves and restores the RNG state for the current device
19+
and the device of all cuda Tensor arguments to the ``run_fn``.
20+
However, the logic has no way to anticipate if the user will move
21+
Tensors to a new device within the ``run_fn`` itself. Therefore, if you move
22+
Tensors to a new device ("new" meaning not belonging to the set of
23+
[current device + devices of Tensor arguments]) within ``run_fn``, deterministic
24+
output compared to non-checkpointed passes is never guaranteed.
25+
1826
.. currentmodule:: torch.utils.checkpoint
1927
.. autofunction:: checkpoint
2028
.. autofunction:: checkpoint_sequential

torch/utils/checkpoint.py

+55-27
Original file line numberDiff line numberDiff line change
@@ -21,38 +21,50 @@ def check_backward_validity(inputs):
2121
warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")
2222

2323

24-
# Global switch to toggle whether or not checkpointed passes stash and restore
25-
# the RNG state. If True, any checkpoints making use of RNG should achieve deterministic
26-
# output compared to non-checkpointed passes.
27-
preserve_rng_state = True
24+
# We can't know if the run_fn will internally move some args to different devices,
25+
# which would require logic to preserve rng states for those devices as well.
26+
# We could paranoically stash and restore ALL the rng states for all visible devices,
27+
# but that seems very wasteful for most cases. Compromise: Stash the RNG state for
28+
# the device of all Tensor args.
29+
#
30+
# To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
31+
def get_device_states(*args):
32+
# This will not error out if "arg" is a CPU tensor or a non-tensor type because
33+
# the conditionals short-circuit.
34+
fwd_gpu_devices = list(set(arg.get_device() for arg in args
35+
if isinstance(arg, torch.Tensor) and arg.is_cuda))
36+
37+
fwd_gpu_states = []
38+
for device in fwd_gpu_devices:
39+
with torch.cuda.device(device):
40+
fwd_gpu_states.append(torch.cuda.get_rng_state())
41+
42+
return fwd_gpu_devices, fwd_gpu_states
43+
44+
45+
def set_device_states(devices, states):
46+
for device, state in zip(devices, states):
47+
with torch.cuda.device(device):
48+
torch.cuda.set_rng_state(state)
2849

2950

3051
class CheckpointFunction(torch.autograd.Function):
3152

3253
@staticmethod
33-
def forward(ctx, run_function, *args):
54+
def forward(ctx, run_function, preserve_rng_state, *args):
3455
check_backward_validity(args)
3556
ctx.run_function = run_function
57+
ctx.preserve_rng_state = preserve_rng_state
3658
if preserve_rng_state:
37-
# We can't know if the user will transfer some args from the host
38-
# to the device during their run_fn. Therefore, we stash both
39-
# the cpu and cuda rng states unconditionally.
40-
#
41-
# TODO:
42-
# We also can't know if the run_fn will internally move some args to a device
43-
# other than the current device, which would require logic to preserve
44-
# rng states for those devices as well. We could paranoically stash and restore
45-
# ALL the rng states for all visible devices, but that seems very wasteful for
46-
# most cases.
47-
ctx.fwd_cpu_rng_state = torch.get_rng_state()
59+
ctx.fwd_cpu_state = torch.get_rng_state()
4860
# Don't eagerly initialize the cuda context by accident.
4961
# (If the user intends that the context is initialized later, within their
5062
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
5163
# we have no way to anticipate this will happen before we run the function.)
5264
ctx.had_cuda_in_fwd = False
5365
if torch.cuda._initialized:
5466
ctx.had_cuda_in_fwd = True
55-
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()
67+
ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)
5668
ctx.save_for_backward(*args)
5769
with torch.no_grad():
5870
outputs = run_function(*args)
@@ -66,23 +78,25 @@ def backward(ctx, *args):
6678
# Stash the surrounding rng state, and mimic the state that was
6779
# present at this time during forward. Restore the surrouding state
6880
# when we're done.
69-
rng_devices = [torch.cuda.current_device()] if ctx.had_cuda_in_fwd else []
70-
with torch.random.fork_rng(devices=rng_devices, enabled=preserve_rng_state):
71-
if preserve_rng_state:
72-
torch.set_rng_state(ctx.fwd_cpu_rng_state)
81+
rng_devices = []
82+
if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
83+
rng_devices = ctx.fwd_gpu_devices
84+
with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
85+
if ctx.preserve_rng_state:
86+
torch.set_rng_state(ctx.fwd_cpu_state)
7387
if ctx.had_cuda_in_fwd:
74-
torch.cuda.set_rng_state(ctx.fwd_cuda_rng_state)
88+
set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
7589
detached_inputs = detach_variable(inputs)
7690
with torch.enable_grad():
7791
outputs = ctx.run_function(*detached_inputs)
7892

7993
if isinstance(outputs, torch.Tensor):
8094
outputs = (outputs,)
8195
torch.autograd.backward(outputs, args)
82-
return (None,) + tuple(inp.grad for inp in detached_inputs)
96+
return (None, None) + tuple(inp.grad for inp in detached_inputs)
8397

8498

85-
def checkpoint(function, *args):
99+
def checkpoint(function, *args, **kwargs):
86100
r"""Checkpoint a model or part of the model
87101
88102
Checkpointing works by trading compute for memory. Rather than storing all
@@ -120,15 +134,22 @@ def checkpoint(function, *args):
120134
passed as the tuple. For example, in LSTM, if user passes
121135
``(activation, hidden)``, :attr:`function` should correctly use the
122136
first input as ``activation`` and the second input as ``hidden``
137+
preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
138+
the RNG state during each checkpoint.
123139
args: tuple containing inputs to the :attr:`function`
124140
125141
Returns:
126142
Output of running :attr:`function` on :attr:`*args`
127143
"""
128-
return CheckpointFunction.apply(function, *args)
144+
# Hack to mix *args with **kwargs in a python 2.7-compliant way
145+
preserve = kwargs.pop('preserve_rng_state', True)
146+
if kwargs:
147+
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
148+
149+
return CheckpointFunction.apply(function, preserve, *args)
129150

130151

131-
def checkpoint_sequential(functions, segments, *inputs):
152+
def checkpoint_sequential(functions, segments, *inputs, **kwargs):
132153
r"""A helper function for checkpointing sequential models.
133154
134155
Sequential models execute a list of modules/functions in order
@@ -154,6 +175,8 @@ def checkpoint_sequential(functions, segments, *inputs):
154175
functions (comprising the model) to run sequentially.
155176
segments: Number of chunks to create in the model
156177
inputs: tuple of Tensors that are inputs to :attr:`functions`
178+
preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
179+
the RNG state during each checkpoint.
157180
158181
Returns:
159182
Output of running :attr:`functions` sequentially on :attr:`*inputs`
@@ -162,6 +185,10 @@ def checkpoint_sequential(functions, segments, *inputs):
162185
>>> model = nn.Sequential(...)
163186
>>> input_var = checkpoint_sequential(model, chunks, input_var)
164187
"""
188+
# Hack to mix *args with **kwargs in a python 2.7-compliant way
189+
preserve = kwargs.pop('preserve_rng_state', True)
190+
if kwargs:
191+
raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))
165192

166193
def run_function(start, end, functions):
167194
def forward(*inputs):
@@ -181,7 +208,8 @@ def forward(*inputs):
181208
end = -1
182209
for start in range(0, segment_size * (segments - 1), segment_size):
183210
end = start + segment_size - 1
184-
inputs = checkpoint(run_function(start, end, functions), *inputs)
211+
inputs = checkpoint(run_function(start, end, functions), *inputs,
212+
preserve_rng_state=preserve)
185213
if not isinstance(inputs, tuple):
186214
inputs = (inputs,)
187215
return run_function(end + 1, len(functions) - 1, functions)(*inputs)

0 commit comments

Comments
 (0)