Skip to content

Commit 36a31ed

Browse files
authored
PeftModel supports saving and loading safetensors (#2025)
1 parent bd6ba69 commit 36a31ed

File tree

3 files changed

+27
-13
lines changed

3 files changed

+27
-13
lines changed

mindnlp/peft/peft_model.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import mindspore
2424
from mindspore import Tensor
2525
from mindspore.train.serialization import _exec_save
26+
from mindnlp.core.serialization import safe_save_file
2627

2728
from mindnlp.core import nn, ops
2829
from mindnlp.core.nn import functional as F
@@ -45,7 +46,7 @@
4546
LNTuningModel,
4647
)
4748
from .utils import (
48-
# SAFETENSORS_WEIGHTS_NAME,
49+
SAFETENSORS_WEIGHTS_NAME,
4950
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING,
5051
WEIGHTS_NAME,
5152
PeftType,
@@ -124,7 +125,7 @@ def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
124125
# if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
125126
# self.base_model.config.pretraining_tp = 1
126127

127-
def save_pretrained(self, save_directory, **kwargs):
128+
def save_pretrained(self, save_directory, safe_serialization=False, **kwargs):
128129
r"""
129130
This function saves the adapter model and the adapter configuration files to a directory, so that it can be
130131
reloaded using the [`LoraModel.from_pretrained`] class method, and also used by the [`LoraModel.push_to_hub`]
@@ -144,10 +145,17 @@ def save_pretrained(self, save_directory, **kwargs):
144145
output_dir = os.path.join(save_directory, adapter_name) if adapter_name != "default" else save_directory
145146
os.makedirs(output_dir, exist_ok=True)
146147

147-
_exec_save(
148-
ckpt_file_name=os.path.join(output_dir, WEIGHTS_NAME),
149-
data_list=output_state_dict,
150-
)
148+
if safe_serialization:
149+
safe_output_state_dict = {k : Tensor(v[2]).reshape(v[0]) for k, v in output_state_dict.items()}
150+
safe_save_file(
151+
safe_output_state_dict,
152+
os.path.join(output_dir, SAFETENSORS_WEIGHTS_NAME),
153+
)
154+
else:
155+
_exec_save(
156+
ckpt_file_name=os.path.join(output_dir, WEIGHTS_NAME),
157+
data_list=output_state_dict,
158+
)
151159

152160
# save the config and change the inference mode to `True`
153161
if peft_config.base_model_name_or_path is None:

mindnlp/peft/utils/save_and_load.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818

1919
import mindspore
2020

21+
from mindnlp.core.serialization import safe_load_file
22+
2123
from .peft_types import PeftType
22-
from .constants import WEIGHTS_NAME
24+
from .constants import WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME
2325

2426
def get_data_list(param_dict):
2527
"""Get state dict of the Peft model for saving."""
@@ -198,11 +200,15 @@ def load_peft_weights(model_id: str,) -> dict:
198200
"""
199201
path = model_id
200202

201-
filename = os.path.join(path, WEIGHTS_NAME)
202-
if not os.path.exists(filename):
203-
# TODO: add download logic later
204-
raise ValueError(f"load peft model failed, peft model file: {filename} not exists.")
203+
safe_filename = os.path.join(path, SAFETENSORS_WEIGHTS_NAME)
204+
ckpt_filename = os.path.join(path, WEIGHTS_NAME)
205205

206-
adapters_weights = mindspore.load_checkpoint(filename)
206+
if os.path.exists(safe_filename):
207+
adapters_weights = safe_load_file(safe_filename)
208+
elif os.path.exists(ckpt_filename):
209+
adapters_weights = mindspore.load_checkpoint(ckpt_filename)
210+
else:
211+
# TODO: add download logic later
212+
raise ValueError(f"load peft model failed, peft model file: neither {ckpt_filename} nor {safe_filename} was found.")
207213

208214
return adapters_weights

mindnlp/transformers/models/qwen2/modeling_qwen2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -801,7 +801,7 @@ def prepare_inputs_for_generation(
801801

802802
if attention_mask is not None and position_ids is None:
803803
# create position_ids on the fly for batch generation
804-
position_ids = attention_mask.int().cumsum(-1) - 1
804+
position_ids = ops.cumsum(attention_mask.int(), -1) - 1
805805
position_ids = position_ids.masked_fill(attention_mask == 0, 1)
806806
if past_key_values:
807807
position_ids = position_ids[:, -input_ids.shape[1] :]

0 commit comments

Comments
 (0)