Skip to content

Commit a3db5c8

Browse files
authored
support vit full llm lora (#3575)
1 parent c3fb9db commit a3db5c8

File tree

12 files changed

+157
-27
lines changed

12 files changed

+157
-27
lines changed

docs/source/Customization/插件化.md

+5-4
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ tuner定制也是swift中有特色的能力之一,开发者可以无视复杂
145145
class IA3(Tuner):
146146

147147
@staticmethod
148-
def prepare_model(args: 'TrainArguments', model: torch.nn.Module):
148+
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
149149
model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch]
150150
ia3_config = IA3Config(
151151
target_modules=find_all_linears(model), feedforward_modules='.*' + model_arch.mlp.split('{}.')[1] + '.*')
@@ -155,14 +155,15 @@ class IA3(Tuner):
155155
def save_pretrained(
156156
model: torch.nn.Module,
157157
save_directory: str,
158+
state_dict: Optional[dict] = None,
158159
safe_serialization: bool = True,
159160
**kwargs,
160-
):
161+
) -> None:
161162
model: PeftModel
162-
model.save_pretrained(save_directory, safe_serialization=safe_serialization, **kwargs)
163+
model.save_pretrained(save_directory, state_dict=state_dict, safe_serialization=safe_serialization, **kwargs)
163164

164165
@staticmethod
165-
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs):
166+
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module:
166167
return PeftModel.from_pretrained(model, model_id, **kwargs)
167168
```
168169

docs/source/Instruction/预训练与微调.md

+1
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ ms-swift使用了分层式的设计思想,用户可以使用命令行界面、
5555
- 多机多卡训练: 我们书写了使用swift、torchrun、dlc、deepspeed、accelerate启动多节点运行的shell脚本示例。除了dlc和deepspeed,其他启动脚本都需要在所有节点中启动才可运行。具体参考[这里](https://github.com/modelscope/swift/blob/main/examples/train/multi-node)
5656
- 量化训练:支持使用GPTQ、AWQ、AQLM、BNB、HQQ、EETQ量化技术的QLoRA训练。微调7B模型只需要9GB显存资源。具体参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora)
5757
- 多模态训练:SWIFT支持多模态模型的预训练、微调和RLHF。支持Caption、VQA、OCR、[Grounding](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-vl-grounding/zh.ipynb)任务。支持图像、视频和音频三种模态。具体参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal)。多模态自定义数据集格式参考[自定义数据集文档](../Customization/自定义数据集.md)
58+
- 对ViT/Aligner使用全参数训练,LLM使用LoRA训练,并采用不同学习率的例子参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/custom_tuner)
5859
- RLHF训练:参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf)。多模态模型参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/rlhf)。GRPO训练参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/grpo_zero2.sh)。强化微调查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft)
5960
- Megatron训练:支持使用Megatron的并行技术来加速大模型的训练,包括数据并行、张量并行、流水线并行、序列并行,上下文并行。参考[Megatron-SWIFT训练文档](./Megatron-SWIFT训练.md)
6061
- 序列分类模型训练:参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/seq_cls)

docs/source_en/Customization/Pluginization.md

+5-4
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ Tuner customization is another unique feature of SWIFT. Developers can bypass th
163163
class IA3(Tuner):
164164

165165
@staticmethod
166-
def prepare_model(args: 'TrainArguments', model: torch.nn.Module):
166+
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
167167
model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch]
168168
ia3_config = IA3Config(
169169
target_modules=find_all_linears(model), feedforward_modules='.*' + model_arch.mlp.split('{}.')[1] + '.*')
@@ -173,14 +173,15 @@ class IA3(Tuner):
173173
def save_pretrained(
174174
model: torch.nn.Module,
175175
save_directory: str,
176+
state_dict: Optional[dict] = None,
176177
safe_serialization: bool = True,
177178
**kwargs,
178-
):
179+
) -> None:
179180
model: PeftModel
180-
model.save_pretrained(save_directory, safe_serialization=safe_serialization, **kwargs)
181+
model.save_pretrained(save_directory, state_dict=state_dict, safe_serialization=safe_serialization, **kwargs)
181182

182183
@staticmethod
183-
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs):
184+
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module:
184185
return PeftModel.from_pretrained(model, model_id, **kwargs)
185186
```
186187

docs/source_en/Instruction/Pre-training-and-Fine-tuning.md

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ Additionally, we offer a series of scripts to help you understand the training c
5858
- Multi-node Multi-GPU Training: We have provided example shell scripts for launching multi-node runs using swift, torchrun, dlc, deepspeed, and accelerate. Except for dlc and deepspeed, the other launch scripts need to be started on all nodes to run properly. Please refer to [here](https://github.com/modelscope/swift/blob/main/examples/train/multi-node) for details.
5959
- Quantization Training: Supports QLoRA training using quantization techniques such as GPTQ, AWQ, AQLM, BNB, HQQ, and EETQ. Fine-tuning a 7B model only requires 9GB of memory. For more details, refer to [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora).
6060
- Multi-modal Training: SWIFT supports pre-training, fine-tuning, and RLHF for multi-modal models. It supports tasks such as Captioning, VQA, OCR, and [Grounding](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-vl-grounding/zh.ipynb). It supports three modalities: images, videos, and audio. For more details, refer to [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal). The format for custom multi-modal datasets can be found in the [Custom Dataset Documentation](../Customization/Custom-dataset.md).
61+
- For an example of using full-parameter training for ViT/Aligner, LoRA training for LLM, and adopting different learning rates, refer to [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/custom_tuner).
6162
- RLHF Training: Refer to [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf). For multi-modal models, refer to [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/rlhf). For GRPO training, refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/grpo_zero2.sh). For reinforcement fine-tuning, see [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft).
6263
- Megatron Training: Supports the use of Megatron's parallelization techniques to accelerate the training of large models, including data parallelism, tensor parallelism, pipeline parallelism, sequence parallelism, and context parallelism. Refer to the [Megatron-SWIFT Training Documentation](./Megatron-SWIFT-Training.md).
6364
- Sequence Classification Model Training: Refer to [here](https://github.com/modelscope/ms-swift/tree/main/examples/train/seq_cls).
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
from typing import Optional
3+
4+
import safetensors.torch
5+
import torch
6+
from transformers import Trainer
7+
8+
from swift.plugin import Tuner, extra_tuners, optimizers_map
9+
from swift.tuners import LoraConfig, Swift
10+
11+
12+
class CustomTuner(Tuner):
13+
14+
@staticmethod
15+
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module:
16+
model = Swift.from_pretrained(model, model_id, **kwargs)
17+
state_dict = safetensors.torch.load_file(os.path.join(model_id, 'vit.safetensors'))
18+
model.load_state_dict(state_dict, strict=False)
19+
return model
20+
21+
@staticmethod
22+
def save_pretrained(
23+
model: torch.nn.Module,
24+
save_directory: str,
25+
state_dict: Optional[dict] = None,
26+
safe_serialization: bool = True,
27+
**kwargs,
28+
) -> None:
29+
if state_dict is None:
30+
state_dict = {}
31+
for n, p in model.named_parameters():
32+
if p.requires_grad:
33+
state_dict[n] = p.detach().cpu()
34+
model.save_pretrained(save_directory, state_dict=state_dict, safe_serialization=safe_serialization, **kwargs)
35+
# vit
36+
state_dict = {k: v for k, v in state_dict.items() if '.visual.' in k}
37+
safetensors.torch.save_file(
38+
state_dict, os.path.join(save_directory, 'vit.safetensors'), metadata={'format': 'pt'})
39+
40+
@staticmethod
41+
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
42+
target_regex = r'^model.layers.*'
43+
lora_config = LoraConfig(
44+
task_type='CAUSAL_LM', r=args.lora_rank, lora_alpha=args.lora_alpha, target_modules=target_regex)
45+
model = Swift.prepare_model(model, lora_config)
46+
model.visual.requires_grad_(True) # vit & merger
47+
return model
48+
49+
50+
def create_custom_optimizer(args, model, dataset):
51+
decay_parameters = set(Trainer.get_decay_parameter_names(None, model))
52+
vit_parameters = [(n, p) for n, p in model.named_parameters() if '.visual.' in n and p.requires_grad]
53+
llm_parameters = [(n, p) for n, p in model.named_parameters() if '.visual.' not in n and p.requires_grad]
54+
optimizer_grouped_parameters = [
55+
# vit & merger
56+
{
57+
'params': [p for n, p in vit_parameters if n in decay_parameters],
58+
'weight_decay': args.weight_decay,
59+
'lr': 0.1 * args.learning_rate, # 1e-5
60+
},
61+
{
62+
'params': [p for n, p in vit_parameters if n not in decay_parameters],
63+
'weight_decay': 0.0,
64+
'lr': 0.1 * args.learning_rate,
65+
},
66+
# llm
67+
{
68+
'params': [p for n, p in llm_parameters if n in decay_parameters],
69+
'weight_decay': args.weight_decay,
70+
},
71+
{
72+
'params': [p for n, p in llm_parameters if n not in decay_parameters],
73+
'weight_decay': 0.0,
74+
},
75+
]
76+
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(args, model)
77+
return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs), None
78+
79+
80+
extra_tuners['custom'] = CustomTuner
81+
optimizers_map['custom'] = create_custom_optimizer
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# If the weights have been merged, please use `--model`.
2+
CUDA_VISIBLE_DEVICES=0 \
3+
swift infer \
4+
--adapters output/vx-xxx/checkpoint-xxx \
5+
--stream true \
6+
--load_data_args true \
7+
--temperature 0 \
8+
--max_new_tokens 2048
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# 4 * 22GiB
2+
# vit/merger lr 1e-5; llm lora lr 1e-4
3+
NPROC_PER_NODE=4 \
4+
CUDA_VISIBLE_DEVICES=0,1,2,3 \
5+
MAX_PIXELS=1003520 \
6+
swift sft \
7+
--model Qwen/Qwen2.5-VL-7B-Instruct \
8+
--dataset 'AI-ModelScope/coco#20000' \
9+
--train_type custom \
10+
--optimizer custom \
11+
--external_plugins 'examples/train/multimodal/custom_tuner/custom_plugin.py' \
12+
--torch_dtype bfloat16 \
13+
--num_train_epochs 1 \
14+
--per_device_train_batch_size 1 \
15+
--per_device_eval_batch_size 1 \
16+
--learning_rate 1e-4 \
17+
--lora_rank 16 \
18+
--lora_alpha 32 \
19+
--gradient_accumulation_steps 4 \
20+
--eval_steps 100 \
21+
--save_steps 100 \
22+
--save_total_limit 5 \
23+
--logging_steps 5 \
24+
--max_length 8192 \
25+
--output_dir output \
26+
--warmup_ratio 0.05 \
27+
--dataloader_num_workers 4 \
28+
--dataset_num_proc 4 \
29+
--deepspeed zero2 \
30+
--save_only_model true
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
swift export \
2+
--adapters output/vx-xxx/checkpoint-xxx \
3+
--merge_lora true

swift/llm/argument/base_args/base_args.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,6 @@ def _init_adapters(self):
131131
self.adapters = [
132132
safe_snapshot_download(adapter, use_hf=self.use_hf, hub_token=self.hub_token) for adapter in self.adapters
133133
]
134-
for adapter in self.adapters:
135-
assert self._check_is_adapter(adapter), (
136-
f'`{adapter}` is not an adapter, please try using `--model` to pass it.')
137134

138135
def __post_init__(self):
139136
if self.use_hf or use_hf_hub():
@@ -149,6 +146,10 @@ def __post_init__(self):
149146
self.rank, self.local_rank, self.global_world_size, self.local_world_size = get_dist_setting()
150147
logger.info(f'rank: {self.rank}, local_rank: {self.local_rank}, '
151148
f'world_size: {self.global_world_size}, local_world_size: {self.local_world_size}')
149+
if self.train_type not in extra_tuners:
150+
for adapter in self.adapters:
151+
assert self._check_is_adapter(adapter), (
152+
f'`{adapter}` is not an adapter, please try using `--model` to pass it.')
152153
ModelArguments.__post_init__(self)
153154
QuantizeArguments.__post_init__(self)
154155
TemplateArguments.__post_init__(self)

swift/plugin/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from .metric import InferStats, MeanMetric, Metric, compute_acc, get_metric, compute_rouge_bleu
1111
from .optimizer import optimizers_map
1212
from .tools import get_tools_prompt, get_tools_keyword
13-
from .tuner import Tuner, extra_tuners
13+
from .tuner import Tuner, extra_tuners, PeftTuner
1414
from .prm import prms, PRM
1515
from .orm import orms, ORM
1616

@@ -22,7 +22,7 @@
2222
'metric': ['InferStats', 'MeanMetric', 'Metric', 'compute_acc', 'get_metric', 'compute_rouge_bleu'],
2323
'optimizer': ['optimizers_map'],
2424
'tools': ['get_tools_prompt', 'get_tools_keyword'],
25-
'tuner': ['Tuner', 'extra_tuners'],
25+
'tuner': ['Tuner', 'extra_tuners', 'PeftTuner'],
2626
'prm': ['prms', 'PRM'],
2727
'orm': ['orms', 'ORM']
2828
}

swift/plugin/optimizer.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def calculate_max_steps(args: 'TrainArguments', dataset) -> int:
2222
return max_steps
2323

2424

25-
def create_galore_optimizers(args, model, dataset):
25+
def create_galore_optimizer(args, model, dataset):
2626
training_steps = calculate_max_steps(args, dataset)
2727
optimizer, lr_scheduler = create_optimizer_and_scheduler(
2828
model, args, args.galore_config, training_steps, lr=args.learning_rate, weight_decay=args.weight_decay)
@@ -31,7 +31,7 @@ def create_galore_optimizers(args, model, dataset):
3131
return optimizer, lr_scheduler
3232

3333

34-
def create_lorap_optimizers(args, model, dataset):
34+
def create_lorap_optimizer(args, model, dataset):
3535
optimizer_grouped_parameters = None
3636
if hasattr(model, 'create_optimizer_param_groups'):
3737
# Lora+ parameter groups
@@ -55,7 +55,7 @@ def create_lorap_optimizers(args, model, dataset):
5555
return optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs), None
5656

5757

58-
def create_muon_optimizers(args, model, dataset):
58+
def create_muon_optimizer(args, model, dataset):
5959
from swift.llm import git_clone_github, get_model_arch
6060
if not args.local_repo_path:
6161
args.local_repo_path = git_clone_github('https://github.com/MoonshotAI/Moonlight.git')
@@ -94,7 +94,7 @@ def create_muon_optimizers(args, model, dataset):
9494

9595
# Add your own optimizers here, use --optimizer xxx to train
9696
optimizers_map = {
97-
'galore': create_galore_optimizers,
98-
'lorap': create_lorap_optimizers,
99-
'muon': create_muon_optimizers,
97+
'galore': create_galore_optimizer,
98+
'lorap': create_lorap_optimizer,
99+
'muon': create_muon_optimizer,
100100
}

swift/plugin/tuner.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
from typing import Optional
3+
24
import torch
35
from peft import IA3Config, PeftModel, get_peft_model
46

@@ -9,7 +11,7 @@
911
class Tuner:
1012

1113
@staticmethod
12-
def prepare_model(args: 'TrainArguments', model: torch.nn.Module):
14+
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
1315
"""Prepare a new model with a tuner
1416
1517
Args:
@@ -25,9 +27,10 @@ def prepare_model(args: 'TrainArguments', model: torch.nn.Module):
2527
def save_pretrained(
2628
model: torch.nn.Module,
2729
save_directory: str,
30+
state_dict: Optional[dict] = None,
2831
safe_serialization: bool = True,
2932
**kwargs,
30-
):
33+
) -> None:
3134
"""Save when save_steps reaches
3235
3336
Args:
@@ -38,7 +41,7 @@ def save_pretrained(
3841
raise NotImplementedError
3942

4043
@staticmethod
41-
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs):
44+
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module:
4245
"""Load the ckpt_dir
4346
4447
Args:
@@ -56,22 +59,22 @@ class PeftTuner(Tuner):
5659
def save_pretrained(
5760
model: torch.nn.Module,
5861
save_directory: str,
62+
state_dict: Optional[dict] = None,
5963
safe_serialization: bool = True,
6064
**kwargs,
61-
):
62-
model: PeftModel
65+
) -> None:
6366
model.save_pretrained(save_directory, safe_serialization=safe_serialization, **kwargs)
6467

6568
@staticmethod
66-
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs):
69+
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module:
6770
return PeftModel.from_pretrained(model, model_id, **kwargs)
6871

6972

7073
# Here gives a simple example of IA3
7174
class IA3(PeftTuner):
7275

7376
@staticmethod
74-
def prepare_model(args: 'TrainArguments', model: torch.nn.Module):
77+
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
7578
model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch]
7679
ia3_config = IA3Config(
7780
target_modules=find_all_linears(model), feedforward_modules='.*' + model_arch.mlp.split('{}.')[1] + '.*')
@@ -81,7 +84,7 @@ def prepare_model(args: 'TrainArguments', model: torch.nn.Module):
8184
class DummyTuner(PeftTuner):
8285

8386
@staticmethod
84-
def prepare_model(args: 'TrainArguments', model: torch.nn.Module):
87+
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
8588
return model
8689

8790

0 commit comments

Comments
 (0)