Skip to content

Commit 657a00f

Browse files
authored
Revert "[lint] reformat qat files" (#2104)
Revert "[lint] reformat qat files (#2090)" This reverts commit 10de698.
1 parent b8206d7 commit 657a00f

File tree

2 files changed

+12
-11
lines changed

2 files changed

+12
-11
lines changed

torchao/quantization/qat/embedding.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def convert(
196196
"""
197197
self._convert_helper(model)
198198
return model
199-
199+
200200
@staticmethod
201201
def quantize_weights(
202202
weight: torch.Tensor,
@@ -207,11 +207,12 @@ def quantize_weights(
207207
Helper function to quantize weights
208208
"""
209209
(qmin, qmax) = _get_qmin_qmax(bit_width)
210-
(s, zp) = get_group_qparams_symmetric(weight, bit_width, group_size)
210+
(s, zp) = get_group_qparams_symmetric(
211+
weight, bit_width, group_size
212+
)
211213
from torchao._executorch_ops import (
212214
_quantized_decomposed_quantize_per_channel_group_wrapper,
213215
)
214-
215216
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
216217
weight,
217218
s,
@@ -223,6 +224,7 @@ def quantize_weights(
223224
)
224225
return (q_weight, s, zp)
225226

227+
226228
def _convert_helper(self, module: torch.nn.Module):
227229
"""
228230
Helper function to recursively swap `Int4WeightOnlyQATEmbedding`
@@ -253,9 +255,7 @@ def _convert_helper(self, module: torch.nn.Module):
253255
)
254256
setattr(module, name, quantized_embedding)
255257

256-
q_weight, s, zp = self.quantize_weights(
257-
child.weight, self.bit_width, group_size
258-
)
258+
q_weight, s, zp = self.quantize_weights(child.weight, self.bit_width, group_size)
259259
# Load weights and qparams into quantized embedding
260260
quantized_embedding.weight = q_weight
261261
quantized_embedding.scale = s.to(scale_precision)

torchao/quantization/qat/linear.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def convert(
197197
) -> torch.nn.Module:
198198
self._convert_qat_linear_8da4w(model)
199199
return model
200-
200+
201201
@staticmethod
202202
def quantize_weights(
203203
weight: torch.Tensor,
@@ -209,7 +209,9 @@ def quantize_weights(
209209
# Load weights and qparams into quantized linear
210210
n_bit = 4
211211
(qmin, qmax) = _get_qmin_qmax(n_bit)
212-
(s, zp) = get_group_qparams_symmetric(weight, n_bit, group_size)
212+
(s, zp) = get_group_qparams_symmetric(
213+
weight, n_bit, group_size
214+
)
213215
from torchao._executorch_ops import (
214216
_quantized_decomposed_quantize_per_channel_group_wrapper,
215217
)
@@ -225,6 +227,7 @@ def quantize_weights(
225227
)
226228
return (q_weight, s, zp)
227229

230+
228231
def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
229232
"""
230233
Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`.
@@ -242,9 +245,7 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
242245
)
243246
setattr(module, name, quantized_linear)
244247

245-
q_weight, scales, zeros = self.quantize_weights(
246-
child.weight, config.group_size
247-
)
248+
q_weight, scales, zeros = self.quantize_weights(child.weight, config.group_size)
248249
quantized_linear.weight = q_weight
249250
quantized_linear.scales = scales
250251
quantized_linear.zeros = zeros

0 commit comments

Comments
 (0)