Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move tensor_flatten/unflatten from AQT -> TensorAOBaseClass #1615

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 67 additions & 3 deletions torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import torch
import torch.nn.utils.parametrize as parametrize
from torch.utils._python_dispatch import return_and_correct_aliasing

__all__ = [
"benchmark_model",
Expand Down Expand Up @@ -512,9 +513,15 @@ def _get_tensor_impl_constructor(
return tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class]


"""
TensorAOBase subclass is a util tensor subclass that provides commonly used functions, and should be inherited to define a new tensor subclass
"""


class TorchAOBaseTensor(torch.Tensor):
"""A util tensor subclass that provides commonly used functions
new tensor subclass can inherit it to get all the utility functions
new tensor subclass can inherit it to get all the utility functions, and
should be inherited to define a new tensor subclass

class MyTensor(TorchAOBaseTensor):
pass
Expand All @@ -539,12 +546,17 @@ def _(func, types, args, kwargs):
class PlainAQTTensorImpl(...):
...

`get_tensor_impl_constructor`:
`get_tensor_impl_constructor`:
get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor
# in constructor of MyTensor:
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)

`__tensor_flatten__` and `__tensor_unflatten__`: Used for tensor serialization/deserialization, must be defined in the subclass

`__repr__`: Used for tensor representation, must be defined in the subclass

`_apply_fn_to_data`: Used for applying a function to the data of the tensor, must be defined in the subclass
"""

implements = classmethod(_implements)
Expand Down Expand Up @@ -573,6 +585,58 @@ def _get_to_kwargs(self, *args, **kwargs):
}
return kwargs

def __tensor_flatten__(self):
raise NotImplementedError("Subclasses must implement __tensor_flatten__")

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
raise NotImplementedError("Subclasses must implement __tensor_unflatten__")

def __repr__(self):
raise NotImplementedError("Subclasses must implement __repr__")

def _apply_fn_to_data(self, fn):
raise NotImplementedError("Subclasses must implement _apply_fn_to_data")

def get_layout(self):
if not hasattr(self, "_layout"):
return None
return self._layout


implements = TorchAOBaseTensor.implements
aten = torch.ops.aten


@implements(aten.detach.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)


@implements(aten.clone.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)


@implements(aten.t.default)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.t)
)


@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.slice)
)


def fill_defaults(args, n, defaults_tail):
"""
Expand All @@ -599,7 +663,7 @@ def fill_defaults(args, n, defaults_tail):
return r


## Deprecated, will be deleted in the future
# Deprecated, will be deleted in the future
def _torch_version_at_least(min_version):
return is_fbcode() or version("torch") >= min_version

Expand Down
Loading