diff --git a/torchchat/generate.py b/torchchat/generate.py index 397f9e801..f4634899b 100644 --- a/torchchat/generate.py +++ b/torchchat/generate.py @@ -1160,17 +1160,9 @@ def callback(x, *, done_generating=False): t - aggregate_metrics.get("time_to_first_token", 0) ) - if jit_compile: - print( - f"just-in-time compilation time (incl run time): {compilation_time:.2} seconds" - ) - aggregate_metrics["tokens_per_sec_jit_compile"] = tokens_sec - # Don't continue here.... because we need to report and reset - # continue - else: - aggregate_metrics["tokens_per_sec"].append(tokens_sec) - aggregate_metrics["first_token_per_sec"].append(first_token_sec) - aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) + aggregate_metrics["tokens_per_sec"].append(tokens_sec) + aggregate_metrics["first_token_per_sec"].append(first_token_sec) + aggregate_metrics["next_tokens_per_sec"].append(next_tokens_sec) logging.info( f"\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\ diff --git a/torchchat/utils/quantize.py b/torchchat/utils/quantize.py index 31c639dfd..fb47f9a19 100644 --- a/torchchat/utils/quantize.py +++ b/torchchat/utils/quantize.py @@ -37,8 +37,9 @@ from torchao.quantization.quant_api import ( int4_weight_only, Int4WeightOnlyQuantizer, - Int8DynActInt4WeightQuantizer, quantize_, + int8_dynamic_activation_int4_weight, + Int8DynActInt4WeightQuantizer, ) from torchao.utils import unwrap_tensor_subclass from torchchat.utils.build_utils import ( @@ -63,10 +64,10 @@ def get_named_parameters(func: Callable) -> List[str]: # Get the signature of the function signature = inspect.signature(func) - + # Extract the parameters from the signature parameters = signature.parameters - + # Filter and return named parameters named_params = [ name for name, param in parameters.items() @@ -80,8 +81,8 @@ def validate_args(named_params: List[str], q_kwargs: Dict[str, Any], quantizer: print(f"Specification for quantizer {quantizer} has extraneous key {key}. Ignoring.") del q_kwargs[key] return q_kwargs - - + + ######################################################################### ### torchchat quantization API ### @@ -116,15 +117,21 @@ def quantize_model( if not support_tensor_subclass: unwrap_tensor_subclass(model) continue - + + # if quantizer == "linear:a8w4dq": + # quantize_(model, int8_dynamic_activation_int4_weight(group_size=q_kwargs["groupsize"])) + # if not support_tensor_subclass: + # unwrap_tensor_subclass(model) + # continue + if quantizer in ["linear:a8wxdq", "embedding:wx"]: # These quantizers require float32 input weights. Note that after quantization, # the weights will no longer be float32, but lowbit integers if get_precision() != torch.float32: print(f"Quantizer {quantizer} requires float32 inputs, but received {get_precision()}. Changing dtype to float32. Note that after quantization, the weights will be lowbit integers, not float32.") set_precision(torch.float32) - - # We set global precision from quantize options if it is specified at cli.py:485 + + # We set global precision from quantize options if it is specified at cli.py:485 # so the precision returned by get_precision() is always the authoritative precision/dtype in torchchat precision = get_precision() @@ -936,7 +943,7 @@ class ErrorHandler(QuantHandler): def __init__(self, model: Optional[nn.Module]=None, device="cpu", precision=None): global torchao_experimental_load_error raise Exception(f"Note: Failed to load torchao experimental quantizer with error: {torchao_experimental_load_error}") - + torchao_experimental_load_error = e quantizer_class_dict["linear:a8wxdq"] = ErrorHandler quantizer_class_dict["embedding:wx"] = ErrorHandler