99
1010import torch
1111import torch .nn .utils .parametrize as parametrize
12+ from torch .utils ._python_dispatch import return_and_correct_aliasing
1213
1314__all__ = [
1415 "benchmark_model" ,
@@ -512,9 +513,36 @@ def _get_tensor_impl_constructor(
512513 return tensor_class ._LAYOUT_CONSTRUCTOR_TABLE [layout_class ]
513514
514515
516+ def _get_to_kwargs (self , * args , ** kwargs ):
517+ # `torch._C._nn._parse_to` can't handle `layout` argument
518+ for arg in args :
519+ if isinstance (arg , torch .layout ):
520+ args .remove (arg )
521+ if "layout" in kwargs :
522+ kwargs .pop ("layout" )
523+ # ignoring `non_blocking` and `memory_format` args since these are not
524+ # very useful for most of the tensor subclasses
525+ # if in the future there are use cases that need these, we'd recommend
526+ # to override `_get_to_kwargs` and return these args
527+ device , dtype , _ , _ = torch ._C ._nn ._parse_to (* args , ** kwargs )
528+ device = self .device if device is None else device
529+ dtype = self .dtype if dtype is None else dtype
530+ kwargs = {
531+ "device" : device ,
532+ "dtype" : dtype ,
533+ }
534+ return kwargs
535+
536+
537+ """
538+ TensorAOBase subclass is a util tensor subclass that provides commonly used functions, and should be inherited to define a new tensor subclass
539+ """
540+
541+
515542class TorchAOBaseTensor (torch .Tensor ):
516543 """A util tensor subclass that provides commonly used functions
517- new tensor subclass can inherit it to get all the utility functions
544+ new tensor subclass can inherit it to get all the utility functions, and
545+ should be inherited to define a new tensor subclass
518546
519547 class MyTensor(TorchAOBaseTensor):
520548 pass
@@ -539,39 +567,45 @@ def _(func, types, args, kwargs):
539567 class PlainAQTTensorImpl(...):
540568 ...
541569
542- `get_tensor_impl_constructor`:
570+ `get_tensor_impl_constructor`:
543571 get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor
544572 # in constructor of MyTensor:
545573 tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
546574 tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
547575
576+ `__tensor_flatten__` and `__tensor_unflatten__`: Used for tensor serialization/deserialization, must be defined in the subclass
577+
578+ `__repr__`: Used for tensor representation, must be defined in the subclass
579+
580+ `_apply_fn_to_data`: Used for applying a function to the data of the tensor, must be defined in the subclass
548581 """
549582
550583 implements = classmethod (_implements )
551584 __torch_dispatch__ = classmethod (_dispatch__torch_dispatch__ )
552585 __torch_function__ = classmethod (_dispatch__torch_function__ )
553586 register_layout = classmethod (_register_layout )
554587 get_tensor_impl_constructor = classmethod (_get_tensor_impl_constructor )
588+ _get_to_kwargs = _get_to_kwargs
589+
590+ def __tensor_flatten__ (self ):
591+ raise NotImplementedError ("Subclasses must implement __tensor_flatten__" )
592+
593+ @classmethod
594+ def __tensor_unflatten__ (
595+ cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
596+ ):
597+ raise NotImplementedError ("Subclasses must implement __tensor_unflatten__" )
598+
599+ def __repr__ (self ):
600+ raise NotImplementedError ("Subclasses must implement __repr__" )
601+
602+ def _apply_fn_to_data (self , fn ):
603+ raise NotImplementedError ("Subclasses must implement _apply_fn_to_data" )
555604
556- def _get_to_kwargs (self , * args , ** kwargs ):
557- # `torch._C._nn._parse_to` can't handle `layout` argument
558- for arg in args :
559- if isinstance (arg , torch .layout ):
560- args .remove (arg )
561- if "layout" in kwargs :
562- kwargs .pop ("layout" )
563- # ignoring `non_blocking` and `memory_format` args since these are not
564- # very useful for most of the tensor subclasses
565- # if in the future there are use cases that need these, we'd recommend
566- # to override `_get_to_kwargs` and return these args
567- device , dtype , _ , _ = torch ._C ._nn ._parse_to (* args , ** kwargs )
568- device = self .device if device is None else device
569- dtype = self .dtype if dtype is None else dtype
570- kwargs = {
571- "device" : device ,
572- "dtype" : dtype ,
573- }
574- return kwargs
605+ def get_layout (self ):
606+ if not hasattr (self , "_layout" ):
607+ return None
608+ return self ._layout
575609
576610
577611def fill_defaults (args , n , defaults_tail ):
@@ -599,7 +633,7 @@ def fill_defaults(args, n, defaults_tail):
599633 return r
600634
601635
602- ## Deprecated, will be deleted in the future
636+ # Deprecated, will be deleted in the future
603637def _torch_version_at_least (min_version ):
604638 return is_fbcode () or version ("torch" ) >= min_version
605639
0 commit comments