Skip to content

Commit 31e66ca

Browse files
zou3519pytorchmergebot
authored andcommitted
[torch.func] Add docs (pytorch#91319)
Docs copy-pasted from functorch docs with minor adjustments. We are keeping the functorch docs for BC, though that's up for debate -- we could also just say "see .. in torch.func" for some, but not all doc pages (we still want to keep around any examples that use make_functional so that users can tell what the difference between that and the new functional_call is). Test Plan: - docs preview Pull Request resolved: pytorch#91319 Approved by: https://github.com/samdow
1 parent 6f034dc commit 31e66ca

File tree

3 files changed

+584
-2
lines changed

3 files changed

+584
-2
lines changed

docs/source/func.rst

+49-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,54 @@
11
torch.func
22
==========
33

4+
.. currentmodule:: torch.func
5+
6+
torch.func, previously known as "functorch", is
7+
`JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch.
8+
9+
.. note::
10+
This library is currently in `beta <https://pytorch.org/blog/pytorch-feature-classification-changes/#beta>`_.
11+
What this means is that the features generally work (unless otherwise documented)
12+
and we (the PyTorch team) are committed to bringing this library forward. However, the APIs
13+
may change under user feedback and we don't have full coverage over PyTorch operations.
14+
15+
If you have suggestions on the API or use-cases you'd like to be covered, please
16+
open an github issue or reach out. We'd love to hear about how you're using the library.
17+
18+
What are composable function transforms?
19+
----------------------------------------
20+
21+
- A "function transform" is a higher-order function that accepts a numerical function
22+
and returns a new function that computes a different quantity.
23+
24+
- :mod:`torch.func` has auto-differentiation transforms (``grad(f)`` returns a function that
25+
computes the gradient of ``f``), a vectorization/batching transform (``vmap(f)``
26+
returns a function that computes ``f`` over batches of inputs), and others.
27+
28+
- These function transforms can compose with each other arbitrarily. For example,
29+
composing ``vmap(grad(f))`` computes a quantity called per-sample-gradients that
30+
stock PyTorch cannot efficiently compute today.
31+
32+
Why composable function transforms?
33+
-----------------------------------
34+
35+
There are a number of use cases that are tricky to do in PyTorch today:
36+
37+
- computing per-sample-gradients (or other per-sample quantities)
38+
- running ensembles of models on a single machine
39+
- efficiently batching together tasks in the inner-loop of MAML
40+
- efficiently computing Jacobians and Hessians
41+
- efficiently computing batched Jacobians and Hessians
42+
43+
Composing :func:`vmap`, :func:`grad`, and :func:`vjp` transforms allows us to express the above without designing a separate subsystem for each.
44+
This idea of composable function transforms comes from the `JAX framework <https://github.com/google/jax>`_.
45+
46+
Read More
47+
---------
48+
449
.. toctree::
5-
:maxdepth: 2
50+
:maxdepth: 2
651

7-
func.api
52+
func.whirlwind_tour
53+
func.api
54+
func.ux_limitations

docs/source/func.ux_limitations.rst

+339
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,339 @@
1+
.. currentmodule:: torch.func
2+
3+
.. _ux-limitations:
4+
5+
UX Limitations
6+
==============
7+
8+
torch.func, like `JAX <https://github.com/google/jax>`_, has restrictions around
9+
what can be transformed. In general, JAX’s limitations are that transforms
10+
only work with pure functions: that is, functions where the output is completely
11+
determined by the input and that do not involve side effects (like mutation).
12+
13+
We have a similar guarantee: our transforms work well with pure functions.
14+
However, we do support certain in-place operations. On one hand, writing code
15+
compatible with function transforms may involve changing how you write PyTorch
16+
code, on the other hand, you may find that our transforms let you express things
17+
that were previously difficult to express in PyTorch.
18+
19+
General limitations
20+
-------------------
21+
22+
All torch.func transforms share a limitation in that a function should not
23+
assign to global variables. Instead, all outputs to a function must be returned
24+
from the function. This restriction comes from how torch.func is implemented:
25+
each transform wraps Tensor inputs in special torch.func Tensor subclasses
26+
that facilitate the transform.
27+
28+
So, instead of the following:
29+
30+
::
31+
32+
import torch
33+
from torch.func import grad
34+
35+
# Don't do this
36+
intermediate = None
37+
38+
def f(x):
39+
global intermediate
40+
intermediate = x.sin()
41+
z = intermediate.sin()
42+
return z
43+
44+
x = torch.randn([])
45+
grad_x = grad(f)(x)
46+
47+
Please rewrite ``f`` to return ``intermediate``:
48+
49+
::
50+
51+
def f(x):
52+
intermediate = x.sin()
53+
z = intermediate.sin()
54+
return z, intermediate
55+
56+
grad_x, intermediate = grad(f, has_aux=True)(x)
57+
58+
torch.autograd APIs
59+
-------------------
60+
61+
If you are trying to use a ``torch.autograd`` API like ``torch.autograd.grad``
62+
or ``torch.autograd.backward`` inside of a function being transformed by
63+
:func:`vmap` or one of torch.func's AD transforms (:func:`vjp`, :func:`jvp`,
64+
:func:`jacrev`, :func:`jacfwd`), the transform may not be able to transform over it.
65+
If it is unable to do so, you'll receive an error message.
66+
67+
This is a fundamental design limitation in how PyTorch's AD support is implemented
68+
and the reason why we designed the torch.func library. Please instead use the torch.func
69+
equivalents of the ``torch.autograd`` APIs:
70+
- ``torch.autograd.grad``, ``Tensor.backward`` -> ``torch.func.vjp`` or ``torch.func.grad``
71+
- ``torch.autograd.functional.jvp`` -> ``torch.func.jvp``
72+
- ``torch.autograd.functional.jacobian`` -> ``torch.func.jacrev`` or ``torch.func.jacfwd``
73+
- ``torch.autograd.functional.hessian`` -> ``torch.func.hessian``
74+
75+
vmap limitations
76+
----------------
77+
78+
.. note::
79+
:func:`vmap` is our most restrictive transform.
80+
The grad-related transforms (:func:`grad`, :func:`vjp`, :func:`jvp`) do not
81+
have these limitations. :func:`jacfwd` (and :func:`hessian`, which is
82+
implemented with :func:`jacfwd`) is a composition of :func:`vmap` and
83+
:func:`jvp` so it also has these limitations.
84+
85+
``vmap(func)`` is a transform that returns a function that maps ``func`` over
86+
some new dimension of each input Tensor. The mental model for vmap is that it is
87+
like running a for-loop: for pure functions (i.e. in the absence of side
88+
effects), ``vmap(f)(x)`` is equivalent to:
89+
90+
::
91+
92+
torch.stack([f(x_i) for x_i in x.unbind(0)])
93+
94+
Mutation: Arbitrary mutation of Python data structures
95+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
96+
97+
In the presence of side effects, :func:`vmap` no longer acts like it is running
98+
a for-loop. For example, the following function:
99+
100+
::
101+
102+
def f(x, list):
103+
list.pop()
104+
print("hello!")
105+
return x.sum(0)
106+
107+
x = torch.randn(3, 1)
108+
lst = [0, 1, 2, 3]
109+
110+
result = vmap(f, in_dims=(0, None))(x, lst)
111+
112+
will print "hello!" once and pop only one element from ``lst``.
113+
114+
115+
:func:`vmap` executes ``f`` a single time, so all side effects only happen once.
116+
117+
This is a consequence of how vmap is implemented. torch.func has a special,
118+
internal BatchedTensor class. ``vmap(f)(*inputs)`` takes all Tensor inputs,
119+
turns them into BatchedTensors, and calls ``f(*batched_tensor_inputs)``.
120+
BatchedTensor overrides the PyTorch API to produce batched (i.e. vectorized)
121+
behavior for each PyTorch operator.
122+
123+
124+
Mutation: in-place PyTorch Operations
125+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
126+
127+
You might be here due to receiving an error about vmap-incompatible in-place
128+
operations. :func:`vmap` will raise an error if it encounters an unsupported PyTorch
129+
in-place operation and it will succeed otherwise. Unsupported operations
130+
are those that would cause a Tensor with more elements to be written to a
131+
Tensor with fewer elements. Here's an example of how this can occur:
132+
133+
::
134+
135+
def f(x, y):
136+
x.add_(y)
137+
return x
138+
139+
x = torch.randn(1)
140+
y = torch.randn(3, 1) # When vmapped over, looks like it has shape [1]
141+
142+
# Raises an error because `x` has fewer elements than `y`.
143+
vmap(f, in_dims=(None, 0))(x, y)
144+
145+
``x`` is a Tensor with one element, ``y`` is a Tensor with three elements.
146+
``x + y`` has three elements (due to broadcasting), but attempting to write
147+
three elements back into ``x``, which only has one element, raises an error
148+
due to attempting to write three elements into a Tensor with a single element.
149+
150+
There is no problem if the Tensor being written to is batched under
151+
:func:`~torch.vmap` (i.e. it is being vmapped over).
152+
153+
::
154+
155+
def f(x, y):
156+
x.add_(y)
157+
return x
158+
159+
x = torch.randn(3, 1)
160+
y = torch.randn(3, 1)
161+
expected = x + y
162+
163+
# Does not raise an error because x is being vmapped over.
164+
vmap(f, in_dims=(0, 0))(x, y)
165+
assert torch.allclose(x, expected)
166+
167+
One common fix for this is to replace calls to factory functions with
168+
their "new_*" equivalent. For example:
169+
170+
- Replace :func:`torch.zeros` with :meth:`Tensor.new_zeros`
171+
- Replace :func:`torch.empty` with :meth:`Tensor.new_empty`
172+
173+
To see why this helps, consider the following.
174+
175+
::
176+
177+
def diag_embed(vec):
178+
assert vec.dim() == 1
179+
result = torch.zeros(vec.shape[0], vec.shape[0])
180+
result.diagonal().copy_(vec)
181+
return result
182+
183+
vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
184+
185+
# RuntimeError: vmap: inplace arithmetic(self, *extra_args) is not possible ...
186+
vmap(diag_embed)(vecs)
187+
188+
Inside of :func:`~torch.vmap`, ``result`` is a Tensor of shape [3, 3].
189+
However, although ``vec`` looks like it has shape [3], ``vec`` actually has
190+
underlying shape [2, 3].
191+
It is not possible to copy ``vec`` into ``result.diagonal()``, which has
192+
shape [3], because it has too many elements.
193+
194+
::
195+
196+
def diag_embed(vec):
197+
assert vec.dim() == 1
198+
result = vec.new_zeros(vec.shape[0], vec.shape[0])
199+
result.diagonal().copy_(vec)
200+
return result
201+
202+
vecs = torch.tensor([[0., 1, 2], [3., 4, 5]])
203+
vmap(diag_embed)(vecs)
204+
205+
Replacing :func:`torch.zeros` with :meth:`Tensor.new_zeros` makes it so that
206+
``result`` has an underlying Tensor of shape [2, 3, 3], so it is now possible
207+
to copy ``vec``, which has underlying shape [2, 3], into ``result.diagonal()``.
208+
209+
210+
Mutation: out= PyTorch Operations
211+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
212+
:func:`vmap` doesn't support the ``out=`` keyword argument in PyTorch operations.
213+
It will error out gracefully if it encounters that in your code.
214+
215+
This is not a fundamental limitation; we could theoretically support this in the
216+
future but we have chosen not to for now.
217+
218+
Data-dependent Python control flow
219+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
220+
We don't yet support ``vmap`` over data-dependent control flow. Data-dependent
221+
control flow is when the condition of an if-statement, while-loop, or
222+
for-loop is a Tensor that is being ``vmap``'ed over. For example, the
223+
following will raise an error message:
224+
225+
::
226+
227+
def relu(x):
228+
if x > 0:
229+
return x
230+
return 0
231+
232+
x = torch.randn(3)
233+
vmap(relu)(x)
234+
235+
However, any control flow that is not dependent on the values in ``vmap``'ed
236+
tensors will work:
237+
238+
::
239+
240+
def custom_dot(x):
241+
if x.dim() == 1:
242+
return torch.dot(x, x)
243+
return (x * x).sum()
244+
245+
x = torch.randn(3)
246+
vmap(custom_dot)(x)
247+
248+
JAX supports transforming over
249+
`data-dependent control flow <https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators>`_
250+
using special control flow operators (e.g. ``jax.lax.cond``, ``jax.lax.while_loop``).
251+
We're investigating adding equivalents of those to PyTorch.
252+
253+
Data-dependent operations (.item())
254+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
255+
We do not (and will not) support vmap over a user-defined function that calls
256+
``.item()`` on a Tensor. For example, the following will raise an error message:
257+
258+
::
259+
260+
def f(x):
261+
return x.item()
262+
263+
x = torch.randn(3)
264+
vmap(f)(x)
265+
266+
Please try to rewrite your code to not use ``.item()`` calls.
267+
268+
You may also encounter an error message about using ``.item()`` but you might
269+
not have used it. In those cases, it is possible that PyTorch internally is
270+
calling ``.item()`` -- please file an issue on GitHub and we'll fix
271+
PyTorch internals.
272+
273+
Dynamic shape operations (nonzero and friends)
274+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
275+
``vmap(f)`` requires that ``f`` applied to every "example" in your input
276+
returns a Tensor with the same shape. Operations such as ``torch.nonzero``,
277+
``torch.is_nonzero`` are not supported and will error as a result.
278+
279+
To see why, consider the following example:
280+
281+
::
282+
283+
xs = torch.tensor([[0, 1, 2], [0, 0, 3]])
284+
vmap(torch.nonzero)(xs)
285+
286+
``torch.nonzero(xs[0])`` returns a Tensor of shape 2;
287+
but ``torch.nonzero(xs[1])`` returns a Tensor of shape 1.
288+
We are unable to construct a single Tensor as an output;
289+
the output would need to be a ragged Tensor (and PyTorch does not yet have
290+
the concept of a ragged Tensor).
291+
292+
293+
Randomness
294+
----------
295+
The user's intention when calling a random operation can be unclear. Specifically, some users may want
296+
the random behavior to be the same across batches while others may want it to differ across batches.
297+
To address this, ``vmap`` takes a randomness flag.
298+
299+
The flag can only be passed to vmap and can take on 3 values, "error," "different," or "same," defaulting
300+
to error. Under "error" mode, any call to a random function will produce an error asking the user to use
301+
one of the other two flags based on their use case.
302+
303+
Under "different" randomness, elements in a batch produce different random values. For instance,
304+
305+
::
306+
307+
def add_noise(x):
308+
y = torch.randn(()) # y will be different across the batch
309+
return x + y
310+
311+
x = torch.ones(3)
312+
result = vmap(add_noise, randomness="different")(x) # we get 3 different values
313+
314+
Under "same" randomness, elements in a batch produce same random values. For instance,
315+
316+
::
317+
318+
def add_noise(x):
319+
y = torch.randn(()) # y will be the same across the batch
320+
return x + y
321+
322+
x = torch.ones(3)
323+
result = vmap(add_noise, randomness="same")(x) # we get the same value, repeated 3 times
324+
325+
326+
.. warning::
327+
Our system only determine the randomness behavior of PyTorch operators and cannot control the
328+
behavior of other libraries, like numpy. This is similar to JAX's limitations with their solutions
329+
330+
.. note::
331+
Multiple vmap calls using either type of supported randomness will not produce
332+
the same results. Like with standard PyTorch, a user can get randomness reproducibility through
333+
either using ``torch.manual_seed()`` outside of vmap or by using generators.
334+
335+
.. note::
336+
Finally, our randomness differs from JAX because we aren't using a stateless PRNG, in part because PyTorch
337+
doesn't have full support for a stateless PRNG. Instead, we've introduced a flag system to allow for the
338+
most common forms of randomness that we see. If your use case does not fit these forms of randomness, please
339+
file an issue.

0 commit comments

Comments
 (0)