2323import mindspore
2424from mindspore import Tensor
2525from mindspore .train .serialization import _exec_save
26+ from mindnlp .core .serialization import safe_save_file
2627
2728from mindnlp .core import nn , ops
2829from mindnlp .core .nn import functional as F
4546 LNTuningModel ,
4647)
4748from .utils import (
48- # SAFETENSORS_WEIGHTS_NAME,
49+ SAFETENSORS_WEIGHTS_NAME ,
4950 TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING ,
5051 WEIGHTS_NAME ,
5152 PeftType ,
@@ -124,7 +125,7 @@ def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
124125 # if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
125126 # self.base_model.config.pretraining_tp = 1
126127
127- def save_pretrained (self , save_directory , ** kwargs ):
128+ def save_pretrained (self , save_directory , safe_serialization = False , ** kwargs ):
128129 r"""
129130 This function saves the adapter model and the adapter configuration files to a directory, so that it can be
130131 reloaded using the [`LoraModel.from_pretrained`] class method, and also used by the [`LoraModel.push_to_hub`]
@@ -144,10 +145,17 @@ def save_pretrained(self, save_directory, **kwargs):
144145 output_dir = os .path .join (save_directory , adapter_name ) if adapter_name != "default" else save_directory
145146 os .makedirs (output_dir , exist_ok = True )
146147
147- _exec_save (
148- ckpt_file_name = os .path .join (output_dir , WEIGHTS_NAME ),
149- data_list = output_state_dict ,
150- )
148+ if safe_serialization :
149+ safe_output_state_dict = {k : Tensor (v [2 ]).reshape (v [0 ]) for k , v in output_state_dict .items ()}
150+ safe_save_file (
151+ safe_output_state_dict ,
152+ os .path .join (output_dir , SAFETENSORS_WEIGHTS_NAME ),
153+ )
154+ else :
155+ _exec_save (
156+ ckpt_file_name = os .path .join (output_dir , WEIGHTS_NAME ),
157+ data_list = output_state_dict ,
158+ )
151159
152160 # save the config and change the inference mode to `True`
153161 if peft_config .base_model_name_or_path is None :
0 commit comments