Skip to content

Commit f872136

Browse files
committed
Move boilerplate code to TensorAO base class
1 parent 9ecdb3b commit f872136

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

torchao/utils.py

+29
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,35 @@ def _get_to_kwargs(self, *args, **kwargs):
573573
}
574574
return kwargs
575575

576+
def __tensor_flatten__(self):
577+
return ["tensor_impl"], [
578+
self.block_size,
579+
self.shape,
580+
self.quant_min,
581+
self.quant_max,
582+
self.zero_point_domain,
583+
self.dtype,
584+
]
585+
586+
@classmethod
587+
def __tensor_unflatten__(
588+
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
589+
):
590+
tensor_impl = tensor_data_dict["tensor_impl"]
591+
block_size, shape, quant_min, quant_max, zero_point_domain, dtype = (
592+
tensor_attributes
593+
)
594+
return cls(
595+
tensor_impl,
596+
block_size,
597+
shape if outer_size is None else outer_size,
598+
quant_min,
599+
quant_max,
600+
zero_point_domain,
601+
dtype=dtype,
602+
strides=outer_stride,
603+
)
604+
576605

577606
def fill_defaults(args, n, defaults_tail):
578607
"""

0 commit comments

Comments
 (0)