-
Notifications
You must be signed in to change notification settings - Fork 454
Expand file tree
/
Copy pathmain.py
More file actions
177 lines (142 loc) · 6.14 KB
/
Copy pathmain.py
File metadata and controls
177 lines (142 loc) · 6.14 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Main script for running the puzzletron algorithm on large language models (based on Puzzle paper https://arxiv.org/abs/2411.19146).
This script provides three modes:
1. Default mode: Runs the full puzzletron pipeline
2. MIP-only mode: Runs only the MIP search and realize models phase
3. MIP sweep mode: Runs MIP for multiple memory compression rates (enabled via config)
Usage:
# Full puzzletron pipeline
torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml
# Only MIP search and realize models phase
torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml --mip-only
# MIP sweep mode (set mip.sweep.enabled: true in config)
torchrun main.py --config ./configs/llama_3.2_1B_pruneffn_memory.yaml --mip-only
"""
import argparse
from datetime import timedelta
from pathlib import Path
import modelopt.torch.nas as mtn
import modelopt.torch.puzzletron as mtpz
import modelopt.torch.utils.distributed as dist
def parse_args():
"""Parse command line arguments."""
parser = argparse.ArgumentParser(
description="Compress large language models using the Puzzletron algorithm (based on Puzzle paper https://arxiv.org/abs/2411.19146)"
)
parser.add_argument(
"--config",
type=str,
required=True,
help="Path to the main config YAML file (e.g., ./configs/llama_3.2_1B_pruneffn_memory.yaml)",
)
parser.add_argument(
"--mip-only",
action="store_true",
help="Run only the MIP search and realize models phase (skip pruning and NAS scoring)",
)
return parser.parse_args()
def run_full_puzzletron(hydra_config_path: str):
"""Run the full puzzletron pipeline.
Args:
config_path: Path to the YAML configuration file
"""
mtpz.tools.mprint("Puzzletron Progress 1/8: starting puzzletron pipeline")
# Register Hydra custom resolvers (needed for config resolution)
mtpz.tools.register_hydra_resolvers()
hydra_config_path = Path(hydra_config_path).resolve()
hydra_config_dir = str(hydra_config_path.parent)
hydra_config_name = hydra_config_path.stem
# Load hydra config
hydra_cfg = mtpz.tools.initialize_hydra_config_for_dir(
config_dir=hydra_config_dir,
config_name=hydra_config_name,
overrides=[],
)
# Default timeout: 10 minutes, or extended to nccl_timeout_minutes if set in config
if hasattr(hydra_cfg, "nccl_timeout_minutes"):
timeout_minutes = hydra_cfg.nccl_timeout_minutes
else:
timeout_minutes = timedelta(minutes=10)
dist.setup(timeout=timeout_minutes)
# Convert model (convert from HF to DeciLM, score pruning activations,
# prune the model and save pruned checkpoints)
input_model = mtpz.puzzletron_nas_plugin.PuzzletronModel()
converted_model = mtn.convert(
input_model,
mode=[
(
"puzzletron",
{
"puzzle_dir": str(hydra_cfg.puzzle_dir),
"input_model_path": hydra_cfg.input_hf_model_path,
"hydra_config_dir": hydra_config_dir,
"hydra_config_name": hydra_config_name,
"dataset_path": str(hydra_cfg.dataset_path),
},
)
],
)
# Run NAS search (build replacement library and compute stats,
# compute one block scores, run MIP and realize models)
mtn.search(
converted_model,
constraints={}, # this is not used as the search space is defined in the hydra config
dummy_input=None, # Not used
config={}, # this is not used as the search space is defined in the hydra config
)
dist.cleanup()
mtpz.tools.mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
def run_mip_only(hydra_config_path: str):
"""Run only the MIP search and realize models phase.
This assumes that pruning, replacement library building, NAS scoring, and subblock stats calculation
have already been completed.
Args:
hydra_config_path: Path to the YAML configuration file
"""
dist.setup(timeout=timedelta(minutes=10))
# Register Hydra custom resolvers (needed for config resolution)
mtpz.tools.register_hydra_resolvers()
hydra_config_path = Path(hydra_config_path).resolve()
hydra_config_dir = str(hydra_config_path.parent)
hydra_config_name = hydra_config_path.stem
# Load hydra config
hydra_cfg = mtpz.tools.initialize_hydra_config_for_dir(
config_dir=hydra_config_dir,
config_name=hydra_config_name,
overrides=[],
)
# Check if sweep mode is enabled
if hasattr(hydra_cfg.mip, "sweep") and hydra_cfg.mip.sweep.get("enabled", False):
mtpz.tools.mprint(
"Puzzletron Progress 7/8: running MIP sweep for multiple compression rates (multi-gpu)"
)
mtpz.mip.run_mip_sweep(hydra_cfg)
else:
# mip_and_realize_models (distributed processing)
# TODO: How to make it part of mnt.search() api, similarly to run_full_puzzletron() API
mtpz.tools.mprint("Puzzletron Progress 7/8: running MIP and realizing models (multi-gpu)")
mtpz.mip.launch_mip_and_realize_model(hydra_cfg)
dist.cleanup()
mtpz.tools.mprint("Puzzletron Progress 8/8: puzzletron pipeline completed (multi-gpu)")
def main():
args = parse_args()
if args.mip_only:
run_mip_only(hydra_config_path=args.config)
else:
run_full_puzzletron(hydra_config_path=args.config)
if __name__ == "__main__":
main()