forked from lonePatient/albert_pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlcqmc_progressor.py
185 lines (162 loc) · 7.81 KB
/
lcqmc_progressor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import torch
import csv
from callback.progressbar import ProgressBar
from model.tokenization_bert import BertTokenizer
from common.tools import logger
from torch.utils.data import TensorDataset
class InputExample(object):
def __init__(self, guid, text_a, text_b=None, label=None):
"""Constructs a InputExample.
Args:
guid: Unique id for the example.
text_a: string. The untokenized text of the first sequence. For single
sequence tasks, only this sequence must be specified.
text_b: (Optional) string. The untokenized text of the second sequence.
Only must be specified for sequence pair tasks.
label: (Optional) string. The label of the example. This should be
specified for train and dev examples, but not for test examples.
"""
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeature(object):
'''
A single set of features of data.
'''
def __init__(self, input_ids, input_mask, segment_ids, label_id, input_len):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.input_len = input_len
class BertProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def __init__(self, vocab_path, do_lower_case):
self.tokenizer = BertTokenizer(vocab_path, do_lower_case)
def get_train(self, data_file):
"""Gets a collection of `InputExample`s for the train set."""
return self.read_data(data_file)
def get_dev(self, data_file):
"""Gets a collection of `InputExample`s for the dev set."""
return self.read_data(data_file)
def get_test(self, lines):
return lines
def get_labels(self):
"""Gets the list of labels for this data set."""
return ["0", "1"]
@classmethod
def read_data(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r", encoding="utf-8-sig") as f:
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
lines.append(line)
return lines
def truncate_seq_pair(self, tokens_a, tokens_b, max_length):
# This is a simple heuristic which will always truncate the longer sequence
# one token at a time. This makes more sense than truncating an equal percent
# of tokens from each, since if one sequence is very short then each token
# that's truncated likely contains more information than a longer sequence.
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def create_examples(self, lines, example_type, cached_examples_file):
'''
Creates examples for data
'''
pbar = ProgressBar(n_total=len(lines), desc='create examples')
if cached_examples_file.exists():
logger.info("Loading examples from cached file %s", cached_examples_file)
examples = torch.load(cached_examples_file)
else:
examples = []
for i, line in enumerate(lines):
guid = '%s-%d' % (example_type, i)
text_a = line[0]
text_b = line[1]
label = line[2]
label = int(label)
example = InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
examples.append(example)
pbar(step=i)
logger.info("Saving examples into cached file %s", cached_examples_file)
torch.save(examples, cached_examples_file)
return examples
def create_features(self, examples, max_seq_len, cached_features_file):
'''
# The convention in BERT is:
# (a) For sequence pairs:
# tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
# type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
# (b) For single sequences:
# tokens: [CLS] the dog is hairy . [SEP]
# type_ids: 0 0 0 0 0 0 0
'''
pbar = ProgressBar(n_total=len(examples), desc='create features')
if cached_features_file.exists():
logger.info("Loading features from cached file %s", cached_features_file)
features = torch.load(cached_features_file)
else:
features = []
for ex_id, example in enumerate(examples):
tokens_a = self.tokenizer.tokenize(example.text_a)
tokens_b = None
label_id = example.label
if example.text_b:
tokens_b = self.tokenizer.tokenize(example.text_b)
# Modifies `tokens_a` and `tokens_b` in place so that the total
# length is less than the specified length.
# Account for [CLS], [SEP], [SEP] with "- 3"
self.truncate_seq_pair(tokens_a, tokens_b, max_length=max_seq_len - 3)
else:
# Account for [CLS] and [SEP] with '-2'
if len(tokens_a) > max_seq_len - 2:
tokens_a = tokens_a[:max_seq_len - 2]
tokens = ['[CLS]'] + tokens_a + ['[SEP]']
segment_ids = [0] * len(tokens)
if tokens_b:
tokens += tokens_b + ['[SEP]']
segment_ids += [1] * (len(tokens_b) + 1)
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
input_mask = [1] * len(input_ids)
padding = [0] * (max_seq_len - len(input_ids))
input_len = len(input_ids)
input_ids += padding
input_mask += padding
segment_ids += padding
assert len(input_ids) == max_seq_len
assert len(input_mask) == max_seq_len
assert len(segment_ids) == max_seq_len
if ex_id < 2:
logger.info("*** Example ***")
logger.info(f"guid: {example.guid}" % ())
logger.info(f"tokens: {' '.join([str(x) for x in tokens])}")
logger.info(f"input_ids: {' '.join([str(x) for x in input_ids])}")
logger.info(f"input_mask: {' '.join([str(x) for x in input_mask])}")
logger.info(f"segment_ids: {' '.join([str(x) for x in segment_ids])}")
logger.info(f"label id : {label_id}")
feature = InputFeature(input_ids=input_ids,
input_mask=input_mask,
segment_ids=segment_ids,
label_id=label_id,
input_len=input_len)
features.append(feature)
pbar(step=ex_id)
logger.info("Saving features into cached file %s", cached_features_file)
torch.save(features, cached_features_file)
return features
def create_dataset(self, features):
# Convert to Tensors and build dataset
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
return dataset