-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathutils.py
129 lines (100 loc) · 4.34 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import Dict, Union
from pathlib import Path
import json
import re
import pandas as pd
import torch
from transformers import get_scheduler
# Helper functions
def load_json(file_path: Union[Path, str]) -> pd.DataFrame:
"""jsonl_to_df read jsonl file and return a pandas DataFrame.
Args:
file_path (Union[Path, str]): The jsonl file path.
Returns:
pd.DataFrame: The jsonl file content.
Example:
>>> read_jsonl_file("data/train.jsonl")
id label ... predicted_label evidence_list
0 3984 refutes ... REFUTES [城市規劃是城市建設及管理的依據 , 位於城市管理之規劃 、 建設 、 運作三個階段之首 ,...
.. ... ... ... ... ...
945 3042 supports ... REFUTES [北歐人相傳每當雷雨交加時就是索爾乘坐馬車出來巡視 , 因此稱呼索爾為 “ 雷神 ” 。, ...
[946 rows x 10 columns]
"""
with open(file_path, "r", encoding="utf8") as json_file:
json_list = list(json_file)
return [json.loads(json_str) for json_str in json_list]
def jsonl_dir_to_df(dir_path: Union[Path, str]) -> pd.DataFrame:
"""jsonl_dir_to_df read jsonl dir and return a pandas DataFrame.
This function will read all jsonl files in the dir_path and concat them.
Args:
dir_path (Union[Path, str]): The jsonl dir path.
Returns:
pd.DataFrame: The jsonl dir content.
Example:
>>> read_jsonl_dir("data/extracted_dir/")
id label ... predicted_label evidence_list
0 3984 refutes ... REFUTES [城市規劃是城市建設及管理的依據 , 位於城市管理之規劃 、 建設 、 運作三個階段之首 ,...
.. ... ... ... ... ...
945 3042 supports ... REFUTES [北歐人相傳每當雷雨交加時就是索爾乘坐馬車出來巡視 , 因此稱呼索爾為 “ 雷神 ” 。, ...
[946 rows x 10 columns]
"""
print(f"Reading and concatenating jsonl files in {dir_path}")
return pd.concat(
[pd.DataFrame(load_json(file)) for file in Path(dir_path).glob("*.jsonl")]
)
def generate_evidence_to_wiki_pages_mapping(
wiki_pages: pd.DataFrame,
) -> Dict[str, Dict[int, str]]:
"""generate_wiki_pages_dict generate a mapping from evidence to wiki pages by evidence id.
Args:
wiki_pages (pd.DataFrame): The wiki pages dataframe
cache(Union[Path, str], optional): The cache file path. Defaults to None.
If cache is None, return the result directly.
Returns:
pd.DataFrame:
"""
def make_dict(x):
result = {}
sentences = re.split(r"\n(?=[0-9])", x)
for sent in sentences:
splitted = sent.split("\t")
if len(splitted) < 2:
# Avoid empty articles
return result
result[splitted[0]] = splitted[1]
return result
# copy wiki_pages
wiki_pages = wiki_pages.copy()
# generate parse mapping
print("Generate parse mapping")
wiki_pages["evidence_map"] = wiki_pages["lines"].parallel_map(make_dict)
# generate id to evidence_map mapping
print("Transform to id to evidence_map mapping")
mapping = dict(
zip(
wiki_pages["id"].to_list(),
wiki_pages["evidence_map"].to_list(),
)
)
# release memory
del wiki_pages
return mapping
def set_lr_scheduler(
optimizer: torch.optim.Optimizer,
num_training_steps: int,
warmup_ratio: float = 0.1,
):
lr_scheduler = get_scheduler(
name="linear",
optimizer=optimizer,
num_warmup_steps=int(num_training_steps * warmup_ratio),
num_training_steps=num_training_steps,
)
return lr_scheduler
def save_checkpoint(model, ckpt_dir: str, current_step: int, mark: str = ""):
if mark != "":
mark += "_"
torch.save(model.state_dict(), f"{ckpt_dir}/{mark}model.{current_step}.pt")
def load_model(model, ckpt_name, ckpt_dir: str):
model.load_state_dict(torch.load(f"{ckpt_dir}/{ckpt_name}"))
return model