You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
"""Convert the weight of linear modules in the model with `apply_tensor_subclass`, model is modified inplace
491
+
"""Convert the weight of linear modules in the model with `config`, model is modified inplace
492
492
493
493
Args:
494
494
model (torch.nn.Module): input model
495
-
apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]): function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor)
496
-
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `apply_tensor_subclass` on
495
+
config (Union[AOBaseWorkflowConfig, Callable[[torch.nn.Module], torch.nn.Module]]): either (1) a workflow configuration object or (2) a function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor). Note: (2) will be deleted in a future release.
496
+
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `config` on
497
497
the weight of the module
498
498
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to True)
499
499
device (device, optional): Device to move module to before applying `filter_fn`. This can be set to `"cuda"` to speed up quantization. The final model will be on the specified `device`.
@@ -505,7 +505,7 @@ def quantize_(
505
505
import torch.nn as nn
506
506
from torchao import quantize_
507
507
508
-
# 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to
508
+
# quantize with some predefined `config` method that corresponds to
509
509
# optimized execution paths or kernels (e.g. int4 tinygemm kernel)
510
510
# also customizable with arguments
511
511
# currently options are
@@ -518,43 +518,13 @@ def quantize_(
518
518
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
519
519
quantize_(m, int4_weight_only(group_size=32))
520
520
521
-
# 2. write your own new apply_tensor_subclass
522
-
# You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor
523
-
# on weight
524
-
525
-
from torchao.dtypes import to_affine_quantized_intx
526
-
527
-
# weight only uint4 asymmetric groupwise quantization
# TODO(after discussion): flesh the BC story out more
537
+
# old behavior, keep to avoid breaking BC
538
+
warnings.warn("""Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/pull/1595 for instructions on how to pass in workflow configuration instead.""")
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
0 commit comments