forked from pytorch/ao
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathobserver.py
248 lines (200 loc) · 8.53 KB
/
observer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
import torch
from .quant_primitives import (
_get_reduction_params,
choose_qparams_affine_with_min_max,
MappingType,
ZeroPointDomain,
)
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Tuple, Optional, Any
from functools import partial
import logging
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class GranularityType:
"""
Base class for representing the granularity of quantization.
This class serves as a parent for specific granularity types used in
quantization operations, such as per-tensor or per-axis quantization.
"""
pass
@dataclass(frozen=True)
class PerTensor(GranularityType):
"""
Represents per-tensor granularity in quantization.
This granularity type calcualtes the quantization parameters
based off the entire tensor.
"""
pass
@dataclass(frozen=True)
class PerAxis(GranularityType):
"""
Represents per-axis granularity in quantization.
This granularity type calcualtes different quantization parameters
along a specified axis of the tensor.
For example if the input tensor is shape [8, 16] and axis=0, then
the quantization parameters are calculated for each row of the tensor.
Giving a total of 8 quantization parameters.
Attributes:
axis (int): The axis along which reduction is performed.
"""
axis: int
@dataclass(frozen=True)
class PerGroup(GranularityType):
"""
Represents per-channel group granularity in quantization.
This granularity type calcualtes different quantization parameters
for each group of <group_size> elements.
For example if the input tensor is shape [8, 16], and the group size is 4, then
the input tensor is reshaped to [64, 4]
quantization parameters are calculated for each group of 4 elements,
giving a total of 64 quantization parameters.
Attributes:
group_size (int): The size of each quantization group
"""
group_size: int
class PerRow(GranularityType):
"""
Represents row-wise granularity in quantization.
This is a special case of per-axis quantization and is unique to Float8 matmuls
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
is quantized with a block_size of (1, weight.shape[1]).
"""
pass
# borrowed from torch.ao.quantization.observer
class _PartialWrapper:
def __init__(self, p):
self.p = p
def __call__(self, *args, **keywords):
return self.p(*args, **keywords)
def __repr__(self):
return self.p.__repr__()
def with_args(self, *args, **kwargs):
return _with_args(self, *args, **kwargs)
def _with_args(cls_or_self, *args, **kwargs):
r"""Wrapper that allows creation of class factories.
This can be useful when there is a need to create classes with the same
constructor arguments, but different instances.
Example::
>>> # xdoctest: +SKIP("Undefined vars")
>>> Foo.with_args = classmethod(_with_args)
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
>>> foo_instance1 = foo_builder()
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1) == id(foo_instance2)
False
"""
r = _PartialWrapper(partial(cls_or_self, *args, **kwargs))
return r
def get_block_size(
input_shape: Tuple[int, ...], granularity_type: GranularityType
) -> Tuple[int, ...]:
"""Get the block size based on the input shape and granularity type.
Args:
input_shape: The input tensor shape possibly more than 2 dimensions
granularity_type: The granularity type of the quantization
"""
if isinstance(granularity_type, PerTensor):
return input_shape
elif isinstance(granularity_type, PerAxis):
block_size = list(input_shape)
block_size[granularity_type.axis] = 1
return tuple(block_size)
elif isinstance(granularity_type, PerRow):
return (1,) * (len(input_shape) - 1) + (input_shape[-1],)
raise ValueError(f"Unsupported GranularityType: {granularity_type}")
ABC: Any = ABCMeta("ABC", (object,), {}) # compatible with Python 2 *and* 3:
class AffineQuantizedObserverBase(ABC, torch.nn.Module):
"""Observer module for affine quantization (https://github.com/pytorch/ao/tree/main/torchao/quantization#affine-quantization)
Args:
`granularity_type` and `block_size`: The granularity of the quantization,
must specify at least one, if both are specified `block_size` takes precedence
Current supported granularity type are `PerTensor` and `PerAxis`
other args: please see `:class:torchao.dtypes.AffineQuantizedTensor`
"""
with_args = classmethod(_with_args)
def __init__(
self,
mapping_type: MappingType,
target_dtype: torch.dtype,
granularity_type: GranularityType,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: Optional[ZeroPointDomain] = ZeroPointDomain.INT,
):
super().__init__()
assert granularity_type is not None, "granularity_type is None"
self.mapping_type = mapping_type
self.target_dtype = target_dtype
self.granularity_type = granularity_type
self.quant_min = quant_min
self.quant_max = quant_max
self.eps = eps
self.scale_dtype = scale_dtype
self.zero_point_dtype = zero_point_dtype
self.preserve_zero = preserve_zero
self.zero_point_domain = zero_point_domain
@abstractmethod
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""forward function should take the input tensor
and updates internal stats and return the original input Tensor
"""
pass
@abstractmethod
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate quantization parameter based on the stats attached to the observer module
and returns a tuple of scale and zero_point Tensor
"""
pass
class AffineQuantizedMinMaxObserver(AffineQuantizedObserverBase):
def forward(self, input: torch.Tensor):
if input.numel() == 0:
return input
input_detached = input.detach()
assert self.granularity_type is not None, "granularity_type is None"
block_size = get_block_size(input_detached.shape, self.granularity_type)
shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input_detached.size()
)
input_detached = input_detached.view(shape_for_reduction)
min_val = torch.amin(input_detached, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input_detached, dim=reduction_dims, keepdim=False)
if not hasattr(self, "min_val") or not hasattr(self, "max_val"):
self.min_val = min_val
self.max_val = max_val
else:
assert self.min_val.shape == min_val.shape, f"Can't update existing min_val - shape mismatch, self.min_val:{self.min_val.shape} != min_val:{min_val.shape}"
assert self.max_val.shape == max_val.shape, f"Can't update existing max_val - shape mismatch, self.max_val {self.max_val.shape} != max_val:{max_val.shape}"
min_val = torch.min(self.min_val, min_val)
max_val = torch.max(self.max_val, max_val)
self.min_val.copy_(min_val)
self.max_val.copy_(max_val)
# returning original input
return input
def calculate_qparams(self) -> Tuple[torch.Tensor, torch.Tensor]:
assert (
hasattr(self, "min_val") and hasattr(self, "max_val")
), "Expecting the observer has min_val and max_val, please run the observer before calling calculate_qparams"
return choose_qparams_affine_with_min_max(
self.min_val,
self.max_val,
self.mapping_type,
[], # BlockSize is not needed because the min/max are already reduced
self.target_dtype,
self.quant_min,
self.quant_max,
self.eps,
self.scale_dtype,
self.zero_point_dtype,
self.preserve_zero,
self.zero_point_domain,
)
if TORCH_VERSION_AT_LEAST_2_5:
# Allow a model with LinearActivationQuantizedTensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([PerRow, PerTensor])