-
Notifications
You must be signed in to change notification settings - Fork 815
/
Copy pathtorch-checkpoint-convert-to-bf16
executable file
·34 lines (27 loc) · 1.32 KB
/
torch-checkpoint-convert-to-bf16
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
#!/bin/bash
# this script converts torch's *bin and safetensor *safetensor files to bf16 creating a new checkpoint under a sub-dir bf16
#
# usage:
# cd checkpoint
# bash torch-checkpoint-convert-to-bf16
# set destination dir
target_dir=bf16
echo "creating a new checkpoint under dir $target_dir"
mkdir -p $target_dir
# cp config and other files - adapt if needed - could also do `cp * $target_dir`
cp *json *model $target_dir
# convert *bin
echo "converting *bin torch files"
python -c "import torch, sys; [torch.save({k:v.to(torch.bfloat16) for k,v in torch.load(f).items()}, f'{sys.argv[1]}/{f}') for f in sys.argv[2:]]" $target_dir *bin
# convert *safetensors (from original *bin files)
if compgen -G "*.safetensors" > /dev/null; then
echo "converting *safetensors files"
cd $target_dir
python -c "import re, sys, torch; from safetensors.torch import save_file; [save_file(torch.load(f), re.sub(r'.*?(model.*?)\.bin',r'\1.safetensors',f), metadata={'format': 'pt'}) for f in sys.argv[1:]]" *bin
if test -e "pytorch_model.bin.index.json"; then
cp pytorch_model.bin.index.json model.safetensors.index.json
perl -pi -e 's|pytorch_||; s|\.bin|.safetensors|' model.safetensors.index.json
fi
cd - > /dev/null
fi
echo "the dir $target_dir now contains a copy of the original checkpoint with bf16 weights"