diff --git a/torchao/utils.py b/torchao/utils.py index 7a17c1b104..5575fd96b4 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -9,6 +9,7 @@ import torch import torch.nn.utils.parametrize as parametrize +from torch.utils._python_dispatch import return_and_correct_aliasing __all__ = [ "benchmark_model", @@ -512,9 +513,15 @@ def _get_tensor_impl_constructor( return tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class] +""" +TensorAOBase subclass is a util tensor subclass that provides commonly used functions, and should be inherited to define a new tensor subclass +""" + + class TorchAOBaseTensor(torch.Tensor): """A util tensor subclass that provides commonly used functions - new tensor subclass can inherit it to get all the utility functions + new tensor subclass can inherit it to get all the utility functions, and + should be inherited to define a new tensor subclass class MyTensor(TorchAOBaseTensor): pass @@ -539,12 +546,17 @@ def _(func, types, args, kwargs): class PlainAQTTensorImpl(...): ... - `get_tensor_impl_constructor`: + `get_tensor_impl_constructor`: get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor # in constructor of MyTensor: tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout) + `__tensor_flatten__` and `__tensor_unflatten__`: Used for tensor serialization/deserialization, must be defined in the subclass + + `__repr__`: Used for tensor representation, must be defined in the subclass + + `_apply_fn_to_data`: Used for applying a function to the data of the tensor, must be defined in the subclass """ implements = classmethod(_implements) @@ -573,6 +585,58 @@ def _get_to_kwargs(self, *args, **kwargs): } return kwargs + def __tensor_flatten__(self): + raise NotImplementedError("Subclasses must implement __tensor_flatten__") + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + raise NotImplementedError("Subclasses must implement __tensor_unflatten__") + + def __repr__(self): + raise NotImplementedError("Subclasses must implement __repr__") + + def _apply_fn_to_data(self, fn): + raise NotImplementedError("Subclasses must implement _apply_fn_to_data") + + def get_layout(self): + if not hasattr(self, "_layout"): + return None + return self._layout + + +implements = TorchAOBaseTensor.implements +aten = torch.ops.aten + + +@implements(aten.detach.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements(aten.clone.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@implements(aten.t.default) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.t) + ) + + +@implements(aten.slice.Tensor) +def _(func, types, args, kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.slice) + ) + def fill_defaults(args, n, defaults_tail): """ @@ -599,7 +663,7 @@ def fill_defaults(args, n, defaults_tail): return r -## Deprecated, will be deleted in the future +# Deprecated, will be deleted in the future def _torch_version_at_least(min_version): return is_fbcode() or version("torch") >= min_version