9
9
10
10
import torch
11
11
import torch .nn .utils .parametrize as parametrize
12
+ from torch .utils ._python_dispatch import return_and_correct_aliasing
12
13
13
14
__all__ = [
14
15
"benchmark_model" ,
@@ -512,9 +513,15 @@ def _get_tensor_impl_constructor(
512
513
return tensor_class ._LAYOUT_CONSTRUCTOR_TABLE [layout_class ]
513
514
514
515
516
+ """
517
+ TensorAOBase subclass is a util tensor subclass that provides commonly used functions, and should be inherited to define a new tensor subclass
518
+ """
519
+
520
+
515
521
class TorchAOBaseTensor (torch .Tensor ):
516
522
"""A util tensor subclass that provides commonly used functions
517
- new tensor subclass can inherit it to get all the utility functions
523
+ new tensor subclass can inherit it to get all the utility functions, and
524
+ should be inherited to define a new tensor subclass
518
525
519
526
class MyTensor(TorchAOBaseTensor):
520
527
pass
@@ -539,12 +546,17 @@ def _(func, types, args, kwargs):
539
546
class PlainAQTTensorImpl(...):
540
547
...
541
548
542
- `get_tensor_impl_constructor`:
549
+ `get_tensor_impl_constructor`:
543
550
get_tensor_impl_constructor = MyTensor.get_tensor_impl_constructor
544
551
# in constructor of MyTensor:
545
552
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
546
553
tensor_impl = tensor_impl_ctr(data, scale, zero_point, _layout)
547
554
555
+ `__tensor_flatten__` and `__tensor_unflatten__`: Used for tensor serialization/deserialization, must be defined in the subclass
556
+
557
+ `__repr__`: Used for tensor representation, must be defined in the subclass
558
+
559
+ `_apply_fn_to_data`: Used for applying a function to the data of the tensor, must be defined in the subclass
548
560
"""
549
561
550
562
implements = classmethod (_implements )
@@ -573,6 +585,58 @@ def _get_to_kwargs(self, *args, **kwargs):
573
585
}
574
586
return kwargs
575
587
588
+ def __tensor_flatten__ (self ):
589
+ raise NotImplementedError ("Subclasses must implement __tensor_flatten__" )
590
+
591
+ @classmethod
592
+ def __tensor_unflatten__ (
593
+ cls , tensor_data_dict , tensor_attributes , outer_size , outer_stride
594
+ ):
595
+ raise NotImplementedError ("Subclasses must implement __tensor_unflatten__" )
596
+
597
+ def __repr__ (self ):
598
+ raise NotImplementedError ("Subclasses must implement __repr__" )
599
+
600
+ def _apply_fn_to_data (self , fn ):
601
+ raise NotImplementedError ("Subclasses must implement _apply_fn_to_data" )
602
+
603
+ def get_layout (self ):
604
+ if not hasattr (self , "_layout" ):
605
+ return None
606
+ return self ._layout
607
+
608
+
609
+ implements = TorchAOBaseTensor .implements
610
+ aten = torch .ops .aten
611
+
612
+
613
+ @implements (aten .detach .default )
614
+ def _ (func , types , args , kwargs ):
615
+ return return_and_correct_aliasing (
616
+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .detach )
617
+ )
618
+
619
+
620
+ @implements (aten .clone .default )
621
+ def _ (func , types , args , kwargs ):
622
+ return return_and_correct_aliasing (
623
+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .clone )
624
+ )
625
+
626
+
627
+ @implements (aten .t .default )
628
+ def _ (func , types , args , kwargs ):
629
+ return return_and_correct_aliasing (
630
+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .t )
631
+ )
632
+
633
+
634
+ @implements (aten .slice .Tensor )
635
+ def _ (func , types , args , kwargs ):
636
+ return return_and_correct_aliasing (
637
+ func , args , kwargs , args [0 ]._apply_fn_to_data (torch .slice )
638
+ )
639
+
576
640
577
641
def fill_defaults (args , n , defaults_tail ):
578
642
"""
@@ -599,7 +663,7 @@ def fill_defaults(args, n, defaults_tail):
599
663
return r
600
664
601
665
602
- ## Deprecated, will be deleted in the future
666
+ # Deprecated, will be deleted in the future
603
667
def _torch_version_at_least (min_version ):
604
668
return is_fbcode () or version ("torch" ) >= min_version
605
669
0 commit comments