|
| 1 | +.. currentmodule:: torch.func |
| 2 | + |
| 3 | +.. _ux-limitations: |
| 4 | + |
| 5 | +UX Limitations |
| 6 | +============== |
| 7 | + |
| 8 | +torch.func, like `JAX <https://github.com/google/jax>`_, has restrictions around |
| 9 | +what can be transformed. In general, JAX’s limitations are that transforms |
| 10 | +only work with pure functions: that is, functions where the output is completely |
| 11 | +determined by the input and that do not involve side effects (like mutation). |
| 12 | + |
| 13 | +We have a similar guarantee: our transforms work well with pure functions. |
| 14 | +However, we do support certain in-place operations. On one hand, writing code |
| 15 | +compatible with function transforms may involve changing how you write PyTorch |
| 16 | +code, on the other hand, you may find that our transforms let you express things |
| 17 | +that were previously difficult to express in PyTorch. |
| 18 | + |
| 19 | +General limitations |
| 20 | +------------------- |
| 21 | + |
| 22 | +All torch.func transforms share a limitation in that a function should not |
| 23 | +assign to global variables. Instead, all outputs to a function must be returned |
| 24 | +from the function. This restriction comes from how torch.func is implemented: |
| 25 | +each transform wraps Tensor inputs in special torch.func Tensor subclasses |
| 26 | +that facilitate the transform. |
| 27 | + |
| 28 | +So, instead of the following: |
| 29 | + |
| 30 | +:: |
| 31 | + |
| 32 | + import torch |
| 33 | + from torch.func import grad |
| 34 | + |
| 35 | + # Don't do this |
| 36 | + intermediate = None |
| 37 | + |
| 38 | + def f(x): |
| 39 | + global intermediate |
| 40 | + intermediate = x.sin() |
| 41 | + z = intermediate.sin() |
| 42 | + return z |
| 43 | + |
| 44 | + x = torch.randn([]) |
| 45 | + grad_x = grad(f)(x) |
| 46 | + |
| 47 | +Please rewrite ``f`` to return ``intermediate``: |
| 48 | + |
| 49 | +:: |
| 50 | + |
| 51 | + def f(x): |
| 52 | + intermediate = x.sin() |
| 53 | + z = intermediate.sin() |
| 54 | + return z, intermediate |
| 55 | + |
| 56 | + grad_x, intermediate = grad(f, has_aux=True)(x) |
| 57 | + |
| 58 | +torch.autograd APIs |
| 59 | +------------------- |
| 60 | + |
| 61 | +If you are trying to use a ``torch.autograd`` API like ``torch.autograd.grad`` |
| 62 | +or ``torch.autograd.backward`` inside of a function being transformed by |
| 63 | +:func:`vmap` or one of torch.func's AD transforms (:func:`vjp`, :func:`jvp`, |
| 64 | +:func:`jacrev`, :func:`jacfwd`), the transform may not be able to transform over it. |
| 65 | +If it is unable to do so, you'll receive an error message. |
| 66 | + |
| 67 | +This is a fundamental design limitation in how PyTorch's AD support is implemented |
| 68 | +and the reason why we designed the torch.func library. Please instead use the torch.func |
| 69 | +equivalents of the ``torch.autograd`` APIs: |
| 70 | +- ``torch.autograd.grad``, ``Tensor.backward`` -> ``torch.func.vjp`` or ``torch.func.grad`` |
| 71 | +- ``torch.autograd.functional.jvp`` -> ``torch.func.jvp`` |
| 72 | +- ``torch.autograd.functional.jacobian`` -> ``torch.func.jacrev`` or ``torch.func.jacfwd`` |
| 73 | +- ``torch.autograd.functional.hessian`` -> ``torch.func.hessian`` |
| 74 | + |
| 75 | +vmap limitations |
| 76 | +---------------- |
| 77 | + |
| 78 | +.. note:: |
| 79 | + :func:`vmap` is our most restrictive transform. |
| 80 | + The grad-related transforms (:func:`grad`, :func:`vjp`, :func:`jvp`) do not |
| 81 | + have these limitations. :func:`jacfwd` (and :func:`hessian`, which is |
| 82 | + implemented with :func:`jacfwd`) is a composition of :func:`vmap` and |
| 83 | + :func:`jvp` so it also has these limitations. |
| 84 | + |
| 85 | +``vmap(func)`` is a transform that returns a function that maps ``func`` over |
| 86 | +some new dimension of each input Tensor. The mental model for vmap is that it is |
| 87 | +like running a for-loop: for pure functions (i.e. in the absence of side |
| 88 | +effects), ``vmap(f)(x)`` is equivalent to: |
| 89 | + |
| 90 | +:: |
| 91 | + |
| 92 | + torch.stack([f(x_i) for x_i in x.unbind(0)]) |
| 93 | + |
| 94 | +Mutation: Arbitrary mutation of Python data structures |
| 95 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 96 | + |
| 97 | +In the presence of side effects, :func:`vmap` no longer acts like it is running |
| 98 | +a for-loop. For example, the following function: |
| 99 | + |
| 100 | +:: |
| 101 | + |
| 102 | + def f(x, list): |
| 103 | + list.pop() |
| 104 | + print("hello!") |
| 105 | + return x.sum(0) |
| 106 | + |
| 107 | + x = torch.randn(3, 1) |
| 108 | + lst = [0, 1, 2, 3] |
| 109 | + |
| 110 | + result = vmap(f, in_dims=(0, None))(x, lst) |
| 111 | + |
| 112 | +will print "hello!" once and pop only one element from ``lst``. |
| 113 | + |
| 114 | + |
| 115 | +:func:`vmap` executes ``f`` a single time, so all side effects only happen once. |
| 116 | + |
| 117 | +This is a consequence of how vmap is implemented. torch.func has a special, |
| 118 | +internal BatchedTensor class. ``vmap(f)(*inputs)`` takes all Tensor inputs, |
| 119 | +turns them into BatchedTensors, and calls ``f(*batched_tensor_inputs)``. |
| 120 | +BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized) |
| 121 | +behavior for each PyTorch operator. |
| 122 | + |
| 123 | + |
| 124 | +Mutation: in-place PyTorch Operations |
| 125 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 126 | + |
| 127 | +You might be here due to receiving an error about vmap-incompatible in-place |
| 128 | +operations. :func:`vmap` will raise an error if it encounters an unsupported PyTorch |
| 129 | +in-place operation and it will succeed otherwise. Unsupported operations |
| 130 | +are those that would cause a Tensor with more elements to be written to a |
| 131 | +Tensor with fewer elements. Here's an example of how this can occur: |
| 132 | + |
| 133 | +:: |
| 134 | + |
| 135 | + def f(x, y): |
| 136 | + x.add_(y) |
| 137 | + return x |
| 138 | + |
| 139 | + x = torch.randn(1) |
| 140 | + y = torch.randn(3, 1) # When vmapped over, looks like it has shape [1] |
| 141 | + |
| 142 | + # Raises an error because `x` has fewer elements than `y`. |
| 143 | + vmap(f, in_dims=(None, 0))(x, y) |
| 144 | + |
| 145 | +``x`` is a Tensor with one element, ``y`` is a Tensor with three elements. |
| 146 | +``x + y`` has three elements (due to broadcasting), but attempting to write |
| 147 | +three elements back into ``x``, which only has one element, raises an error |
| 148 | +due to attempting to write three elements into a Tensor with a single element. |
| 149 | + |
| 150 | +There is no problem if the Tensor being written to is batched under |
| 151 | +:func:`~torch.vmap` (i.e. it is being vmapped over). |
| 152 | + |
| 153 | +:: |
| 154 | + |
| 155 | + def f(x, y): |
| 156 | + x.add_(y) |
| 157 | + return x |
| 158 | + |
| 159 | + x = torch.randn(3, 1) |
| 160 | + y = torch.randn(3, 1) |
| 161 | + expected = x + y |
| 162 | + |
| 163 | + # Does not raise an error because x is being vmapped over. |
| 164 | + vmap(f, in_dims=(0, 0))(x, y) |
| 165 | + assert torch.allclose(x, expected) |
| 166 | + |
| 167 | +One common fix for this is to replace calls to factory functions with |
| 168 | +their "new_*" equivalent. For example: |
| 169 | + |
| 170 | +- Replace :func:`torch.zeros` with :meth:`Tensor.new_zeros` |
| 171 | +- Replace :func:`torch.empty` with :meth:`Tensor.new_empty` |
| 172 | + |
| 173 | +To see why this helps, consider the following. |
| 174 | + |
| 175 | +:: |
| 176 | + |
| 177 | + def diag_embed(vec): |
| 178 | + assert vec.dim() == 1 |
| 179 | + result = torch.zeros(vec.shape[0], vec.shape[0]) |
| 180 | + result.diagonal().copy_(vec) |
| 181 | + return result |
| 182 | + |
| 183 | + vecs = torch.tensor([[0., 1, 2], [3., 4, 5]]) |
| 184 | + |
| 185 | + # RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ... |
| 186 | + vmap(diag_embed)(vecs) |
| 187 | + |
| 188 | +Inside of :func:`~torch.vmap`, ``result`` is a Tensor of shape [3, 3]. |
| 189 | +However, although ``vec`` looks like it has shape [3], ``vec`` actually has |
| 190 | +underlying shape [2, 3]. |
| 191 | +It is not possible to copy ``vec`` into ``result.diagonal()``, which has |
| 192 | +shape [3], because it has too many elements. |
| 193 | + |
| 194 | +:: |
| 195 | + |
| 196 | + def diag_embed(vec): |
| 197 | + assert vec.dim() == 1 |
| 198 | + result = vec.new_zeros(vec.shape[0], vec.shape[0]) |
| 199 | + result.diagonal().copy_(vec) |
| 200 | + return result |
| 201 | + |
| 202 | + vecs = torch.tensor([[0., 1, 2], [3., 4, 5]]) |
| 203 | + vmap(diag_embed)(vecs) |
| 204 | + |
| 205 | +Replacing :func:`torch.zeros` with :meth:`Tensor.new_zeros` makes it so that |
| 206 | +``result`` has an underlying Tensor of shape [2, 3, 3], so it is now possible |
| 207 | +to copy ``vec``, which has underlying shape [2, 3], into ``result.diagonal()``. |
| 208 | + |
| 209 | + |
| 210 | +Mutation: out= PyTorch Operations |
| 211 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 212 | +:func:`vmap` doesn't support the ``out=`` keyword argument in PyTorch operations. |
| 213 | +It will error out gracefully if it encounters that in your code. |
| 214 | + |
| 215 | +This is not a fundamental limitation; we could theoretically support this in the |
| 216 | +future but we have chosen not to for now. |
| 217 | + |
| 218 | +Data-dependent Python control flow |
| 219 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 220 | +We don't yet support ``vmap`` over data-dependent control flow. Data-dependent |
| 221 | +control flow is when the condition of an if-statement, while-loop, or |
| 222 | +for-loop is a Tensor that is being ``vmap``'ed over. For example, the |
| 223 | +following will raise an error message: |
| 224 | + |
| 225 | +:: |
| 226 | + |
| 227 | + def relu(x): |
| 228 | + if x > 0: |
| 229 | + return x |
| 230 | + return 0 |
| 231 | + |
| 232 | + x = torch.randn(3) |
| 233 | + vmap(relu)(x) |
| 234 | + |
| 235 | +However, any control flow that is not dependent on the values in ``vmap``'ed |
| 236 | +tensors will work: |
| 237 | + |
| 238 | +:: |
| 239 | + |
| 240 | + def custom_dot(x): |
| 241 | + if x.dim() == 1: |
| 242 | + return torch.dot(x, x) |
| 243 | + return (x * x).sum() |
| 244 | + |
| 245 | + x = torch.randn(3) |
| 246 | + vmap(custom_dot)(x) |
| 247 | + |
| 248 | +JAX supports transforming over |
| 249 | +`data-dependent control flow <https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators>`_ |
| 250 | +using special control flow operators (e.g. ``jax.lax.cond``, ``jax.lax.while_loop``). |
| 251 | +We're investigating adding equivalents of those to PyTorch. |
| 252 | + |
| 253 | +Data-dependent operations (.item()) |
| 254 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 255 | +We do not (and will not) support vmap over a user-defined function that calls |
| 256 | +``.item()`` on a Tensor. For example, the following will raise an error message: |
| 257 | + |
| 258 | +:: |
| 259 | + |
| 260 | + def f(x): |
| 261 | + return x.item() |
| 262 | + |
| 263 | + x = torch.randn(3) |
| 264 | + vmap(f)(x) |
| 265 | + |
| 266 | +Please try to rewrite your code to not use ``.item()`` calls. |
| 267 | + |
| 268 | +You may also encounter an error message about using ``.item()`` but you might |
| 269 | +not have used it. In those cases, it is possible that PyTorch internally is |
| 270 | +calling ``.item()`` -- please file an issue on GitHub and we'll fix |
| 271 | +PyTorch internals. |
| 272 | + |
| 273 | +Dynamic shape operations (nonzero and friends) |
| 274 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 275 | +``vmap(f)`` requires that ``f`` applied to every "example" in your input |
| 276 | +returns a Tensor with the same shape. Operations such as ``torch.nonzero``, |
| 277 | +``torch.is_nonzero`` are not supported and will error as a result. |
| 278 | + |
| 279 | +To see why, consider the following example: |
| 280 | + |
| 281 | +:: |
| 282 | + |
| 283 | + xs = torch.tensor([[0, 1, 2], [0, 0, 3]]) |
| 284 | + vmap(torch.nonzero)(xs) |
| 285 | + |
| 286 | +``torch.nonzero(xs[0])`` returns a Tensor of shape 2; |
| 287 | +but ``torch.nonzero(xs[1])`` returns a Tensor of shape 1. |
| 288 | +We are unable to construct a single Tensor as an output; |
| 289 | +the output would need to be a ragged Tensor (and PyTorch does not yet have |
| 290 | +the concept of a ragged Tensor). |
| 291 | + |
| 292 | + |
| 293 | +Randomness |
| 294 | +---------- |
| 295 | +The user's intention when calling a random operation can be unclear. Specifically, some users may want |
| 296 | +the random behavior to be the same across batches while others may want it to differ across batches. |
| 297 | +To address this, ``vmap`` takes a randomness flag. |
| 298 | + |
| 299 | +The flag can only be passed to vmap and can take on 3 values, "error," "different," or "same," defaulting |
| 300 | +to error. Under "error" mode, any call to a random function will produce an error asking the user to use |
| 301 | +one of the other two flags based on their use case. |
| 302 | + |
| 303 | +Under "different" randomness, elements in a batch produce different random values. For instance, |
| 304 | + |
| 305 | +:: |
| 306 | + |
| 307 | + def add_noise(x): |
| 308 | + y = torch.randn(()) # y will be different across the batch |
| 309 | + return x + y |
| 310 | + |
| 311 | + x = torch.ones(3) |
| 312 | + result = vmap(add_noise, randomness="different")(x) # we get 3 different values |
| 313 | + |
| 314 | +Under "same" randomness, elements in a batch produce same random values. For instance, |
| 315 | + |
| 316 | +:: |
| 317 | + |
| 318 | + def add_noise(x): |
| 319 | + y = torch.randn(()) # y will be the same across the batch |
| 320 | + return x + y |
| 321 | + |
| 322 | + x = torch.ones(3) |
| 323 | + result = vmap(add_noise, randomness="same")(x) # we get the same value, repeated 3 times |
| 324 | + |
| 325 | + |
| 326 | +.. warning:: |
| 327 | + Our system only determine the randomness behavior of PyTorch operators and cannot control the |
| 328 | + behavior of other libraries, like numpy. This is similar to JAX's limitations with their solutions |
| 329 | + |
| 330 | +.. note:: |
| 331 | + Multiple vmap calls using either type of supported randomness will not produce |
| 332 | + the same results. Like with standard PyTorch, a user can get randomness reproducibility through |
| 333 | + either using ``torch.manual_seed()`` outside of vmap or by using generators. |
| 334 | + |
| 335 | +.. note:: |
| 336 | + Finally, our randomness differs from JAX because we aren't using a stateless PRNG, in part because PyTorch |
| 337 | + doesn't have full support for a stateless PRNG. Instead, we've introduced a flag system to allow for the |
| 338 | + most common forms of randomness that we see. If your use case does not fit these forms of randomness, please |
| 339 | + file an issue. |
0 commit comments