Skip to content

Commit 3591008

Browse files
committed
Add module swap -> tensor subclass migration tutorial
Adds a migration tutorial from module swap to tensor subclass for expressing basic quantization. This is a simplified version of the existing subclass tutorials in torchao, removing layers of indirection like Layout and TensorImpl for ease of understanding. This commit also removes overlapping content from the existing contributor guide. Work was done with @bdhirsh.
1 parent 4e4f4df commit 3591008

7 files changed

+767
-215
lines changed

docs/source/contributor_guide.rst

+1-215
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ On the top of the stack will be the final quantization algorithms and quantizati
125125

126126
For demonstration purposes, let's say after previous step we have ``AffineQuantizedTensor`` and ``to_affine_quantized`` factory function defined. For simplicity, let's say ``to_affine_quantized`` takes a high precision floating point Tensor and a target_dtype (e.g. torch.int8) and converts it to an ``AffineQuantizedTensor`` with corresponding dtype.
127127

128-
Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in ``Tensor Subclass Developer Guide`` section.
128+
Note: below are all for explaining the concepts, more detailed introduction for utils and examples we provide can be found in the `Writing Your Own Tensor Subclass <subclass_basic.html>`__ tutorial.
129129

130130
Weight Only Quantization
131131
########################
@@ -257,220 +257,6 @@ During Save/Load
257257

258258
Since ``AffineQuantizedTensor`` weight is still a ``torch.Tensor``, save/load works the same way as the original high precision floating point model. See the `serialization doc <https://pytorch.org/ao/stable/serialization.html>`__ for more details.
259259

260-
Tensor Subclass Developer Guide
261-
===============================
262-
263-
We have covered high level overview and how everything is connected together in the previous section, this section will focus on Tensor Subclasses, which is the main extension point we rely on to provide flexibility of supporting inference, training and fine tuning with low precision Tensors and composability with torch.compile, autograd, distributed primitives in these scenarios.
264-
265-
Prerequisites
266-
~~~~~~~~~~~~~
267-
Some externally available resources for tensor subclasses:
268-
269-
* `tensor subclass doc <pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor>`__
270-
* `Edward's podcast about tensor subclasses <https://podcasts.apple.com/us/podcast/tensor-subclasses-and-pt2/id1566080008?i=1000646728968>`__
271-
* `Tensor subclass zoo <https://github.com/albanD/subclass_zoo>`__
272-
273-
Why Tensor Subclass?
274-
~~~~~~~~~~~~~~~~~~~~
275-
There are multiple ways people can implement quantization techniques or new dtypes, main motivation for us to recommend the tensor subclass based approach are three things:
276-
(1). It’s natural for quantization to be modeled as a dtype conversion, so implementing it with tensor subclass means we are not introducing new concepts but reusing existing concepts like dtype, layout that already exists in pytorch core
277-
(2). Since tensor subclass intercepts computation at torch function or aten ops level, as long as the same function/operator is used, we will be able to quantize the model. This allows the model that’s using variants of native modules (e.g. a slightly modified version of nn.Linear) to still be compatible with quantization
278-
(3). Tensor subclass is also the approach adopted by other techniques like sparsity and distributed, so implementing quantization or dtype conversion with tensor subclass would make it easier for it to be composable with these techniques
279-
280-
Example Code for a new DType
281-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
282-
Please feel free to start with `tutorial <https://github.com/pytorch/ao/blob/main/tutorials/developer_api_guide/my_dtype_tensor_subclass.py>`__ for a end to end working example that combines everything we talked about together and come back to the doc for clarifications and documentations.
283-
284-
Basic Structure
285-
~~~~~~~~~~~~~~~
286-
A tensor subclass needs to define a few basic methods: ``__new__``, ``__init__``, ``__tensor_flatten__``, ``__tensor_unflatten__``
287-
and also dispatch functions for torch functions ``__torch_function__`` and aten ops ``__torch_dispatch__``.
288-
289-
Here is an example of basic structure::
290-
# check out docs in https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L437
291-
from torchao.utils import TorchAOBaseTensor
292-
293-
class MyDTypeLayout(TorchAOBaseTensor):
294-
# see tutorial code for details
295-
pass
296-
297-
class MyDtypeTensor(TorchAOBaseTensor):
298-
"""We need to define `__new__` for constructing a new tensor subclass instance and `__init__` for initialize
299-
the instance. There is no requirement on what the argument list should look like here, only requirement is
300-
that `__new__` must return a Tensor instance with `torch.Tensor._make_wrapper_subclass(cls, shape, ...)` call
301-
"""
302-
@staticmethod
303-
def __new__(
304-
cls,
305-
tensor_impl: MyDTypeLayout,
306-
shape: torch.Size,
307-
dtype: Optional[torch.dtype] = None,
308-
):
309-
...
310-
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
311-
312-
def __init__(
313-
self,
314-
tensor_impl: MyDTypeLayout,
315-
shape: torch.Size, ...
316-
):
317-
self.tensor_impl = tensor_impl
318-
319-
320-
"""`__tensor_flatten__` and `__tensor_unflatten__` are used to desugar the tensor into native Tensors/attributes and
321-
reconstruct the tensor subclass instance from the desugared tensor and attributes, these are required to define
322-
a Tensor subclass for torch.compile support
323-
"""
324-
def __tensor_flatten__(self):
325-
return ["tensor_impl"], [self.shape]
326-
327-
"""see https://github.com/pytorch/pytorch/blob/3bc2004f9123a32f381ef64202252d59109507f3/torch/utils/_python_dispatch.py#L289 for documentations for outer_size and outer_stride
328-
"""
329-
@classmethod
330-
def __tensor_unflatten__(
331-
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
332-
):
333-
tensor_impl = tensor_data_dict["tensor_impl"]
334-
shape, = tensor_attributes
335-
return cls(
336-
tensor_impl,
337-
shape if outer_size is None else outer_size,
338-
)
339-
340-
341-
"""classmethod that converts from a floating point Tensor (fp32/fp16/bf16) to the current dtype
342-
"""
343-
@classmethod
344-
def from_float(
345-
cls,
346-
input_float: torch.Tensor,
347-
):
348-
mapping_type = MappingType.SYMMETRIC
349-
block_size = input_float.shape
350-
dtype = torch.int16
351-
scale, _ = choose_qparams_affine(input_float, mapping_type, block_size, dtype)
352-
int_data = (input_float / scale).to(torch.int8)
353-
tensor_impl = MyDTypeLayout.from_plain(int_data, scale)
354-
return cls(tensor_impl, input_float.shape)
355-
356-
357-
"""[Optional] see docs for `Layout/Packing` under `Quantized Tensors` section to understand what layout_type is
358-
"""
359-
@property
360-
def _layout(self) -> LayoutType:
361-
return self.tensor_impl._layout
362-
363-
"""There are two entry points that we can modify the behavior of a pytorch op: torch_function and torch_dispatch:
364-
365-
__torch_function__: will be called whenever a torch level function is called on the Tensor object, for example: torch.nn.functional.linear,
366-
tensor.detach, tensor.reshape, tensor.t etc.
367-
368-
__torch_dispatch__: will be called in the C++ dispatcher, when an aten operator is called on the Tensor object, for example:
369-
aten.mm, aten.addmm, aten.detach.default, aten.t.default etc.
370-
you can checkout https://github.com/pytorch/ao/blob/e283743b3cc4612bb641b88dca3670231724d396/torchao/utils.py#L361-L389 to understand what `__torch_function__` and `__torch_dispatch__` are doing, but with `TorchAoBaseTensor` user can use
371-
some helper functions directly (see next section)
372-
373-
Operator Support
374-
~~~~~~~~~~~~~~~~
375-
There are two types of operator support, torch function and aten ops. For torch functions (e.g. ``torch.nn.functional.linear``), we’ll need to overwrite ``__torch_function__`` callback in the Tensor subclass, for aten ops (e.g. ``torch.ops.aten.mm``), we’ll need to overwrite ``__torch_dispatch__`` callback function.
376-
377-
For a new dtype, we’d like people to define the following decorator::
378-
if your dtype class is inherited from `torchao.utils.TorchAoBaseTensor`, you can do:
379-
380-
implements = my_dtype_tensor_cls.implements
381-
382-
And we can implement the operator dispatch with the following::
383-
# Example for torch_function dispatch for torch.nn.functional.linear
384-
def _quantized_linear_op(input_tensor, weight_tensor, bias):
385-
if isinstance(input_tensor, MyDtypeTensor):
386-
input_tensor = input_tensor.dequantize()
387-
if isinstance(weight_tensor, MyDtypeTensor):
388-
weight_tensor = weight_tensor.dequantize()
389-
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
390-
391-
392-
@implements(torch.nn.functional.linear)
393-
def _(*args, **kwargs):
394-
input_tensor, weight_tensor, bias = (
395-
args[0],
396-
args[1],
397-
args[2] if len(args) > 2 else None,
398-
)
399-
# using try/except here so that we can have a general fallback when input_tensor/weight_tensor
400-
# is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to
401-
# make the branches easier to understand in `_quantized_linear_op`
402-
try:
403-
return _quantized_linear_op(input_tensor, weight_tensor, bias)
404-
except NotImplementedError:
405-
if isinstance(input_tensor, MyDtypeTensor):
406-
input_tensor = input_tensor.dequantize()
407-
if isinstance(weight_tensor, MyDtypeTensor):
408-
weight_tensor = weight_tensor.dequantize()
409-
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
410-
411-
# Example for aten op dispatch for aten.detach.default
412-
@implements(aten.detach.default)
413-
def _(func, *args, **kwargs):
414-
# `return_and_correct_aliasing` should be used by wrapper tensor ``__torch_dispatch__`` subclasses that would like to
415-
# work with torch.compile. It ensures that the subclass properly implements the aliasing behavior of every op,
416-
# which is needed for correctness in AOTAutograd.
417-
418-
# `_apply_fn_to_data` just applies the function to the tensor data in `args[0]`, `args[0]` is a tensor subclass
419-
# of `my_dtype`
420-
return return_and_correct_aliasing(
421-
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
422-
)
423-
424-
What ops do we need to overwrite? This depends on the model we are trying to quantize, commonly overwritten ops are:
425-
``__torch_function__``: ``torch.nn.functional.linear``
426-
``__torch_dispatch__``: ``torch.ops.aten.addmm.default``, ``torch.ops.aten.mm.default``, ``torch.ops.aten.detach.default``, ``torch.ops.aten.t.default``
427-
428-
You can also find the ops that can be overwritten in ``__torch_function__`` or ``__torch_dispatch__`` with the following code, and you can start with a model that you want to optimize, start with just overwriting the important ops like linear, and gradually expand the coverage until the test runs and you get the expected optimized generated code (see Optimized Operators section for more details)::
429-
class M(torch.nn.Module):
430-
def __init__(self) -> None:
431-
super().__init__()
432-
self.linear = torch.nn.Linear(10, 10)
433-
def forward(self, x: torch.Tensor) -> torch.Tensor:
434-
return self.linear(x) + x
435-
436-
from torch.overrides import TorchFunctionMode
437-
class TorchFunctionLoggingMode(TorchFunctionMode):
438-
def __torch_function__(cls, func, types, args=(), kwargs=None):
439-
if kwargs is None:
440-
kwargs = {}
441-
print(f"TORCH_FUNC={str(func)}")
442-
return func(*args, **kwargs)
443-
444-
with TorchFunctionLoggingMode():
445-
m(*example_inputs)
446-
447-
## Example output
448-
# TORCH_FUNC=<built-in function linear>
449-
# TORCH_FUNC=<method 'add' of 'torch._C.TensorBase' objects>
450-
451-
452-
from torch.utils._python_dispatch import TorchDispatchMode
453-
class TorchDispatchLoggingMode(TorchDispatchMode):
454-
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
455-
if kwargs is None:
456-
kwargs = {}
457-
print(f"ATEN_FUNC={str(func)}")
458-
return func(*args, **kwargs)
459-
460-
with TorchDispatchLoggingMode():
461-
m(*example_inputs)
462-
463-
## Example output
464-
# ATEN_FUNC=aten.t.default
465-
# ATEN_FUNC=aten.addmm.default
466-
# ATEN_FUNC=aten.add.Tensor
467-
468-
# or a more polished logging for torch_dispatch (aten) ops: https://github.com/albanD/subclass_zoo/blob/main/logging_mode.py
469-
470-
Alternatively, you can run a test example (e.g. use your quantized model with tensor parallelism, FSDP etc.) and discover the missing ops and add them until the test passes.
471-
472-
We are still working on a table that talks about for each feature what are the operators that need to be supported.
473-
474260
Adding Efficient Kernels
475261
~~~~~~~~~~~~~~~~~~~~~~~~
476262

docs/source/index.rst

+2
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,5 @@ for an overall introduction to the library and recent highlight and updates.
3737
:caption: Tutorials
3838

3939
serialization
40+
subclass_basic
41+
subclass_advanced

docs/source/subclass_advanced.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Writing Your Own Quantized Tensor (advanced)
2+
--------------------------------------------
3+
4+
Coming soon!

0 commit comments

Comments
 (0)