Skip to content

Commit 10de698

Browse files
[lint] reformat qat files (#2090)
reformat qat files
1 parent c52dcd4 commit 10de698

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

torchao/quantization/qat/embedding.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def convert(
196196
"""
197197
self._convert_helper(model)
198198
return model
199-
199+
200200
@staticmethod
201201
def quantize_weights(
202202
weight: torch.Tensor,
@@ -207,12 +207,11 @@ def quantize_weights(
207207
Helper function to quantize weights
208208
"""
209209
(qmin, qmax) = _get_qmin_qmax(bit_width)
210-
(s, zp) = get_group_qparams_symmetric(
211-
weight, bit_width, group_size
212-
)
210+
(s, zp) = get_group_qparams_symmetric(weight, bit_width, group_size)
213211
from torchao._executorch_ops import (
214212
_quantized_decomposed_quantize_per_channel_group_wrapper,
215213
)
214+
216215
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
217216
weight,
218217
s,
@@ -224,7 +223,6 @@ def quantize_weights(
224223
)
225224
return (q_weight, s, zp)
226225

227-
228226
def _convert_helper(self, module: torch.nn.Module):
229227
"""
230228
Helper function to recursively swap `Int4WeightOnlyQATEmbedding`
@@ -255,7 +253,9 @@ def _convert_helper(self, module: torch.nn.Module):
255253
)
256254
setattr(module, name, quantized_embedding)
257255

258-
q_weight, s, zp = self.quantize_weights(child.weight, self.bit_width, group_size)
256+
q_weight, s, zp = self.quantize_weights(
257+
child.weight, self.bit_width, group_size
258+
)
259259
# Load weights and qparams into quantized embedding
260260
quantized_embedding.weight = q_weight
261261
quantized_embedding.scale = s.to(scale_precision)

torchao/quantization/qat/linear.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def convert(
197197
) -> torch.nn.Module:
198198
self._convert_qat_linear_8da4w(model)
199199
return model
200-
200+
201201
@staticmethod
202202
def quantize_weights(
203203
weight: torch.Tensor,
@@ -209,9 +209,7 @@ def quantize_weights(
209209
# Load weights and qparams into quantized linear
210210
n_bit = 4
211211
(qmin, qmax) = _get_qmin_qmax(n_bit)
212-
(s, zp) = get_group_qparams_symmetric(
213-
weight, n_bit, group_size
214-
)
212+
(s, zp) = get_group_qparams_symmetric(weight, n_bit, group_size)
215213
from torchao._executorch_ops import (
216214
_quantized_decomposed_quantize_per_channel_group_wrapper,
217215
)
@@ -227,7 +225,6 @@ def quantize_weights(
227225
)
228226
return (q_weight, s, zp)
229227

230-
231228
def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
232229
"""
233230
Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`.
@@ -245,7 +242,9 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
245242
)
246243
setattr(module, name, quantized_linear)
247244

248-
q_weight, scales, zeros = self.quantize_weights(child.weight, config.group_size)
245+
q_weight, scales, zeros = self.quantize_weights(
246+
child.weight, config.group_size
247+
)
249248
quantized_linear.weight = q_weight
250249
quantized_linear.scales = scales
251250
quantized_linear.zeros = zeros

0 commit comments

Comments
 (0)