@@ -21,38 +21,50 @@ def check_backward_validity(inputs):
21
21
warnings .warn ("None of the inputs have requires_grad=True. Gradients will be None" )
22
22
23
23
24
- # Global switch to toggle whether or not checkpointed passes stash and restore
25
- # the RNG state. If True, any checkpoints making use of RNG should achieve deterministic
26
- # output compared to non-checkpointed passes.
27
- preserve_rng_state = True
24
+ # We can't know if the run_fn will internally move some args to different devices,
25
+ # which would require logic to preserve rng states for those devices as well.
26
+ # We could paranoically stash and restore ALL the rng states for all visible devices,
27
+ # but that seems very wasteful for most cases. Compromise: Stash the RNG state for
28
+ # the device of all Tensor args.
29
+ #
30
+ # To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
31
+ def get_device_states (* args ):
32
+ # This will not error out if "arg" is a CPU tensor or a non-tensor type because
33
+ # the conditionals short-circuit.
34
+ fwd_gpu_devices = list (set (arg .get_device () for arg in args
35
+ if isinstance (arg , torch .Tensor ) and arg .is_cuda ))
36
+
37
+ fwd_gpu_states = []
38
+ for device in fwd_gpu_devices :
39
+ with torch .cuda .device (device ):
40
+ fwd_gpu_states .append (torch .cuda .get_rng_state ())
41
+
42
+ return fwd_gpu_devices , fwd_gpu_states
43
+
44
+
45
+ def set_device_states (devices , states ):
46
+ for device , state in zip (devices , states ):
47
+ with torch .cuda .device (device ):
48
+ torch .cuda .set_rng_state (state )
28
49
29
50
30
51
class CheckpointFunction (torch .autograd .Function ):
31
52
32
53
@staticmethod
33
- def forward (ctx , run_function , * args ):
54
+ def forward (ctx , run_function , preserve_rng_state , * args ):
34
55
check_backward_validity (args )
35
56
ctx .run_function = run_function
57
+ ctx .preserve_rng_state = preserve_rng_state
36
58
if preserve_rng_state :
37
- # We can't know if the user will transfer some args from the host
38
- # to the device during their run_fn. Therefore, we stash both
39
- # the cpu and cuda rng states unconditionally.
40
- #
41
- # TODO:
42
- # We also can't know if the run_fn will internally move some args to a device
43
- # other than the current device, which would require logic to preserve
44
- # rng states for those devices as well. We could paranoically stash and restore
45
- # ALL the rng states for all visible devices, but that seems very wasteful for
46
- # most cases.
47
- ctx .fwd_cpu_rng_state = torch .get_rng_state ()
59
+ ctx .fwd_cpu_state = torch .get_rng_state ()
48
60
# Don't eagerly initialize the cuda context by accident.
49
61
# (If the user intends that the context is initialized later, within their
50
62
# run_function, we SHOULD actually stash the cuda state here. Unfortunately,
51
63
# we have no way to anticipate this will happen before we run the function.)
52
64
ctx .had_cuda_in_fwd = False
53
65
if torch .cuda ._initialized :
54
66
ctx .had_cuda_in_fwd = True
55
- ctx .fwd_cuda_rng_state = torch . cuda . get_rng_state ( )
67
+ ctx .fwd_gpu_devices , ctx . fwd_gpu_states = get_device_states ( * args )
56
68
ctx .save_for_backward (* args )
57
69
with torch .no_grad ():
58
70
outputs = run_function (* args )
@@ -66,23 +78,25 @@ def backward(ctx, *args):
66
78
# Stash the surrounding rng state, and mimic the state that was
67
79
# present at this time during forward. Restore the surrouding state
68
80
# when we're done.
69
- rng_devices = [torch .cuda .current_device ()] if ctx .had_cuda_in_fwd else []
70
- with torch .random .fork_rng (devices = rng_devices , enabled = preserve_rng_state ):
71
- if preserve_rng_state :
72
- torch .set_rng_state (ctx .fwd_cpu_rng_state )
81
+ rng_devices = []
82
+ if ctx .preserve_rng_state and ctx .had_cuda_in_fwd :
83
+ rng_devices = ctx .fwd_gpu_devices
84
+ with torch .random .fork_rng (devices = rng_devices , enabled = ctx .preserve_rng_state ):
85
+ if ctx .preserve_rng_state :
86
+ torch .set_rng_state (ctx .fwd_cpu_state )
73
87
if ctx .had_cuda_in_fwd :
74
- torch . cuda . set_rng_state (ctx .fwd_cuda_rng_state )
88
+ set_device_states (ctx .fwd_gpu_devices , ctx . fwd_gpu_states )
75
89
detached_inputs = detach_variable (inputs )
76
90
with torch .enable_grad ():
77
91
outputs = ctx .run_function (* detached_inputs )
78
92
79
93
if isinstance (outputs , torch .Tensor ):
80
94
outputs = (outputs ,)
81
95
torch .autograd .backward (outputs , args )
82
- return (None ,) + tuple (inp .grad for inp in detached_inputs )
96
+ return (None , None ) + tuple (inp .grad for inp in detached_inputs )
83
97
84
98
85
- def checkpoint (function , * args ):
99
+ def checkpoint (function , * args , ** kwargs ):
86
100
r"""Checkpoint a model or part of the model
87
101
88
102
Checkpointing works by trading compute for memory. Rather than storing all
@@ -120,15 +134,22 @@ def checkpoint(function, *args):
120
134
passed as the tuple. For example, in LSTM, if user passes
121
135
``(activation, hidden)``, :attr:`function` should correctly use the
122
136
first input as ``activation`` and the second input as ``hidden``
137
+ preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
138
+ the RNG state during each checkpoint.
123
139
args: tuple containing inputs to the :attr:`function`
124
140
125
141
Returns:
126
142
Output of running :attr:`function` on :attr:`*args`
127
143
"""
128
- return CheckpointFunction .apply (function , * args )
144
+ # Hack to mix *args with **kwargs in a python 2.7-compliant way
145
+ preserve = kwargs .pop ('preserve_rng_state' , True )
146
+ if kwargs :
147
+ raise ValueError ("Unexpected keyword arguments: " + "," .join (arg for arg in kwargs ))
148
+
149
+ return CheckpointFunction .apply (function , preserve , * args )
129
150
130
151
131
- def checkpoint_sequential (functions , segments , * inputs ):
152
+ def checkpoint_sequential (functions , segments , * inputs , ** kwargs ):
132
153
r"""A helper function for checkpointing sequential models.
133
154
134
155
Sequential models execute a list of modules/functions in order
@@ -154,6 +175,8 @@ def checkpoint_sequential(functions, segments, *inputs):
154
175
functions (comprising the model) to run sequentially.
155
176
segments: Number of chunks to create in the model
156
177
inputs: tuple of Tensors that are inputs to :attr:`functions`
178
+ preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
179
+ the RNG state during each checkpoint.
157
180
158
181
Returns:
159
182
Output of running :attr:`functions` sequentially on :attr:`*inputs`
@@ -162,6 +185,10 @@ def checkpoint_sequential(functions, segments, *inputs):
162
185
>>> model = nn.Sequential(...)
163
186
>>> input_var = checkpoint_sequential(model, chunks, input_var)
164
187
"""
188
+ # Hack to mix *args with **kwargs in a python 2.7-compliant way
189
+ preserve = kwargs .pop ('preserve_rng_state' , True )
190
+ if kwargs :
191
+ raise ValueError ("Unexpected keyword arguments: " + "," .join (arg for arg in kwargs ))
165
192
166
193
def run_function (start , end , functions ):
167
194
def forward (* inputs ):
@@ -181,7 +208,8 @@ def forward(*inputs):
181
208
end = - 1
182
209
for start in range (0 , segment_size * (segments - 1 ), segment_size ):
183
210
end = start + segment_size - 1
184
- inputs = checkpoint (run_function (start , end , functions ), * inputs )
211
+ inputs = checkpoint (run_function (start , end , functions ), * inputs ,
212
+ preserve_rng_state = preserve )
185
213
if not isinstance (inputs , tuple ):
186
214
inputs = (inputs ,)
187
215
return run_function (end + 1 , len (functions ) - 1 , functions )(* inputs )
0 commit comments