Skip to content

Deprecating Int8DynActInt4WeightQuantizer #1332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions torchchat/generate.py
Original file line number Diff line number Diff line change
@@ -1160,17 +1160,9 @@ def callback(x, *, done_generating=False):
t - aggregate_metrics.get("time_to_first_token", 0)
)

if jit_compile:
print(
f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds"
)
aggregate_metrics["tokens_per_sec_jit_compile"] = tokens_sec
# Don't continue here.... because we need to report and reset
# continue
else:
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)
aggregate_metrics["tokens_per_sec"].append(tokens_sec)
aggregate_metrics["first_token_per_sec"].append(first_token_sec)
aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec)

logging.info(
f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\
25 changes: 16 additions & 9 deletions torchchat/utils/quantize.py
Original file line number Diff line number Diff line change
@@ -37,8 +37,9 @@
from torchao.quantization.quant_api import (
int4_weight_only,
Int4WeightOnlyQuantizer,
Int8DynActInt4WeightQuantizer,
quantize_,
int8_dynamic_activation_int4_weight,
Int8DynActInt4WeightQuantizer,
)
from torchao.utils import unwrap_tensor_subclass
from torchchat.utils.build_utils import (
@@ -63,10 +64,10 @@
def get_named_parameters(func: Callable) -> List[str]:
# Get the signature of the function
signature = inspect.signature(func)

# Extract the parameters from the signature
parameters = signature.parameters

# Filter and return named parameters
named_params = [
name for name, param in parameters.items()
@@ -80,8 +81,8 @@ def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer:
print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.")
del q_kwargs[key]
return q_kwargs


#########################################################################
### torchchat quantization API ###

@@ -116,15 +117,21 @@ def quantize_model(
if not support_tensor_subclass:
unwrap_tensor_subclass(model)
continue


# if quantizer == "linear:a8w4dq":
# quantize_(model, int8_dynamic_activation_int4_weight(group_size=q_kwargs["groupsize"]))
# if not support_tensor_subclass:
# unwrap_tensor_subclass(model)
# continue

if quantizer in ["linear:a8wxdq", "embedding:wx"]:
# These quantizers require float32 input weights. Note that after quantization,
# the weights will no longer be float32, but lowbit integers
if get_precision() != torch.float32:
print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.")
set_precision(torch.float32)
# We set global precision from quantize options if it is specified at cli.py:485

# We set global precision from quantize options if it is specified at cli.py:485
# so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat
precision = get_precision()

@@ -936,7 +943,7 @@ class ErrorHandler(QuantHandler):
def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None):
global torchao_experimental_load_error
raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}")

torchao_experimental_load_error = e
quantizer_class_dict["linear:a8wxdq"] = ErrorHandler
quantizer_class_dict["embedding:wx"] = ErrorHandler