Skip to content

Commit 5b9d876

Browse files
committed
Update
[ghstack-poisoned]
1 parent 24114ce commit 5b9d876

File tree

3 files changed

+25
-47
lines changed

3 files changed

+25
-47
lines changed

test/dtypes/test_affine_quantized.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
run_tests,
99
)
1010

11+
from torchao.core.config import AOBaseWorkflowConfig
1112
from torchao.dtypes import CutlassInt4PackedLayout, Int4CPULayout, SemiSparseLayout
1213
from torchao.quantization import (
1314
float8_weight_only,
1415
int4_weight_only,
1516
int8_dynamic_activation_int4_weight,
1617
int8_dynamic_activation_int8_weight,
1718
int8_weight_only,
19+
quantize_,
1820
)
1921
from torchao.quantization.quant_primitives import MappingType, ZeroPointDomain
2022
from torchao.utils import (
@@ -186,7 +188,10 @@ def test_flatten_unflatten(self, device, dtype):
186188
apply_quant_list = get_quantization_functions(False, True, device)
187189
for apply_quant in apply_quant_list:
188190
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
189-
ql = apply_quant(linear)
191+
if isinstance(apply_quant, AOBaseWorkflowConfig):
192+
quantize_(linear, apply_quant)
193+
else:
194+
ql = apply_quant(linear)
190195
lp_tensor = ql.weight
191196
tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__()
192197
tensor_data_dict = {

test/quantization/test_quant_api.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -762,17 +762,16 @@ def reset_memory():
762762
assert param.is_cuda
763763
self.assertLess(memory_streaming, memory_baseline)
764764

765+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
766+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
765767
def test_int4_weight_only_numerics(self):
766768
"""
767769
Simple test of e2e int4_weight_only workflow, comparing numerics
768770
to a bfloat16 baseline.
769771
"""
770-
# TODO(before land) skip on cpu-only
771-
# TODO(before land) support other inference techniques?
772-
773772
# set up inputs
774773
x = torch.randn(128, 128, device="cuda", dtype=torch.bfloat16)
775-
# TODO: model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
774+
# TODO(future): model in float32 leads to error: https://gist.github.com/vkuzo/63b3bcd7818393021a6e3fb4ccf3c469
776775
# is that expected?
777776
m_ref = torch.nn.Sequential(torch.nn.Linear(128, 128)).cuda().bfloat16()
778777
m_int4_wo = copy.deepcopy(m_ref)

torchao/quantization/quant_api.py

+16-42
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ def _replace_with_custom_fn_if_matches_filter(
262262
model = replacement_fn(model, *extra_args)
263263
return model
264264
else:
265-
for name, child in model.named_children():
265+
named_children_list = list(model.named_children())
266+
for name, child in named_children_list:
266267
new_child = _replace_with_custom_fn_if_matches_filter(
267268
child,
268269
replacement_fn,
@@ -480,20 +481,19 @@ def insert_subclass(lin):
480481

481482
def quantize_(
482483
model: torch.nn.Module,
483-
# apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
484-
apply_tensor_subclass: Union[
485-
Callable[[torch.nn.Module], torch.nn.Module], AOBaseWorkflowConfig
484+
config: Union[
485+
AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module]
486486
],
487487
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
488488
set_inductor_config: bool = True,
489489
device: Optional[torch.types.Device] = None,
490490
):
491-
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace
491+
"""Convert the weight of linear modules in the model with `config`, model is modified inplace
492492
493493
Args:
494494
model (torch.nn.Module): input model
495-
apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]): function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor)
496-
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
495+
config (Union[AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module]]): either (1) a workflow configuration object or (2) a function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor). Note: (2) will be deleted in a future release.
496+
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `config` on
497497
the weight of the module
498498
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
499499
device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`.
@@ -505,7 +505,7 @@ def quantize_(
505505
import torch.nn as nn
506506
from torchao import quantize_
507507
508-
# 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to
508+
# quantize with some predefined `config` method that corresponds to
509509
# optimized execution paths or kernels (e.g. int4 tinygemm kernel)
510510
# also customizable with arguments
511511
# currently options are
@@ -518,43 +518,13 @@ def quantize_(
518518
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
519519
quantize_(m, int4_weight_only(group_size=32))
520520
521-
# 2. write your own new apply_tensor_subclass
522-
# You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor
523-
# on weight
524-
525-
from torchao.dtypes import to_affine_quantized_intx
526-
527-
# weight only uint4 asymmetric groupwise quantization
528-
groupsize = 32
529-
apply_weight_quant = lambda x: to_affine_quantized_intx(
530-
x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6,
531-
zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float")
532-
533-
def apply_weight_quant_to_linear(linear):
534-
linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False)
535-
return linear
536-
537-
# apply to modules under block0 submodule
538-
def filter_fn(module: nn.Module, fqn: str) -> bool:
539-
return isinstance(module, nn.Linear)
540-
541-
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
542-
quantize_(m, apply_weight_quant_to_linear, filter_fn)
543-
544521
"""
545522
if set_inductor_config:
546523
torchao.quantization.utils.recommended_inductor_config_setter()
547524

548-
if isinstance(apply_tensor_subclass, AOBaseWorkflowConfig):
549-
# new behavior
550-
551-
# make the variable name make sense
552-
config = apply_tensor_subclass
525+
if isinstance(config, AOBaseWorkflowConfig):
553526
handler = _QUANTIZE_CONFIG_HANDLER[type(config)]
554-
555527
# for each linear in the model, apply the transform if filtering passes
556-
# key difference from old is that `config_with_transform` is easily
557-
# inspectable
558528
_replace_with_custom_fn_if_matches_filter(
559529
model,
560530
handler,
@@ -564,8 +534,12 @@ def filter_fn(module: nn.Module, fqn: str) -> bool:
564534
)
565535

566536
else:
567-
# old behavior, for now keep for BC purposes
568-
# TODO(after discussion): flesh the BC story out more
537+
# old behavior, keep to avoid breaking BC
538+
warnings.warn("""Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/pull/1595 for instructions on how to pass in workflow configuration instead.""")
539+
540+
# make the variable name make sense
541+
apply_tensor_subclass = config
542+
569543
_replace_with_custom_fn_if_matches_filter(
570544
model,
571545
apply_tensor_subclass,
@@ -773,7 +747,7 @@ def _int4_weight_only_transform(
773747
logger.info(
774748
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
775749
)
776-
return weight
750+
return module
777751

778752
mapping_type = MappingType.ASYMMETRIC
779753
block_size = (1, group_size)

0 commit comments

Comments
 (0)