-
Notifications
You must be signed in to change notification settings - Fork 317
Make SmoothQuant more General #2728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
c482371
e16edc2
5ec0dcf
2475ad1
ccb7b84
ba89d03
a6df6af
0fc6539
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,26 @@ | ||
# SmothQuant quantization | ||
# SmoothQuant quantization | ||
This is a native PyTorch implementation of the algorithm described in [this paper](https://arxiv.org/abs/2211.10438). | ||
|
||
In this implementation, weights are smoothed (equalized) and quantized to int8 during quantization. Activations are smoothed and quantized to int8 at runtime. Quantization is done either dynamically or statically. If activations are dynamically quantized, qparams (i.e., scales) are found at runtime while qparams are found during quantization for static quantization. For dynamic quantization, activations are quantized per token. And for static quantization, activations are quantized per tensor. Generally, dynamic quantization produces better accuracy while static quantization has better latency. In both cases, weights and activations are symmetrically quantized. | ||
|
||
## Quick start | ||
Run the example code with | ||
```bash | ||
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> | ||
python example.py -m MODEL_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> | ||
# An example | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you need to update the API here as well? since we now take a config instead of just static/dynamic quant mode flag There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes now we can update here because there is no need for |
||
python example.py -m meta-llama/Llama-2-7b-hf --device=cuda --quant-mode=dynamic | ||
``` | ||
To use the `torch.compile` for speedup, add `--compile`. You may want to export `TORCHINDUCTOR_FREEZING=1` for even better performance. | ||
```bash | ||
TORCHINDUCTOR_FREEZING=1 python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --compile | ||
TORCHINDUCTOR_FREEZING=1 python example.py -m MODEL_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --compile | ||
``` | ||
To save a quantized model for reuse, specify `--model-save-path` | ||
```bash | ||
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-save-path ./quantized_model.pt | ||
python example.py -m MODEL_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-save-path ./quantized_model.pt | ||
``` | ||
And load it by `--model-load-path` | ||
```bash | ||
python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-load-path ./quantized_model.pt | ||
python example.py -m MODEL_ID --device=<cuda or cpu> --quant-mode=<dynamic or static> --model-load-path ./quantized_model.pt | ||
``` | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,13 @@ | ||
from .api import ( | ||
SmoothQuantConfig, | ||
insert_smooth_quant_observer_, | ||
load_smooth_quant_recipe, | ||
save_smooth_quant_recipe, | ||
from .api import SmoothQuantConfig | ||
from .core import ( | ||
SmoothQuantObservedLinear, | ||
SmoothQuantObserver, | ||
SmoothQuantStep, | ||
) | ||
from .core import SmoothQuantObservedLinear | ||
|
||
__all__ = [ | ||
"insert_smooth_quant_observer_", | ||
"load_smooth_quant_recipe", | ||
"save_smooth_quant_recipe", | ||
"SmoothQuantConfig", | ||
"SmoothQuantStep", | ||
"SmoothQuantObserver", | ||
"SmoothQuantObservedLinear", | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it you want to really make it general, the
quant_mode
has to be changed to base_config, and we'll do a general quantization like thisao/torchao/prototype/awq/api.py
Lines 104 to 106 in 2eae09b
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I missed it. Directly using the Quantization API is a choice.