Skip to content

Commit 8d72d2f

Browse files
committed
update
1 parent 39f869f commit 8d72d2f

33 files changed

+41992
-13780
lines changed

Diff for: .gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ model_vars
77
play_ground.py
88
*.pkl
99
tri_*
10+
test/*.txt

Diff for: README.md

+25-1
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,36 @@ token的所有类别:`['PLAIN', 'PUNCT', 'DATE', 'LETTERS', 'CARDINAL', 'VERBA
7070

7171
## 比赛记录
7272

73-
###2017-11-15
73+
### 2017-11-15
7474

7575
比赛第一天,新的测试数据发布。按照之前的模型跑一版,到0.9933,直接窜到第三名。也不知道可以坚挺多久。
7676

7777
使用大字典的方法跑了一个baseline:0.9893,用作后续提高的对比数据。
7878

79+
### 2017-11-17
80+
81+
寻找错误大概从两个方向上去找:1.找分类器分错的类别。2.找normalizer没有换对的情况。
82+
83+
分类器分错的类别可以通过分类器输出的分类信心概率来获取。比如使用xgboost中的`'objective': 'multi:softprob'`参数设定。检查prob值特别小的数据。
84+
85+
通过最终结果去找可以直接使用大字典跑出来的baseline去比较不同,但这个方法也不尽靠谱,不同的不一定错,相同的不一定对。但也可以从中观察,找到一些规律。
86+
87+
还可以直接在最终的结果中去找数字和特殊符号,一般是normalize失败的情况,这个方法可以查出来。
88+
89+
目前是0.9947。
90+
91+
### 2017-11-20
92+
93+
今天是比赛最后一天,发现之前xgboost的参数设置有问题: `'nthread': -1`,直接删掉,因为默认值是最大值。果然重新运行之后cpu的8核满负荷运行。快了很多。
94+
95+
另外把之前的基于context的方法做了一下修正:之前每个token都会包括上一个token和下一个token。那么句子中的最后一个token的context会包含下一句的第一个token,句子的第一个token的context会包含上一句的最后一个token。这显然是不对的。因此在每句之间加入全0向量。忽略该向量本身,只是用于其他向量的context。
96+
97+
最后一天重新check所有类别的rule函数,发现还有不错的可改进空间。
98+
99+
还能继续改的:
100+
- `ORDINAL`: 前面有没有the的问题。
101+
- `VERBATIM`: `#` 是hash-tag还是number的问题。
102+
79103
## 其他信息
80104

81105
### 使用到的第三方包

Diff for: baseline_class.py

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# coding=utf-8
2+
# @author: cer
3+
# use python3
4+
from __future__ import print_function
5+
import os
6+
import operator
7+
from num2words import num2words # 这个包不支持中文
8+
import gc
9+
import pandas as pd
10+
import numpy as np
11+
import time
12+
import pickle as pkl
13+
14+
train_file_name = "input/en_train.csv"
15+
test_file = 'input/en_test_2.csv'
16+
baseline_file = 'output/baseline_class.csv'
17+
pkl_name = "output/class_dict.pkl"
18+
train_df = pd.read_csv(train_file_name)
19+
test_df = pd.read_csv(test_file)
20+
21+
22+
def train():
23+
print('Train start...')
24+
if os.path.exists(pkl_name):
25+
with open(pkl_name, "rb") as f:
26+
res = pkl.load(f)
27+
else:
28+
# Work with primary dataset
29+
train_file = open(train_file_name, encoding='UTF8')
30+
train_file.readline()
31+
res = dict()
32+
total = 0
33+
not_same = 0
34+
while 1:
35+
line = train_file.readline().strip()
36+
if line == '':
37+
break
38+
total += 1
39+
pos = line.find('","')
40+
text = line[pos + 2:]
41+
if text[:3] == '","':
42+
continue
43+
text = text[1:-1]
44+
arr = text.split('","')
45+
if arr[0] != arr[1]:
46+
not_same += 1
47+
if arr[0] not in res:
48+
res[arr[0]] = dict()
49+
res[arr[0]][arr[1]] = 1
50+
else:
51+
if arr[1] in res[arr[0]]:
52+
res[arr[0]][arr[1]] += 1
53+
else:
54+
res[arr[0]][arr[1]] = 1
55+
train_file.close()
56+
print(train_file_name + ':\tTotal: {} Have diff value: {}'.format(total, not_same))
57+
58+
# Work with additional dataset from https://www.kaggle.com/google-nlu/text-normalization
59+
files = ['output_1.csv', 'output_6.csv', 'output_11.csv', 'output_16.csv', \
60+
'output_21.csv', 'output_91.csv', 'output_96.csv']
61+
62+
for add_file_name in files:
63+
train_file = open(os.path.join("input", 'tn', add_file_name), encoding='UTF8')
64+
train_file.readline()
65+
while 1:
66+
line = train_file.readline().strip()
67+
if line == '':
68+
break
69+
line = line.replace(',NA,', ',"NA",')
70+
total += 1
71+
pos = line.find('","')
72+
text = line[pos + 2:]
73+
if text[:3] == '","':
74+
continue
75+
text = text[1:-1]
76+
arr = text.split('","')
77+
if arr[0] == '<eos>':
78+
continue
79+
if arr[1] != '<self>':
80+
not_same += 1
81+
82+
if arr[1] == '<self>' or arr[1] == 'sil':
83+
arr[1] = arr[0]
84+
85+
if arr[0] not in res:
86+
res[arr[0]] = dict()
87+
res[arr[0]][arr[1]] = 1
88+
else:
89+
if arr[1] in res[arr[0]]:
90+
res[arr[0]][arr[1]] += 1
91+
else:
92+
res[arr[0]][arr[1]] = 1
93+
train_file.close()
94+
print(add_file_name + ':\tTotal: {} Have diff value: {}'.format(total, not_same))
95+
96+
return res
97+
98+
99+
def solve(res):
100+
sdict = {}
101+
sdict['km2'] = 'square kilometers'
102+
sdict['km'] = 'kilometers'
103+
sdict['kg'] = 'kilograms'
104+
sdict['lb'] = 'pounds'
105+
sdict['dr'] = 'doctor'
106+
sdict['m²'] = 'square meters'
107+
108+
total = 0
109+
changes = 0
110+
out = open(baseline_file, "w", encoding='UTF8')
111+
out.write('"id","after"\n')
112+
test = open(test_file, encoding='UTF8')
113+
test.readline().strip()
114+
while 1:
115+
line = test.readline().strip()
116+
if line == '':
117+
break
118+
119+
pos = line.find(',')
120+
i1 = line[:pos]
121+
line = line[pos + 1:]
122+
123+
pos = line.find(',')
124+
i2 = line[:pos]
125+
line = line[pos + 1:]
126+
127+
line = line[1:-1]
128+
out.write('"' + i1 + '_' + i2 + '",')
129+
if line in res:
130+
srtd = sorted(res[line].items(), key=operator.itemgetter(1), reverse=True)
131+
out.write('"' + srtd[0][0] + '"')
132+
changes += 1
133+
else:
134+
# line.split(' ')
135+
if len(line) > 1:
136+
val = line.split(',')
137+
if len(val) == 2 and val[0].isdigit and val[1].isdigit:
138+
line = ''.join(val)
139+
140+
if line.isdigit():
141+
srtd = line.translate(SUB)
142+
srtd = srtd.translate(SUP)
143+
srtd = srtd.translate(OTH)
144+
out.write('"' + num2words(float(srtd)) + '"')
145+
changes += 1
146+
elif len(line.split(' ')) > 1:
147+
val = line.split(' ')
148+
for i, v in enumerate(val):
149+
if v.isdigit():
150+
srtd = v.translate(SUB)
151+
srtd = srtd.translate(SUP)
152+
srtd = srtd.translate(OTH)
153+
val[i] = num2words(float(srtd))
154+
elif v in sdict:
155+
val[i] = sdict[v]
156+
157+
out.write('"' + ' '.join(val) + '"')
158+
changes += 1
159+
else:
160+
out.write('"' + line + '"')
161+
162+
out.write('\n')
163+
total += 1
164+
165+
print('Total: {} Changed: {}'.format(total, changes))
166+
test.close()
167+
out.close()
168+
169+
if __name__ == '__main__':
170+
res = train()
171+
solve(res)

Diff for: compare_with_big_data.py

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
# coding=utf-8
2+
# @author: cer
3+
# use python3
4+
from __future__ import print_function
5+
import os
6+
import operator
7+
from num2words import num2words # 这个包不支持中文
8+
import gc
9+
import pandas as pd
10+
import numpy as np
11+
import time
12+
13+
train_file_name = "input/en_train.csv"
14+
test_file = 'input/en_test_2.csv'
15+
baseline_file = 'output/baseline.csv'
16+
train_df = pd.read_csv(train_file_name)
17+
test_df = pd.read_csv(test_file)
18+
19+
SUB = str.maketrans("₀₁₂₃₄₅₆₇₈₉", "0123456789")
20+
SUP = str.maketrans("⁰¹²³⁴⁵⁶⁷⁸⁹", "0123456789")
21+
OTH = str.maketrans("፬", "4")
22+
23+
24+
def train():
25+
print('Train start...')
26+
27+
# Work with primary dataset
28+
train_file = open(train_file_name, encoding='UTF8')
29+
train_file.readline()
30+
res = dict()
31+
total = 0
32+
not_same = 0
33+
while 1:
34+
line = train_file.readline().strip()
35+
if line == '':
36+
break
37+
total += 1
38+
pos = line.find('","')
39+
text = line[pos + 2:]
40+
if text[:3] == '","':
41+
continue
42+
text = text[1:-1]
43+
arr = text.split('","')
44+
if arr[0] != arr[1]:
45+
not_same += 1
46+
if arr[0] not in res:
47+
res[arr[0]] = dict()
48+
res[arr[0]][arr[1]] = 1
49+
else:
50+
if arr[1] in res[arr[0]]:
51+
res[arr[0]][arr[1]] += 1
52+
else:
53+
res[arr[0]][arr[1]] = 1
54+
train_file.close()
55+
print(train_file_name + ':\tTotal: {} Have diff value: {}'.format(total, not_same))
56+
57+
# Work with additional dataset from https://www.kaggle.com/google-nlu/text-normalization
58+
files = ['output_1.csv', 'output_6.csv', 'output_11.csv', 'output_16.csv', \
59+
'output_21.csv', 'output_91.csv', 'output_96.csv']
60+
61+
for add_file_name in files:
62+
train_file = open(os.path.join("input", 'tn', add_file_name), encoding='UTF8')
63+
train_file.readline()
64+
while 1:
65+
line = train_file.readline().strip()
66+
if line == '':
67+
break
68+
line = line.replace(',NA,', ',"NA",')
69+
total += 1
70+
pos = line.find('","')
71+
text = line[pos + 2:]
72+
if text[:3] == '","':
73+
continue
74+
text = text[1:-1]
75+
arr = text.split('","')
76+
if arr[0] == '<eos>':
77+
continue
78+
if arr[1] != '<self>':
79+
not_same += 1
80+
81+
if arr[1] == '<self>' or arr[1] == 'sil':
82+
arr[1] = arr[0]
83+
84+
if arr[0] not in res:
85+
res[arr[0]] = dict()
86+
res[arr[0]][arr[1]] = 1
87+
else:
88+
if arr[1] in res[arr[0]]:
89+
res[arr[0]][arr[1]] += 1
90+
else:
91+
res[arr[0]][arr[1]] = 1
92+
train_file.close()
93+
print(add_file_name + ':\tTotal: {} Have diff value: {}'.format(total, not_same))
94+
95+
return res
96+
97+
98+
def solve(res):
99+
100+
101+
total = 0
102+
changes = 0
103+
out = open(baseline_file, "w", encoding='UTF8')
104+
out.write('"id","after"\n')
105+
test = open(test_file, encoding='UTF8')
106+
test.readline().strip()
107+
while 1:
108+
line = test.readline().strip()
109+
if line == '':
110+
break
111+
112+
pos = line.find(',')
113+
i1 = line[:pos]
114+
line = line[pos + 1:]
115+
116+
pos = line.find(',')
117+
i2 = line[:pos]
118+
line = line[pos + 1:]
119+
120+
line = line[1:-1]
121+
out.write('"' + i1 + '_' + i2 + '",')
122+
if line in res:
123+
srtd = sorted(res[line].items(), key=operator.itemgetter(1), reverse=True)
124+
out.write('"' + srtd[0][0] + '"')
125+
changes += 1
126+
else:
127+
# line.split(' ')
128+
if len(line) > 1:
129+
val = line.split(',')
130+
if len(val) == 2 and val[0].isdigit and val[1].isdigit:
131+
line = ''.join(val)
132+
133+
if line.isdigit():
134+
srtd = line.translate(SUB)
135+
srtd = srtd.translate(SUP)
136+
srtd = srtd.translate(OTH)
137+
out.write('"' + num2words(float(srtd)) + '"')
138+
changes += 1
139+
elif len(line.split(' ')) > 1:
140+
val = line.split(' ')
141+
for i, v in enumerate(val):
142+
if v.isdigit():
143+
srtd = v.translate(SUB)
144+
srtd = srtd.translate(SUP)
145+
srtd = srtd.translate(OTH)
146+
val[i] = num2words(float(srtd))
147+
elif v in sdict:
148+
val[i] = sdict[v]
149+
150+
out.write('"' + ' '.join(val) + '"')
151+
changes += 1
152+
else:
153+
out.write('"' + line + '"')
154+
155+
out.write('\n')
156+
total += 1
157+
158+
print('Total: {} Changed: {}'.format(total, changes))
159+
test.close()
160+
out.close()
161+
162+
if __name__ == '__main__':
163+
res = train()
164+
solve(res)

0 commit comments

Comments
 (0)