-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path99_export.py
75 lines (65 loc) · 2.35 KB
/
99_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import shutil
import argparse
import torch
from rich import print
from transformers import AutoConfig
from transformers import PreTrainedModel
from transformers import AutoModelForTokenClassification
# 加载模型
def load_model(target: str, output_path: str) -> PreTrainedModel:
config = AutoConfig.from_pretrained(
target,
local_files_only = True,
trust_remote_code = True,
)
config.reference_compile = None
if "bf16" in output_path:
return AutoModelForTokenClassification.from_pretrained(
target,
config = config,
torch_dtype = torch.bfloat16,
attn_implementation = "sdpa",
local_files_only = True,
trust_remote_code = True,
ignore_mismatched_sizes = True,
)
elif "fp16" in output_path:
return AutoModelForTokenClassification.from_pretrained(
target,
config = config,
torch_dtype = torch.float16,
attn_implementation = "sdpa",
local_files_only = True,
trust_remote_code = True,
ignore_mismatched_sizes = True,
)
else:
return AutoModelForTokenClassification.from_pretrained(
target,
config = config,
torch_dtype = torch.float32,
attn_implementation = "sdpa",
local_files_only = True,
trust_remote_code = True,
ignore_mismatched_sizes = True,
)
# 导出模型
def export(input_path: str, dtype: str) -> None:
output_path = f"{input_path}_{dtype}"
print("")
print(f"正在导出 [green]{output_path}[/] ...")
shutil.rmtree(f"{output_path}", ignore_errors = True)
shutil.copytree(input_path, f"{output_path}", dirs_exist_ok = True)
os.remove(f"{output_path}/model.safetensors") if os.path.exists(f"{output_path}/model.safetensors") else None
os.remove(f"{output_path}/pytorch_model.bin") if os.path.exists(f"{output_path}/pytorch_model.bin") else None
load_model(input_path, output_path).save_pretrained(f"{output_path}")
# 运行主函数
def main(target: str) -> None:
export(target, "bf16")
# 运行主函数
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("target", type = str, help = "目标路径")
args = parser.parse_args()
main(args.target)