@@ -21,38 +21,50 @@ def check_backward_validity(inputs):
2121 warnings .warn ("None of the inputs have requires_grad=True. Gradients will be None" )
2222
2323
24- # Global switch to toggle whether or not checkpointed passes stash and restore
25- # the RNG state. If True, any checkpoints making use of RNG should achieve deterministic
26- # output compared to non-checkpointed passes.
27- preserve_rng_state = True
24+ # We can't know if the run_fn will internally move some args to different devices,
25+ # which would require logic to preserve rng states for those devices as well.
26+ # We could paranoically stash and restore ALL the rng states for all visible devices,
27+ # but that seems very wasteful for most cases. Compromise: Stash the RNG state for
28+ # the device of all Tensor args.
29+ #
30+ # To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
31+ def get_device_states (* args ):
32+ # This will not error out if "arg" is a CPU tensor or a non-tensor type because
33+ # the conditionals short-circuit.
34+ fwd_gpu_devices = list (set (arg .get_device () for arg in args
35+ if isinstance (arg , torch .Tensor ) and arg .is_cuda ))
36+
37+ fwd_gpu_states = []
38+ for device in fwd_gpu_devices :
39+ with torch .cuda .device (device ):
40+ fwd_gpu_states .append (torch .cuda .get_rng_state ())
41+
42+ return fwd_gpu_devices , fwd_gpu_states
43+
44+
45+ def set_device_states (devices , states ):
46+ for device , state in zip (devices , states ):
47+ with torch .cuda .device (device ):
48+ torch .cuda .set_rng_state (state )
2849
2950
3051class CheckpointFunction (torch .autograd .Function ):
3152
3253 @staticmethod
33- def forward (ctx , run_function , * args ):
54+ def forward (ctx , run_function , preserve_rng_state , * args ):
3455 check_backward_validity (args )
3556 ctx .run_function = run_function
57+ ctx .preserve_rng_state = preserve_rng_state
3658 if preserve_rng_state :
37- # We can't know if the user will transfer some args from the host
38- # to the device during their run_fn. Therefore, we stash both
39- # the cpu and cuda rng states unconditionally.
40- #
41- # TODO:
42- # We also can't know if the run_fn will internally move some args to a device
43- # other than the current device, which would require logic to preserve
44- # rng states for those devices as well. We could paranoically stash and restore
45- # ALL the rng states for all visible devices, but that seems very wasteful for
46- # most cases.
47- ctx .fwd_cpu_rng_state = torch .get_rng_state ()
59+ ctx .fwd_cpu_state = torch .get_rng_state ()
4860 # Don't eagerly initialize the cuda context by accident.
4961 # (If the user intends that the context is initialized later, within their
5062 # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
5163 # we have no way to anticipate this will happen before we run the function.)
5264 ctx .had_cuda_in_fwd = False
5365 if torch .cuda ._initialized :
5466 ctx .had_cuda_in_fwd = True
55- ctx .fwd_cuda_rng_state = torch . cuda . get_rng_state ( )
67+ ctx .fwd_gpu_devices , ctx . fwd_gpu_states = get_device_states ( * args )
5668 ctx .save_for_backward (* args )
5769 with torch .no_grad ():
5870 outputs = run_function (* args )
@@ -66,23 +78,25 @@ def backward(ctx, *args):
6678 # Stash the surrounding rng state, and mimic the state that was
6779 # present at this time during forward. Restore the surrouding state
6880 # when we're done.
69- rng_devices = [torch .cuda .current_device ()] if ctx .had_cuda_in_fwd else []
70- with torch .random .fork_rng (devices = rng_devices , enabled = preserve_rng_state ):
71- if preserve_rng_state :
72- torch .set_rng_state (ctx .fwd_cpu_rng_state )
81+ rng_devices = []
82+ if ctx .preserve_rng_state and ctx .had_cuda_in_fwd :
83+ rng_devices = ctx .fwd_gpu_devices
84+ with torch .random .fork_rng (devices = rng_devices , enabled = ctx .preserve_rng_state ):
85+ if ctx .preserve_rng_state :
86+ torch .set_rng_state (ctx .fwd_cpu_state )
7387 if ctx .had_cuda_in_fwd :
74- torch . cuda . set_rng_state (ctx .fwd_cuda_rng_state )
88+ set_device_states (ctx .fwd_gpu_devices , ctx . fwd_gpu_states )
7589 detached_inputs = detach_variable (inputs )
7690 with torch .enable_grad ():
7791 outputs = ctx .run_function (* detached_inputs )
7892
7993 if isinstance (outputs , torch .Tensor ):
8094 outputs = (outputs ,)
8195 torch .autograd .backward (outputs , args )
82- return (None ,) + tuple (inp .grad for inp in detached_inputs )
96+ return (None , None ) + tuple (inp .grad for inp in detached_inputs )
8397
8498
85- def checkpoint (function , * args ):
99+ def checkpoint (function , * args , ** kwargs ):
86100 r"""Checkpoint a model or part of the model
87101
88102 Checkpointing works by trading compute for memory. Rather than storing all
@@ -120,15 +134,22 @@ def checkpoint(function, *args):
120134 passed as the tuple. For example, in LSTM, if user passes
121135 ``(activation, hidden)``, :attr:`function` should correctly use the
122136 first input as ``activation`` and the second input as ``hidden``
137+ preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
138+ the RNG state during each checkpoint.
123139 args: tuple containing inputs to the :attr:`function`
124140
125141 Returns:
126142 Output of running :attr:`function` on :attr:`*args`
127143 """
128- return CheckpointFunction .apply (function , * args )
144+ # Hack to mix *args with **kwargs in a python 2.7-compliant way
145+ preserve = kwargs .pop ('preserve_rng_state' , True )
146+ if kwargs :
147+ raise ValueError ("Unexpected keyword arguments: " + "," .join (arg for arg in kwargs ))
148+
149+ return CheckpointFunction .apply (function , preserve , * args )
129150
130151
131- def checkpoint_sequential (functions , segments , * inputs ):
152+ def checkpoint_sequential (functions , segments , * inputs , ** kwargs ):
132153 r"""A helper function for checkpointing sequential models.
133154
134155 Sequential models execute a list of modules/functions in order
@@ -154,6 +175,8 @@ def checkpoint_sequential(functions, segments, *inputs):
154175 functions (comprising the model) to run sequentially.
155176 segments: Number of chunks to create in the model
156177 inputs: tuple of Tensors that are inputs to :attr:`functions`
178+ preserve_rng_state(bool, optional, default=True): Omit stashing and restoring
179+ the RNG state during each checkpoint.
157180
158181 Returns:
159182 Output of running :attr:`functions` sequentially on :attr:`*inputs`
@@ -162,6 +185,10 @@ def checkpoint_sequential(functions, segments, *inputs):
162185 >>> model = nn.Sequential(...)
163186 >>> input_var = checkpoint_sequential(model, chunks, input_var)
164187 """
188+ # Hack to mix *args with **kwargs in a python 2.7-compliant way
189+ preserve = kwargs .pop ('preserve_rng_state' , True )
190+ if kwargs :
191+ raise ValueError ("Unexpected keyword arguments: " + "," .join (arg for arg in kwargs ))
165192
166193 def run_function (start , end , functions ):
167194 def forward (* inputs ):
@@ -181,7 +208,8 @@ def forward(*inputs):
181208 end = - 1
182209 for start in range (0 , segment_size * (segments - 1 ), segment_size ):
183210 end = start + segment_size - 1
184- inputs = checkpoint (run_function (start , end , functions ), * inputs )
211+ inputs = checkpoint (run_function (start , end , functions ), * inputs ,
212+ preserve_rng_state = preserve )
185213 if not isinstance (inputs , tuple ):
186214 inputs = (inputs ,)
187215 return run_function (end + 1 , len (functions ) - 1 , functions )(* inputs )
0 commit comments