@@ -512,6 +512,27 @@ def _get_tensor_impl_constructor(
512512 return tensor_class ._LAYOUT_CONSTRUCTOR_TABLE [layout_class ]
513513
514514
515+ def _get_to_kwargs (self , * args , ** kwargs ):
516+ # `torch._C._nn._parse_to` can't handle `layout` argument
517+ for arg in args :
518+ if isinstance (arg , torch .layout ):
519+ args .remove (arg )
520+ if "layout" in kwargs :
521+ kwargs .pop ("layout" )
522+ # ignoring `non_blocking` and `memory_format` args since these are not
523+ # very useful for most of the tensor subclasses
524+ # if in the future there are use cases that need these, we'd recommend
525+ # to override `_get_to_kwargs` and return these args
526+ device , dtype , _ , _ = torch ._C ._nn ._parse_to (* args , ** kwargs )
527+ device = self .device if device is None else device
528+ dtype = self .dtype if dtype is None else dtype
529+ kwargs = {
530+ "device" : device ,
531+ "dtype" : dtype ,
532+ }
533+ return kwargs
534+
535+
515536class TorchAOBaseTensor (torch .Tensor ):
516537 """A util tensor subclass that provides commonly used functions
517538 new tensor subclass can inherit it to get all the utility functions
@@ -552,26 +573,27 @@ class PlainAQTTensorImpl(...):
552573 __torch_function__ = classmethod (_dispatch__torch_function__ )
553574 register_layout = classmethod (_register_layout )
554575 get_tensor_impl_constructor = classmethod (_get_tensor_impl_constructor )
576+ _get_to_kwargs = _get_to_kwargs
577+
578+ def __tensor_flatten__ (self ):
579+ raise NotImplementedError ("Subclasses must implement __tensor_flatten__" )
580+
581+ @classmethod
582+ def __tensor_unflatten__ (
583+ cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
584+ ):
585+ raise NotImplementedError ("Subclasses must implement __tensor_unflatten__" )
586+
587+ def __repr__ (self ):
588+ raise NotImplementedError ("Subclasses must implement __repr__" )
589+
590+ def _apply_fn_to_data (self , fn ):
591+ raise NotImplementedError ("Subclasses must implement _apply_fn_to_data" )
555592
556- def _get_to_kwargs (self , * args , ** kwargs ):
557- # `torch._C._nn._parse_to` can't handle `layout` argument
558- for arg in args :
559- if isinstance (arg , torch .layout ):
560- args .remove (arg )
561- if "layout" in kwargs :
562- kwargs .pop ("layout" )
563- # ignoring `non_blocking` and `memory_format` args since these are not
564- # very useful for most of the tensor subclasses
565- # if in the future there are use cases that need these, we'd recommend
566- # to override `_get_to_kwargs` and return these args
567- device , dtype , _ , _ = torch ._C ._nn ._parse_to (* args , ** kwargs )
568- device = self .device if device is None else device
569- dtype = self .dtype if dtype is None else dtype
570- kwargs = {
571- "device" : device ,
572- "dtype" : dtype ,
573- }
574- return kwargs
593+ def get_layout (self ):
594+ if not hasattr (self , "_layout" ):
595+ return None
596+ return self ._layout
575597
576598
577599def fill_defaults (args , n , defaults_tail ):
0 commit comments