Skip to content

Commit 896f61b

Browse files
authored
Revert D73201409 (#2105)
Summary: This diff reverts D73201409 Depends on D73414112 (The context such as a Sandcastle job, Task, SEV, etc. was not provided.) Reviewed By: navsud Differential Revision: D73414124
1 parent 657a00f commit 896f61b

File tree

2 files changed

+39
-64
lines changed

2 files changed

+39
-64
lines changed

torchao/quantization/qat/embedding.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Optional, Tuple
7+
from typing import Any, Optional
88

99
import torch
1010
import torch.nn.functional as F
@@ -196,40 +196,15 @@ def convert(
196196
"""
197197
self._convert_helper(model)
198198
return model
199-
200-
@staticmethod
201-
def quantize_weights(
202-
weight: torch.Tensor,
203-
bit_width: int,
204-
group_size: int,
205-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
206-
"""
207-
Helper function to quantize weights
208-
"""
209-
(qmin, qmax) = _get_qmin_qmax(bit_width)
210-
(s, zp) = get_group_qparams_symmetric(
211-
weight, bit_width, group_size
212-
)
213-
from torchao._executorch_ops import (
214-
_quantized_decomposed_quantize_per_channel_group_wrapper,
215-
)
216-
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
217-
weight,
218-
s,
219-
zp,
220-
qmin,
221-
qmax,
222-
torch.int8,
223-
group_size,
224-
)
225-
return (q_weight, s, zp)
226-
227199

228200
def _convert_helper(self, module: torch.nn.Module):
229201
"""
230202
Helper function to recursively swap `Int4WeightOnlyQATEmbedding`
231203
modules with `Int4WeightOnlyEmbedding`
232204
"""
205+
from torchao._executorch_ops import (
206+
_quantized_decomposed_quantize_per_channel_group_wrapper,
207+
)
233208

234209
for name, child in module.named_children():
235210
if isinstance(child, Int4WeightOnlyQATEmbedding):
@@ -255,8 +230,20 @@ def _convert_helper(self, module: torch.nn.Module):
255230
)
256231
setattr(module, name, quantized_embedding)
257232

258-
q_weight, s, zp = self.quantize_weights(child.weight, self.bit_width, group_size)
259233
# Load weights and qparams into quantized embedding
234+
(qmin, qmax) = _get_qmin_qmax(self.bit_width)
235+
(s, zp) = get_group_qparams_symmetric(
236+
child.weight, self.bit_width, group_size
237+
)
238+
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
239+
child.weight,
240+
s,
241+
zp,
242+
qmin,
243+
qmax,
244+
torch.int8,
245+
group_size,
246+
)
260247
quantized_embedding.weight = q_weight
261248
quantized_embedding.scale = s.to(scale_precision)
262249
quantized_embedding.zero_point = zp.to(zero_point_precision)

torchao/quantization/qat/linear.py

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, Optional, Tuple
7+
from typing import Any, Optional
88

99
import torch
1010
import torch.nn.functional as F
@@ -197,36 +197,6 @@ def convert(
197197
) -> torch.nn.Module:
198198
self._convert_qat_linear_8da4w(model)
199199
return model
200-
201-
@staticmethod
202-
def quantize_weights(
203-
weight: torch.Tensor,
204-
group_size: int,
205-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
206-
"""
207-
Helper function to quantize weights
208-
"""
209-
# Load weights and qparams into quantized linear
210-
n_bit = 4
211-
(qmin, qmax) = _get_qmin_qmax(n_bit)
212-
(s, zp) = get_group_qparams_symmetric(
213-
weight, n_bit, group_size
214-
)
215-
from torchao._executorch_ops import (
216-
_quantized_decomposed_quantize_per_channel_group_wrapper,
217-
)
218-
219-
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
220-
weight,
221-
s,
222-
zp,
223-
qmin,
224-
qmax,
225-
torch.int8,
226-
group_size,
227-
)
228-
return (q_weight, s, zp)
229-
230200

231201
def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
232202
"""
@@ -245,10 +215,28 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
245215
)
246216
setattr(module, name, quantized_linear)
247217

248-
q_weight, scales, zeros = self.quantize_weights(child.weight, config.group_size)
218+
# Load weights and qparams into quantized linear
219+
n_bit = 4
220+
(qmin, qmax) = _get_qmin_qmax(n_bit)
221+
(s, zp) = get_group_qparams_symmetric(
222+
child.weight, n_bit, config.group_size
223+
)
224+
from torchao._executorch_ops import (
225+
_quantized_decomposed_quantize_per_channel_group_wrapper,
226+
)
227+
228+
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
229+
child.weight,
230+
s,
231+
zp,
232+
qmin,
233+
qmax,
234+
torch.int8,
235+
config.group_size,
236+
)
249237
quantized_linear.weight = q_weight
250-
quantized_linear.scales = scales
251-
quantized_linear.zeros = zeros
238+
quantized_linear.scales = s
239+
quantized_linear.zeros = zp
252240
if child.bias is not None:
253241
quantized_linear.bias = child.bias
254242
else:

0 commit comments

Comments
 (0)