Skip to content

Commit

Permalink
Add wint4afp8 & fix fp quant bugs (#288)
Browse files Browse the repository at this point in the history
  • Loading branch information
helloyongyang authored Jan 8, 2025
1 parent 3977a82 commit dbfd394
Show file tree
Hide file tree
Showing 19 changed files with 84 additions and 31 deletions.
3 changes: 2 additions & 1 deletion configs/quantization/backend/sglang/fp8/awq_fp8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ eval:
inference_per_block: False
quant:
method: Awq
quant_type: float-quant
weight:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
granularity: per_channel
use_qtorch: True
act:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/backend/sglang/fp8/awq_fp8_static.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ eval:
inference_per_block: False
quant:
method: Awq
quant_type: float-quant
weight:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
granularity: per_tensor
use_qtorch: True
act:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/backend/sglang/fp8/gptq_fp8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ eval:
inference_per_block: False
quant:
method: GPTQ
quant_type: float-quant
weight:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
granularity: per_channel
use_qtorch: True
act:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/backend/sglang/fp8/rtn_fp8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ eval:
inference_per_block: False
quant:
method: RTN
quant_type: float-quant
weight:
quant_type: float-quant
bit: e4m3
symmetric: True
granularity: per_channel
use_qtorch: True
act:
quant_type: float-quant
bit: e4m3
symmetric: True
granularity: per_token
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/backend/sglang/fp8/smoothquant_fp8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ eval:
seq_len: 2048
quant:
method: SmoothQuant
quant_type: float-quant
weight:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
granularity: per_channel
use_qtorch: True
act:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/backend/vllm/fp8/awq_fp8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ eval:
inference_per_block: False
quant:
method: Awq
quant_type: float-quant
weight:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
granularity: per_channel
use_qtorch: True
act:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/backend/vllm/fp8/awq_fp8_static.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ eval:
inference_per_block: False
quant:
method: Awq
quant_type: float-quant
weight:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
granularity: per_tensor
use_qtorch: True
act:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/backend/vllm/fp8/gptq_fp8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ eval:
inference_per_block: False
quant:
method: GPTQ
quant_type: float_quant
weight:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
granularity: per_channel
use_qtorch: True
act:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/backend/vllm/fp8/rtn_fp8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ eval:
inference_per_block: False
quant:
method: RTN
quant_type: float-quant
weight:
quant_type: float-quant
bit: e4m3
symmetric: True
granularity: per_channel
use_qtorch: True
act:
quant_type: float-quant
bit: e4m3
symmetric: True
granularity: per_token
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/backend/vllm/fp8/smoothquant_fp8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ eval:
seq_len: 2048
quant:
method: SmoothQuant
quant_type: float-quant
weight:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
granularity: per_channel
use_qtorch: True
act:
quant_type: float-quant
# Support ["e4m3", "e5m2"]
bit: e4m3
symmetric: True
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/methods/FP_Quant/awq_we2m1a16_g128.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ eval:
inference_per_block: False
quant:
method: Awq
quant_type: float-quant
weight:
quant_type: float-quant
bit: e2m1
symmetric: False
granularity: per_group
group_size: 128
use_qtorch: True
special:
quant_type: float-quant
trans: True
# The options for "trans_version" include "v1" and "v2".
trans_version: v2
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/methods/FP_Quant/gptq_we2m1a16_g128.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ eval:
inference_per_block: False
quant:
method: GPTQ
quant_type: float-quant
weight:
quant_type: float-quant
bit: e2m1
symmetric: True
granularity: per_group
group_size: 128
use_qtorch: True
special:
quant_type: float-quant
actorder: True
static_groups: False
percdamp: 0.01
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ eval:
inference_per_block: False
quant:
method: RTN
quant_type: float-quant
weight:
quant_type: float-quant
bit: e2m1
symmetric: True
granularity: per_group
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/methods/FP_Quant/rtn_we2m1ae2m1.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ eval:
inference_per_block: False
quant:
method: RTN
quant_type: float-quant
weight:
quant_type: float-quant
bit: e2m1
symmetric: True
granularity: per_channel
act:
quant_type: float-quant
bit: e2m1
symmetric: True
granularity: per_token
Expand Down
27 changes: 14 additions & 13 deletions configs/quantization/methods/FP_Quant/rtn_we4m3ae4m3.yml
Original file line number Diff line number Diff line change
@@ -1,27 +1,28 @@
base:
seed: &seed 42
model:
type: model_type
path: model path
type: Llama
path: /mnt/nvme1/yongyang/models/llama2-7b
torch_dtype: auto
eval:
eval_pos: [pretrain, fake_quant]
name: wikitext2
download: False
path: eval data path
seq_len: 2048
# For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
# For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
bs: 1
inference_per_block: False
# eval:
# eval_pos: [pretrain, fake_quant]
# name: wikitext2
# download: False
# path: /mnt/nvme0/yongyang/llm_datasets/llmc/eval/wikitext2
# seq_len: 2048
# # For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
# # For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
# bs: 1
# inference_per_block: False
quant:
method: RTN
quant_type: float-quant
weight:
quant_type: float-quant
bit: e4m3
symmetric: True
granularity: per_channel
act:
quant_type: float-quant
bit: e4m3
symmetric: True
granularity: per_token
Expand Down
3 changes: 2 additions & 1 deletion configs/quantization/methods/FP_Quant/rtn_we5m2ae5m2.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ eval:
inference_per_block: False
quant:
method: RTN
quant_type: float-quant
weight:
quant_type: float-quant
bit: e5m2
symmetric: True
granularity: per_channel
act:
quant_type: float-quant
bit: e5m2
symmetric: True
granularity: per_token
Expand Down
40 changes: 40 additions & 0 deletions configs/quantization/methods/RTN/rtn_w_a_wint4afp8.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
base:
seed: &seed 42
model:
type: Llama
path: /mnt/nvme1/yongyang/models/llama2-7b
torch_dtype: auto
eval:
eval_pos: [pretrain, fake_quant]
name: wikitext2
download: False
path: /mnt/nvme0/yongyang/llm_datasets/llmc/eval/wikitext2
seq_len: 2048
# For 7B / 13B model eval, bs can be set to "1", and inference_per_block can be set to "False".
# For 70B model eval, bs can be set to "20", and inference_per_block can be set to "True".
bs: 1
inference_per_block: False
quant:
method: RTN
weight:
bit: 48
bit4:
symmetric: False
granularity: per_group
group_size: 128
scales_bit: 8
scales_symmetric: True
zeros_bit: 8
zeros_symmetric: True
bit8:
symmetric: True
granularity: per_channel
int_range: [-120, 120]
act:
quant_type: float-quant
bit: e4m3
symmetric: True
granularity: per_token
save:
save_fake: False
save_path: /path/to/save/
2 changes: 0 additions & 2 deletions configs/quantization/methods/RTN/rtn_w_a_wint4aint8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ eval:
quant:
method: RTN
weight:
quant_type: int-quant
bit: 48
bit4:
symmetric: False
Expand All @@ -32,7 +31,6 @@ quant:
granularity: per_channel
int_range: [-120, 120]
act:
quant_type: int-quant
bit: 8
symmetric: True
granularity: per_token
Expand Down
2 changes: 1 addition & 1 deletion llmc/compression/quantization/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, bit, symmetric, granularity, **kwargs):
# hist config
self.bins = self.kwargs.get('bins', 2048)
self.hist_threshold = self.kwargs.get('hist_threshold', 1)
self.dst_nbins = 2**bit
self.dst_nbins = 2**bit if isinstance(bit, int) else None
self.upsample_rate = (
16 # used to reduce quantization errors when upscaling histogram
)
Expand Down

0 comments on commit dbfd394

Please sign in to comment.