23
23
import mindspore
24
24
from mindspore import Tensor
25
25
from mindspore .train .serialization import _exec_save
26
+ from mindnlp .core .serialization import safe_save_file
26
27
27
28
from mindnlp .core import nn , ops
28
29
from mindnlp .core .nn import functional as F
45
46
LNTuningModel ,
46
47
)
47
48
from .utils import (
48
- # SAFETENSORS_WEIGHTS_NAME,
49
+ SAFETENSORS_WEIGHTS_NAME ,
49
50
TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING ,
50
51
WEIGHTS_NAME ,
51
52
PeftType ,
@@ -124,7 +125,7 @@ def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
124
125
# if hasattr(self.base_model, "config") and hasattr(self.base_model.config, "pretraining_tp"):
125
126
# self.base_model.config.pretraining_tp = 1
126
127
127
- def save_pretrained (self , save_directory , ** kwargs ):
128
+ def save_pretrained (self , save_directory , safe_serialization = False , ** kwargs ):
128
129
r"""
129
130
This function saves the adapter model and the adapter configuration files to a directory, so that it can be
130
131
reloaded using the [`LoraModel.from_pretrained`] class method, and also used by the [`LoraModel.push_to_hub`]
@@ -144,10 +145,17 @@ def save_pretrained(self, save_directory, **kwargs):
144
145
output_dir = os .path .join (save_directory , adapter_name ) if adapter_name != "default" else save_directory
145
146
os .makedirs (output_dir , exist_ok = True )
146
147
147
- _exec_save (
148
- ckpt_file_name = os .path .join (output_dir , WEIGHTS_NAME ),
149
- data_list = output_state_dict ,
150
- )
148
+ if safe_serialization :
149
+ safe_output_state_dict = {k : Tensor (v [2 ]).reshape (v [0 ]) for k , v in output_state_dict .items ()}
150
+ safe_save_file (
151
+ safe_output_state_dict ,
152
+ os .path .join (output_dir , SAFETENSORS_WEIGHTS_NAME ),
153
+ )
154
+ else :
155
+ _exec_save (
156
+ ckpt_file_name = os .path .join (output_dir , WEIGHTS_NAME ),
157
+ data_list = output_state_dict ,
158
+ )
151
159
152
160
# save the config and change the inference mode to `True`
153
161
if peft_config .base_model_name_or_path is None :
0 commit comments