-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_storage.py
52 lines (33 loc) · 1.42 KB
/
data_storage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import json
DATA_DIR = "data"
PARALLEL_DATA_DICT_FILE = os.path.join(DATA_DIR, "parallel_dataset.json")
MULT_GENERATIONS_FILENAME = os.path.join(DATA_DIR, "multiple_generations_all_keys.json")
STATS_AND_DATASET_PATH = os.path.join(DATA_DIR, 'stats_dataset.json')
def load_or_create_dict(filepath: str):
new_dict = {}
if os.path.exists(filepath):
print(f"Loading from cached file: {filepath}")
with open(filepath, "r") as fp:
new_dict = json.load(fp)
return new_dict
def save_dict(data_dict, filepath: str):
with open(filepath, "w") as fp:
json.dump(data_dict, fp)
print(f"saved to {filepath}")
def get_original_data_filename():
return os.path.join(DATA_DIR, "truthfulQA-alldata.csv")
def load_or_create_parallel_data_dict():
return load_or_create_dict(PARALLEL_DATA_DICT_FILE)
def save_parallel_data_dict(data_dict):
save_dict(data_dict, PARALLEL_DATA_DICT_FILE)
def load_or_create_multi_generations():
return load_or_create_dict(MULT_GENERATIONS_FILENAME)
def save_multi_generations(generations_dict):
save_dict(generations_dict, MULT_GENERATIONS_FILENAME)
def load_or_create_stats():
return load_or_create_dict(STATS_AND_DATASET_PATH)
def save_stats(stats_and_dataset_dict):
save_dict(stats_and_dataset_dict, STATS_AND_DATASET_PATH)
# 10 paraphrases of each original example
PARAPHRASE_DICT = os.path.join(DATA_DIR, "paraphrases.json")