Skip to content

Commit 5eb4e7f

Browse files
jpablomchYuan0320
andcommitted
Mamba-Shedder release
Co-authored-by: Yuan0320 <[email protected]>
1 parent 766dd35 commit 5eb4e7f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+5503
-6
lines changed

MS/README.md

+138
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
# Mamba-Shedder
2+
3+
Official implementation of [Mamba-Shedder: Post-Transformer Compression for Efficient Selective Structured State Space Models]().
4+
5+
This repo contains the code for Mamba-Shedder, which explores the compression of the new Mamba-series architectures (and their hybrids).
6+
We study the sensitivity of these models to the removal of selected components at different granularities to reduce model size and computational overhead, thereby improving their efficiency while maintaining accuracy.
7+
Please refer to our paper for more details.
8+
9+
## News
10+
- **[2025.01.23]** Support for the new hybrid architecture model **Hymba**, please refer to [Hymba-Pruning](./hybrid/Hymba-Pruning).
11+
- **[2025.01.23]** Support Zamba2 ([Zamba2-Pruning](./hybrid/Zamba2-Pruning)).
12+
- **[2025.01.22]** Release the code for **Mamba-Shedder**. :tada:
13+
14+
## Released Pruned Models 🤗
15+
16+
Compressed models by Mamba-Shedder:
17+
18+
| Source Model | Components Removed | Recovery Tuning | Relative Acc. | Pruned Model Link | Inference Speedup |
19+
|--------------------------------------------------------------------|--------------------|-----------------|---------------|----------------------------------------------------------------------------------------|-------------------|
20+
| [Hymba-1.5B-Base](https://huggingface.co/nvidia/Hymba-1.5B-Base) | 7 Hymba Blocks || 97% | [Link]() | ~1.2x |
21+
| [Hymba-1.5B-Base](https://huggingface.co/nvidia/Hymba-1.5B-Base) | 7 Hymba Blocks || 99% | [Link]() | ~1.2x |
22+
| [mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) | 14 Mamba Blocks || 90% | [Link]() | ~1.3x |
23+
| [mamba2-2.7b](https://huggingface.co/state-spaces/mamba2-2.7b) | 22 SSMs || 96% | [Link]() | ~1.2x |
24+
| [mamba2-2.7b](https://huggingface.co/state-spaces/mamba2-2.7b) | 22 SSMs || 99% | [Link]() | ~1.2x |
25+
26+
## Setup
27+
28+
Use the following instructions to create a virtual environment with the required dependencies.
29+
30+
```
31+
# install dependencies
32+
bash install.sh
33+
```
34+
35+
## Run
36+
37+
### Evaluation before Pruning
38+
39+
```bash
40+
python eval.py --model_path <path to mamba model>
41+
```
42+
43+
### Prune
44+
45+
#### Mamba Block Pruning
46+
47+
An example command for [mamba-2.8b](https://huggingface.co/state-spaces/mamba-2.8b) with Mamba Block Pruning:
48+
49+
```bash
50+
python prune.py \
51+
--model_path state-spaces/mamba-2.8b \
52+
--do_prune \
53+
--output_path <path to pruning results> \
54+
--prune_target mamba_block \
55+
--target_pruning_steps 10 \
56+
--importance_metric ppl \
57+
--calibration_dataset alpaca \
58+
--num_calibration_samples 256 \
59+
--do_eval
60+
```
61+
62+
- `model_path`: Path to the pre-trained Mamba model.
63+
- `do_prune`: Flag to indicate whether to perform pruning.
64+
- `output_path`: Directory to save the pruning and evaluation results.
65+
- `prune_target`: "mamba_block" or "ssm".
66+
- `target_pruning_steps`: Number of pruning target modules (mamba blocks or SSMs).
67+
- `importance_metric`: Metric for calculating block importance, currently only supports PPL.
68+
- `calibration_dataset`: Calibration dataset name ("alpaca", "c4", "ptb" or "wikitext2").
69+
- `num_calibration_samples`: Number of calibration samples for pruning.
70+
- `do_eval`: Flag to indicate whether to perform evaluation.
71+
72+
#### SSM Pruning
73+
74+
An example command for [mamba2-2.7b](https://huggingface.co/state-spaces/mamba2-2.7b) with SSM Pruning:
75+
76+
```bash
77+
python prune.py \
78+
--model_path state-spaces/mamba2-2.7b \
79+
--do_prune \
80+
--output_path <path to pruning results> \
81+
--prune_target ssm \
82+
--target_pruning_steps 20 \
83+
--importance_metric ppl \
84+
--calibration_dataset alpaca \
85+
--num_calibration_samples 256 \
86+
--do_eval
87+
```
88+
89+
### Extract the Pruned Model
90+
91+
Extract the pruned model based on the optimal pruning configuration obtained from Mamba-Shedder.
92+
For more details, please refer to [here](./extract).
93+
Here is an example to extract a pruned [mamba2-2.7b](https://huggingface.co/state-spaces/mamba2-2.7b):
94+
95+
```bash
96+
python extract/extract_mamba.py \
97+
--model_path state-spaces/mamba2-2.7b \
98+
--pruned_model_config_file <path to pruning results>/pruning_config.json \
99+
--output_path <path to compressed model>
100+
```
101+
102+
### Recovery Fine-tuning
103+
104+
After we have obtained the pruned model, we can use [Alpaca](https://huggingface.co/datasets/yahma/alpaca-cleaned) dataset for recovery fine-tuning:
105+
106+
```bash
107+
# Finetune the compressed Mamba-2
108+
python recovery/finetune_mamba.py \
109+
--model_path <path to compressed model> \
110+
--do_train \
111+
--batch_size 32 \
112+
--gradient_accumulation_steps 1 \
113+
--num_train_epochs 1 \
114+
--learning_rate 5e-5 \
115+
--output_path <path to trained model> \
116+
--do_eval
117+
```
118+
119+
## Results
120+
121+
All run commands and pruning results can be found in [here](./results).
122+
123+
### Loading the compressed model for evaluation
124+
125+
```bash
126+
python eval.py --model_path <path to compressed model>
127+
```
128+
129+
## Citation
130+
If you find Mamba-Shedder's code and paper helpful, please kindly cite:
131+
```bibtex
132+
@article{munoz2025mambashedder,
133+
title = {Mamba-Shedder: Post-Transformer Compression for Efficient Selective Structured State Space Models},
134+
author = {J. Pablo Munoz and Jinjie Yuan and Nilesh Jain},
135+
journal = {},
136+
year = {2025}
137+
}
138+
```

MS/eval.py

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import argparse
2+
import json
3+
import logging
4+
import torch
5+
6+
from transformers import AutoTokenizer
7+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
8+
9+
from lm_eval import evaluator
10+
from lm_eval.models.mamba_lm import MambaLMWrapper
11+
12+
TASKS = ["lambada_openai", "hellaswag", "piqa", "arc_easy", "arc_challenge", "winogrande", "openbookqa"]
13+
14+
15+
def main():
16+
parser = argparse.ArgumentParser()
17+
parser.add_argument(
18+
"--model_path",
19+
type=str,
20+
)
21+
args = parser.parse_args()
22+
model_path = args.model_path
23+
24+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
25+
model = MambaLMHeadModel.from_pretrained(model_path, device="cuda", dtype=torch.float16)
26+
model.device = model.lm_head.weight.device
27+
lm = MambaLMWrapper(pretrained=model, tokenizer=tokenizer, batch_size=64)
28+
29+
# Evaluate on selected tasks
30+
logging.info(f"Selected Tasks: {TASKS}")
31+
results = evaluator.simple_evaluate(lm, tasks=TASKS, log_samples=False)['results']
32+
33+
metric_vals = {}
34+
for task, result in results.items():
35+
# TODO: fix (all are `acc_norm,none`)
36+
res = result['acc,none'] if task == 'arc_easy' else result.get('acc_norm,none', result['acc,none'])
37+
metric_vals[task] = round(res, 3) * 100
38+
if task == "lambada_openai":
39+
metric_vals[task + "_ppl"] = result['perplexity,none']
40+
41+
logging.info(json.dumps(metric_vals, indent=4))
42+
43+
44+
if __name__ == "__main__":
45+
main()

MS/extract/README.md

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
## Extract the Compressed Model from Mamba-Shedder
2+
3+
The final compressed model can be extracted based on the optimal pruning configuration obtained from Mamba-Shedder.
4+
5+
```bash
6+
# Mamba-1 (Mamba Block Pruning)
7+
python extract/extract_mamba.py \
8+
--model_path state-spaces/mamba-2.8b \
9+
--output_path <path to pruned model> \
10+
--pruned_model_config_file <path to pruning result>/pruning_config.json # Or specify the config file of a pruning step from the `pruned_model_configs` folder, e.g., <path to pruning result>/pruned_model_configs/config.mamba_block.${eval_step}.json
11+
12+
# Mamba-2 (SSM Pruning)
13+
python extract/extract_mamba.py \
14+
--model_path state-spaces/mamba2-2.7b \
15+
--output_path <path to pruned model> \
16+
--pruned_model_config_file <path to pruning result>/pruning_config.json # Or specify the config file of a pruning step from the `pruned_model_configs` folder, e.g., <path to pruning result>/pruned_model_configs/config.ssm.${eval_step}.json
17+
```
18+
19+
- `model_path`: Path to the pre-trained model.
20+
- `pruned_model_config_file`: JSON file for the pruned model configuration.
21+
- `output_path`: Directory to save the compressed model.

MS/extract/extract_mamba.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import argparse
2+
import json
3+
import logging
4+
import os
5+
import torch
6+
7+
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
8+
from transformers import AutoTokenizer
9+
10+
11+
MAMBA_MODULES = [
12+
"backbone.layers.*.mixer.dt_bias",
13+
"backbone.layers.*.mixer.A_log",
14+
"backbone.layers.*.mixer.D",
15+
"backbone.layers.*.mixer.in_proj.weight",
16+
"backbone.layers.*.mixer.conv1d.weight",
17+
"backbone.layers.*.mixer.conv1d.bias",
18+
"backbone.layers.*.mixer.norm.weight",
19+
"backbone.layers.*.mixer.out_proj.weight",
20+
"backbone.layers.*.mixer.dt_proj.weight", # Mamba-1
21+
"backbone.layers.*.mixer.dt_proj.bias", # Mamba-1
22+
"backbone.layers.*.mixer.x_proj.weight", # Mamba-1
23+
"backbone.layers.*.norm.weight",
24+
]
25+
26+
# only for Mamba-2
27+
SSM_MODULES = [
28+
"backbone.layers.*.mixer.D",
29+
"backbone.layers.*.mixer.dt_bias",
30+
]
31+
32+
33+
def main():
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument(
36+
"--model_path",
37+
type=str,
38+
help="Path to the Mamba model."
39+
)
40+
parser.add_argument(
41+
"--output_path",
42+
type=str,
43+
help="Directory to save the compressed model."
44+
)
45+
parser.add_argument(
46+
"--pruned_model_config_file",
47+
type=str,
48+
help="Path to the pruned model configuration file."
49+
)
50+
51+
args = parser.parse_args()
52+
model_path = args.model_path
53+
output_path = args.output_path
54+
# Create output directory if it doesn't exist
55+
os.makedirs(output_path, exist_ok=True)
56+
pruned_model_config_file = args.pruned_model_config_file
57+
58+
# Load model and tokenizer
59+
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
60+
model = MambaLMHeadModel.from_pretrained(model_path, device="cuda", dtype=torch.float16)
61+
62+
# Load pruning results
63+
with open(pruned_model_config_file, "r") as f:
64+
pruned_config = json.load(f)
65+
logging.info(f"Detect a pruned model config: {pruned_config}")
66+
state_dict = model.state_dict()
67+
68+
def prune_modules(state_dict, idx, module_names):
69+
for module_name in module_names:
70+
module_name = module_name.replace("*", str(idx))
71+
if module_name in state_dict:
72+
del state_dict[module_name]
73+
74+
if pruned_config.get("pruned_mamba_block_idx"):
75+
pruned_mamba_block_idx = pruned_config["pruned_mamba_block_idx"]
76+
for idx in pruned_mamba_block_idx:
77+
prune_modules(state_dict, idx, MAMBA_MODULES)
78+
if pruned_config.get("pruned_ssm_idx"):
79+
pruned_ssm_idx = pruned_config["pruned_ssm_idx"]
80+
for idx in pruned_ssm_idx:
81+
prune_modules(state_dict, idx, SSM_MODULES)
82+
83+
model.save_pretrained(output_path, state_dict=state_dict)
84+
tokenizer.save_pretrained(output_path)
85+
86+
87+
if __name__ == "__main__":
88+
main()

0 commit comments

Comments
 (0)