Skip to content

Commit 457bdf1

Browse files
committed
Tensor subclass generalization
1 parent b2fb664 commit 457bdf1

File tree

1 file changed

+41
-19
lines changed

1 file changed

+41
-19
lines changed

torchao/utils.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,27 @@ def _get_tensor_impl_constructor(
512512
return tensor_class._LAYOUT_CONSTRUCTOR_TABLE[layout_class]
513513

514514

515+
def _get_to_kwargs(self, *args, **kwargs):
516+
# `torch._C._nn._parse_to` can't handle `layout` argument
517+
for arg in args:
518+
if isinstance(arg, torch.layout):
519+
args.remove(arg)
520+
if "layout" in kwargs:
521+
kwargs.pop("layout")
522+
# ignoring `non_blocking` and `memory_format` args since these are not
523+
# very useful for most of the tensor subclasses
524+
# if in the future there are use cases that need these, we'd recommend
525+
# to override `_get_to_kwargs` and return these args
526+
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
527+
device = self.device if device is None else device
528+
dtype = self.dtype if dtype is None else dtype
529+
kwargs = {
530+
"device": device,
531+
"dtype": dtype,
532+
}
533+
return kwargs
534+
535+
515536
class TorchAOBaseTensor(torch.Tensor):
516537
"""A util tensor subclass that provides commonly used functions
517538
new tensor subclass can inherit it to get all the utility functions
@@ -552,26 +573,27 @@ class PlainAQTTensorImpl(...):
552573
__torch_function__ = classmethod(_dispatch__torch_function__)
553574
register_layout = classmethod(_register_layout)
554575
get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor)
576+
_get_to_kwargs = _get_to_kwargs
577+
578+
def __tensor_flatten__(self):
579+
raise NotImplementedError("Subclasses must implement __tensor_flatten__")
580+
581+
@classmethod
582+
def __tensor_unflatten__(
583+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
584+
):
585+
raise NotImplementedError("Subclasses must implement __tensor_unflatten__")
586+
587+
def __repr__(self):
588+
raise NotImplementedError("Subclasses must implement __repr__")
589+
590+
def _apply_fn_to_data(self, fn):
591+
raise NotImplementedError("Subclasses must implement _apply_fn_to_data")
555592

556-
def _get_to_kwargs(self, *args, **kwargs):
557-
# `torch._C._nn._parse_to` can't handle `layout` argument
558-
for arg in args:
559-
if isinstance(arg, torch.layout):
560-
args.remove(arg)
561-
if "layout" in kwargs:
562-
kwargs.pop("layout")
563-
# ignoring `non_blocking` and `memory_format` args since these are not
564-
# very useful for most of the tensor subclasses
565-
# if in the future there are use cases that need these, we'd recommend
566-
# to override `_get_to_kwargs` and return these args
567-
device, dtype, _, _ = torch._C._nn._parse_to(*args, **kwargs)
568-
device = self.device if device is None else device
569-
dtype = self.dtype if dtype is None else dtype
570-
kwargs = {
571-
"device": device,
572-
"dtype": dtype,
573-
}
574-
return kwargs
593+
def get_layout(self):
594+
if not hasattr(self, "_layout"):
595+
return None
596+
return self._layout
575597

576598

577599
def fill_defaults(args, n, defaults_tail):

0 commit comments

Comments
 (0)