3
3
4
4
# This source code is licensed under the license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
- import glob
7
6
import json
8
7
import os
9
8
import re
@@ -42,12 +41,7 @@ def convert_hf_checkpoint(
42
41
print (f"Model config { config .__dict__ } " )
43
42
44
43
# Load the json file containing weight mapping
45
- model_map_json_matches = [Path (m ) for m in glob .glob (str (model_dir / "*.index.json" ))]
46
- assert len (model_map_json_matches ) <= 1 , "Found multiple weight mapping files"
47
- if len (model_map_json_matches ):
48
- model_map_json = model_map_json_matches [0 ]
49
- else :
50
- model_map_json = model_dir / "pytorch_model.bin.index.json"
44
+ model_map_json = model_dir / "pytorch_model.bin.index.json"
51
45
52
46
# If there is no weight mapping, check for a consolidated model and
53
47
# tokenizer we can move. Llama 2 and Mistral have weight mappings, while
@@ -62,9 +56,10 @@ def convert_hf_checkpoint(
62
56
str (consolidated_pth ), map_location = "cpu" , mmap = True , weights_only = True
63
57
)
64
58
del loaded_result # No longer needed
65
- print (f"Moving checkpoint to { model_dir / 'model.pth' } ." )
66
- os .rename (consolidated_pth , model_dir / "model.pth" )
67
- os .rename (tokenizer_pth , model_dir / "tokenizer.model" )
59
+ print (f"Symlinking checkpoint to { model_dir / 'model.pth' } ." )
60
+ consolidated_pth = os .path .realpath (consolidated_pth )
61
+ os .symlink (consolidated_pth , model_dir / "model.pth" )
62
+ os .symlink (tokenizer_pth , model_dir / "tokenizer.model" )
68
63
print ("Done." )
69
64
return
70
65
else :
@@ -81,17 +76,10 @@ def convert_hf_checkpoint(
81
76
"model.layers.{}.self_attn.k_proj.weight" : "layers.{}.attention.wk.weight" ,
82
77
"model.layers.{}.self_attn.v_proj.weight" : "layers.{}.attention.wv.weight" ,
83
78
"model.layers.{}.self_attn.o_proj.weight" : "layers.{}.attention.wo.weight" ,
84
- "model.layers.{}.self_attn.q_proj.bias" : "layers.{}.attention.wq.bias" ,
85
- "model.layers.{}.self_attn.k_proj.bias" : "layers.{}.attention.wk.bias" ,
86
- "model.layers.{}.self_attn.v_proj.bias" : "layers.{}.attention.wv.bias" ,
87
- "model.layers.{}.self_attn.o_proj.bias" : "layers.{}.attention.wo.bias" ,
88
79
"model.layers.{}.self_attn.rotary_emb.inv_freq" : None ,
89
80
"model.layers.{}.mlp.gate_proj.weight" : "layers.{}.feed_forward.w1.weight" ,
90
81
"model.layers.{}.mlp.up_proj.weight" : "layers.{}.feed_forward.w3.weight" ,
91
82
"model.layers.{}.mlp.down_proj.weight" : "layers.{}.feed_forward.w2.weight" ,
92
- "model.layers.{}.mlp.gate_proj.bias" : "layers.{}.feed_forward.w1.bias" ,
93
- "model.layers.{}.mlp.up_proj.bias" : "layers.{}.feed_forward.w3.bias" ,
94
- "model.layers.{}.mlp.down_proj.bias" : "layers.{}.feed_forward.w2.bias" ,
95
83
"model.layers.{}.input_layernorm.weight" : "layers.{}.attention_norm.weight" ,
96
84
"model.layers.{}.post_attention_layernorm.weight" : "layers.{}.ffn_norm.weight" ,
97
85
"model.norm.weight" : "norm.weight" ,
@@ -100,43 +88,19 @@ def convert_hf_checkpoint(
100
88
bin_files = {model_dir / bin for bin in bin_index ["weight_map" ].values ()}
101
89
102
90
def permute (w , n_heads ):
91
+ dim = config .dim
103
92
return (
104
- w .view (n_heads , 2 , config .head_dim // 2 , * w . shape [ 1 :] )
93
+ w .view (n_heads , 2 , config .head_dim // 2 , dim )
105
94
.transpose (1 , 2 )
106
- .reshape (w . shape )
95
+ .reshape (config . head_dim * n_heads , dim )
107
96
)
108
97
109
98
merged_result = {}
110
99
for file in sorted (bin_files ):
111
-
112
- # The state_dict can be loaded from either a torch zip file or
113
- # safetensors. We take our best guess from the name and try all
114
- # possibilities
115
- load_pt_mmap = lambda : torch .load (
100
+ state_dict = torch .load (
116
101
str (file ), map_location = "cpu" , mmap = True , weights_only = True
117
102
)
118
- load_pt_no_mmap = lambda : torch .load (
119
- str (file ), map_location = "cpu" , mmap = False , weights_only = True
120
- )
121
- def load_safetensors ():
122
- import safetensors .torch
123
- with open (file , "rb" ) as handle :
124
- return safetensors .torch .load (handle .read ())
125
- if "safetensors" in str (file ):
126
- loaders = [load_safetensors , load_pt_mmap , load_pt_no_mmap ]
127
- else :
128
- loaders = [load_pt_mmap , load_pt_no_mmap , load_safetensors ]
129
-
130
- state_dict = None
131
- for loader in loaders :
132
- try :
133
- state_dict = loader ()
134
- break
135
- except Exception :
136
- continue
137
- assert state_dict is not None , f"Unable to load tensors from { file } "
138
103
merged_result .update (state_dict )
139
-
140
104
final_result = {}
141
105
for key , value in merged_result .items ():
142
106
if "layers" in key :
@@ -152,18 +116,16 @@ def load_safetensors():
152
116
final_result [new_key ] = value
153
117
154
118
for key in tuple (final_result .keys ()):
155
- if "wq.weight" in key or "wq.bias" in key :
156
- wk_key = key .replace ("wq" , "wk" )
157
- wv_key = key .replace ("wq" , "wv" )
119
+ if "wq" in key :
158
120
q = final_result [key ]
159
- k = final_result [wk_key ]
160
- v = final_result [wv_key ]
121
+ k = final_result [key . replace ( "wq" , "wk" ) ]
122
+ v = final_result [key . replace ( "wq" , "wv" ) ]
161
123
q = permute (q , config .n_heads )
162
124
k = permute (k , config .n_local_heads )
163
125
final_result [key .replace ("wq" , "wqkv" )] = torch .cat ([q , k , v ])
164
126
del final_result [key ]
165
- del final_result [wk_key ]
166
- del final_result [wv_key ]
127
+ del final_result [key . replace ( "wq" , "wk" ) ]
128
+ del final_result [key . replace ( "wq" , "wv" ) ]
167
129
print (f"Saving checkpoint to { model_dir / 'model.pth' } . This may take a while." )
168
130
torch .save (final_result , model_dir / "model.pth" )
169
131
print ("Done." )
@@ -184,10 +146,10 @@ def convert_hf_checkpoint_to_tune(
184
146
consolidated_pth = model_dir / "original" / "consolidated.pth"
185
147
tokenizer_pth = model_dir / "original" / "tokenizer.model"
186
148
if consolidated_pth .is_file () and tokenizer_pth .is_file ():
187
- print (f"Moving checkpoint to { model_dir / 'model.pth' } ." )
188
- os .rename (consolidated_pth , model_dir / "model.pth" )
189
- print (f"Moving tokenizer to { model_dir / 'tokenizer.model' } ." )
190
- os .rename (tokenizer_pth , model_dir / "tokenizer.model" )
149
+ print (f"Creating symlink from { consolidated_pth } to { model_dir / 'model.pth' } ." )
150
+ os .symlink (consolidated_pth , model_dir / "model.pth" )
151
+ print (f"Creating symlink from { tokenizer_pth } to { model_dir / 'tokenizer.model' } ." )
152
+ os .symlink (tokenizer_pth , model_dir / "tokenizer.model" )
191
153
print ("Done." )
192
154
else :
193
155
raise RuntimeError (f"Could not find { consolidated_pth } " )
0 commit comments