@@ -512,6 +512,27 @@ def _get_tensor_impl_constructor(
512
512
return tensor_class ._LAYOUT_CONSTRUCTOR_TABLE [layout_class ]
513
513
514
514
515
+ def _get_to_kwargs (self , * args , ** kwargs ):
516
+ # `torch._C._nn._parse_to` can't handle `layout` argument
517
+ for arg in args :
518
+ if isinstance (arg , torch .layout ):
519
+ args .remove (arg )
520
+ if "layout" in kwargs :
521
+ kwargs .pop ("layout" )
522
+ # ignoring `non_blocking` and `memory_format` args since these are not
523
+ # very useful for most of the tensor subclasses
524
+ # if in the future there are use cases that need these, we'd recommend
525
+ # to override `_get_to_kwargs` and return these args
526
+ device , dtype , _ , _ = torch ._C ._nn ._parse_to (* args , ** kwargs )
527
+ device = self .device if device is None else device
528
+ dtype = self .dtype if dtype is None else dtype
529
+ kwargs = {
530
+ "device" : device ,
531
+ "dtype" : dtype ,
532
+ }
533
+ return kwargs
534
+
535
+
515
536
class TorchAOBaseTensor (torch .Tensor ):
516
537
"""A util tensor subclass that provides commonly used functions
517
538
new tensor subclass can inherit it to get all the utility functions
@@ -552,26 +573,27 @@ class PlainAQTTensorImpl(...):
552
573
__torch_function__ = classmethod (_dispatch__torch_function__ )
553
574
register_layout = classmethod (_register_layout )
554
575
get_tensor_impl_constructor = classmethod (_get_tensor_impl_constructor )
576
+ _get_to_kwargs = _get_to_kwargs
577
+
578
+ def __tensor_flatten__ (self ):
579
+ raise NotImplementedError ("Subclasses must implement __tensor_flatten__" )
580
+
581
+ @classmethod
582
+ def __tensor_unflatten__ (
583
+ cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
584
+ ):
585
+ raise NotImplementedError ("Subclasses must implement __tensor_unflatten__" )
586
+
587
+ def __repr__ (self ):
588
+ raise NotImplementedError ("Subclasses must implement __repr__" )
589
+
590
+ def _apply_fn_to_data (self , fn ):
591
+ raise NotImplementedError ("Subclasses must implement _apply_fn_to_data" )
555
592
556
- def _get_to_kwargs (self , * args , ** kwargs ):
557
- # `torch._C._nn._parse_to` can't handle `layout` argument
558
- for arg in args :
559
- if isinstance (arg , torch .layout ):
560
- args .remove (arg )
561
- if "layout" in kwargs :
562
- kwargs .pop ("layout" )
563
- # ignoring `non_blocking` and `memory_format` args since these are not
564
- # very useful for most of the tensor subclasses
565
- # if in the future there are use cases that need these, we'd recommend
566
- # to override `_get_to_kwargs` and return these args
567
- device , dtype , _ , _ = torch ._C ._nn ._parse_to (* args , ** kwargs )
568
- device = self .device if device is None else device
569
- dtype = self .dtype if dtype is None else dtype
570
- kwargs = {
571
- "device" : device ,
572
- "dtype" : dtype ,
573
- }
574
- return kwargs
593
+ def get_layout (self ):
594
+ if not hasattr (self , "_layout" ):
595
+ return None
596
+ return self ._layout
575
597
576
598
577
599
def fill_defaults (args , n , defaults_tail ):
0 commit comments