Skip to content

Commit 10c36be

Browse files
committed
add tests
1 parent a1e2766 commit 10c36be

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+5463
-305
lines changed

albert_pytorch

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Subproject commit 644877786907f798ec034ffb077398d784512e6c
+281
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,281 @@
1+
#!/usr/bin/env python
2+
# encoding: utf-8
3+
'''
4+
@author: zessay
5+
@license: (C) Copyright Sogou.
6+
7+
@file: process_multiQA_pipline.py
8+
@time: 2019/12/25 12:22
9+
@description:
10+
'''
11+
import json
12+
import os
13+
import re
14+
import argparse
15+
import random
16+
import logging
17+
import numpy as np
18+
import pandas as pd
19+
from pathlib import Path
20+
import multiprocessing as mp
21+
from tqdm import tqdm
22+
import shutil
23+
24+
log_format = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
25+
datefmt='%m/%d/%Y %H:%M:%S')
26+
logger = logging.getLogger("Process-MultiQA")
27+
logger.setLevel(logging.INFO)
28+
console_handler = logging.StreamHandler()
29+
console_handler.setFormatter(log_format)
30+
logger.handlers = [console_handler]
31+
32+
def seed_everything(seed=2020):
33+
os.environ['PYTHONASHSEED'] = str(seed)
34+
random.seed(seed)
35+
np.random.seed(seed)
36+
37+
38+
def read_and_get_QAlist_from_json(config):
39+
"""根据标准的多轮对话json文件生成"""
40+
with open(Path(config.from_path) / config.from_file, 'r', encoding='utf8') as f:
41+
data = json.load(f)
42+
43+
seq_num = []
44+
question = []
45+
answer = []
46+
num = -1
47+
48+
for dialogue in data:
49+
utterance = dialogue['utterance']
50+
reply = dialogue['responses'][0]['reply']
51+
52+
uttrs = re.split(r"\[U=\d\]", utterance.replace(" ", ""))[1:]
53+
if len(uttrs) <= 1:
54+
num += 1
55+
56+
q = "\t".join(uttrs)
57+
a = reply
58+
59+
seq_num.append(f"D_{config.to_file_name}{num}")
60+
question.append(q.strip())
61+
answer.append(a)
62+
return seq_num, question, answer
63+
64+
def read_and_get_QAlist_from_txt(config):
65+
"""一行一句话,每一个dialogue用换行符分隔"""
66+
data = []
67+
with open(Path(config.from_path)/ config.from_file, 'r', encoding='utf8') as f:
68+
dialogue = []
69+
for i, line in enumerate(f):
70+
line = line.strip()
71+
if line:
72+
dialogue.append(line)
73+
else:
74+
data.append(dialogue)
75+
dialogue = []
76+
seq_num, question, answer = [], [], []
77+
78+
for index, dialogue in enumerate(data):
79+
q = ""
80+
for i in range(0, len(dialogue), 2):
81+
q += dialogue[i] + "\t"
82+
if (i + 1) >= len(dialogue):
83+
continue
84+
a = dialogue[i + 1]
85+
seq_num.append(f"D_{config.to_file_name}{index}")
86+
question.append(q.strip())
87+
answer.append(a)
88+
return seq_num, question, answer
89+
90+
def process_and_save(seq_num, question, answer):
91+
# seq_num相同表示是同一组对话
92+
# questions表示是当前response的历史utterances
93+
# answer表示当前回复
94+
pos_data = pd.DataFrame({"D_num": seq_num, "utterances": question, "response": answer})
95+
## 得到最后一个utterance以及之前的utterances
96+
pos_data['prev_uttrs'] = pos_data['utterances'].apply(lambda s: "\t".join(s.split("\t")[:-1]))
97+
pos_data['last'] = pos_data['utterances'].apply(lambda s: s.split("\t")[-1])
98+
## 获取包含当前轮总共的轮数
99+
pos_data['turns'] = pos_data['utterances'].apply(lambda s: len(s.split("\t")))
100+
101+
pos_data = pos_data[['D_num', 'turns', 'utterances', 'prev_uttrs', 'last', 'response']]
102+
# 返回处理好的数据
103+
return pos_data
104+
105+
def sample_multiQA_neg_random_turn(pos_data, start, end, responses, length, prefix):
106+
"""
107+
从任意轮的回复中选择一次response作为neg回复
108+
:param pos_data:
109+
:return:
110+
"""
111+
neg_data = pos_data[start:end].reset_index(drop=True).copy()
112+
to_file = prefix + f"_{start}_{end}" + '.csv'
113+
path = "./data"
114+
if not Path(path).exists():
115+
Path(path).mkdir(parents=True)
116+
t_file = Path(path) / to_file
117+
pid = os.getpid()
118+
logger.info(f"{mp.current_process()} 已启动...")
119+
120+
for i, record in tqdm(neg_data.iterrows()):
121+
pos = record.response
122+
index = np.random.randint(0, length)
123+
low = max(0, index-10)
124+
high = min(length, index+10)
125+
candidates = responses[low:high]
126+
if pos in candidates:
127+
candidates.remove(pos)
128+
neg_index = np.random.randint(len(candidates))
129+
neg = candidates[neg_index]
130+
131+
if (i+1) % 5000 == 0:
132+
logger.info(f"{pid}\t已经处理了 {i+1} 条数据...")
133+
record.response = neg
134+
neg_data.loc[i] = record
135+
neg_data.to_csv(t_file, index=False)
136+
137+
def sample_multiQA_neg_same_turn(pos_data, config):
138+
"""
139+
从当前轮次的对话中选择一句(非pos的)作为neg回复
140+
:param pos_data:
141+
:return:
142+
"""
143+
neg_data = pos_data.copy()
144+
to_file = Path(config.to_path) / f"multi_{config.to_file_name}_neg_same.csv"
145+
146+
for i, (tag, group) in tqdm(enumerate(neg_data.groupby(by=['D_num'],
147+
as_index=False,
148+
sort=False))):
149+
# 将所有的回复和之前的问题合并
150+
candidates = group['response'].values.tolist() + group.iloc[-1].utterances.split("\t")
151+
for j, record in group.iterrows():
152+
pos = record.response
153+
candidates.remove(pos)
154+
if len(candidates) <= 0:
155+
neg = "嗯嗯,好吧!"
156+
else:
157+
neg = np.random.choice(candidates)
158+
159+
record.response = neg
160+
candidates += [pos]
161+
group.loc[j, :] = record
162+
if len(group.shape) == 1:
163+
group = group.to_frame()
164+
if i == 0:
165+
group.to_csv(to_file, mode="w", encoding="utf8", header=True, index=False)
166+
else:
167+
group.to_csv(to_file, mode="a", encoding="utf8", header=False, index=False)
168+
if (i+1) % 1000 == 0:
169+
logger.info(f"已经处理了 {i+1} 组数据 ...")
170+
neg_data = pd.read_csv(to_file)
171+
os.remove(to_file)
172+
return neg_data
173+
174+
def gen_multiQA_train_val(p_data, n_data, config):
175+
"""
176+
生成训练集和验证集数据
177+
:param p_data:
178+
:param n_data:
179+
:return:
180+
"""
181+
p_data['label'] = 1
182+
n_data['label'] = 0
183+
# 将索引打乱
184+
p_all_index = p_data.index.tolist()
185+
n_all_index = n_data.index.tolist()
186+
187+
p_rand_index = np.random.permutation(p_all_index)
188+
n_rand_index = np.random.permutation(n_all_index)
189+
190+
# 定义训练集的长度
191+
p_train_len = int(len(p_rand_index) * 0.8)
192+
n_train_len = int(len(n_rand_index) * 0.8)
193+
194+
p_train = p_data.iloc[p_rand_index[:p_train_len], :].reset_index(drop=True)
195+
p_valid = p_data.iloc[p_rand_index[p_train_len:], :].reset_index(drop=True)
196+
197+
n_train = n_data.iloc[n_rand_index[:n_train_len], :].reset_index(drop=True)
198+
n_valid = n_data.iloc[n_rand_index[n_train_len:], :].reset_index(drop=True)
199+
200+
multi_train = pd.concat([p_train, n_train], axis=0, sort=False, ignore_index=True)
201+
multi_valid = pd.concat([p_valid, n_valid], axis=0, sort=False, ignore_index=True)
202+
203+
train_index = np.random.permutation(range(multi_train.shape[0]))
204+
valid_index = np.random.permutation(range(multi_valid.shape[0]))
205+
206+
multi_train = multi_train.iloc[train_index, :]
207+
multi_valid = multi_valid.iloc[valid_index, :]
208+
209+
multi_train.to_csv(Path(config.to_path) / f"multi_{config.to_file_name}_train.csv", index=False)
210+
multi_valid.to_csv(Path(config.to_path) / f"multi_{config.to_file_name}_val.csv", index=False)
211+
212+
213+
if __name__ == "__main__":
214+
parser = argparse.ArgumentParser()
215+
216+
# 添加参数
217+
parser.add_argument("--from_path", default=None, type=str, required=True,
218+
help="The input data dir.")
219+
parser.add_argument("--from_file", default=None, type=str, required=True,
220+
help="The input file, should be .json preprocessed or .data.format.")
221+
parser.add_argument("--to_path", default=None, type=str, required=True,
222+
help="The save path of train and val file.")
223+
parser.add_argument("--to_file_name", default=None, type=str, required=True,
224+
help="The final file name, like 50w, persona, or pass, don't need the full name.")
225+
parser.add_argument("--sample", default='random', type=str, choices=['random', 'same'],
226+
help="Sample the neg in the random turn or the same turn.")
227+
parser.add_argument("--n_jobs", default=5, type=int,
228+
help="The number of jobs to use when random sample.")
229+
parser.add_argument("--seed", default=2020, type=int,
230+
help="The sample seed.")
231+
232+
config = parser.parse_args()
233+
seed_everything(config.seed)
234+
235+
logger.info(f"读取文件 {config.from_file}")
236+
# 首先读取数据
237+
if config.from_file.endswith(".json"):
238+
seq_num, question, answer = read_and_get_QAlist_from_json(config)
239+
else:
240+
seq_num, question, answer = read_and_get_QAlist_from_txt(config)
241+
242+
# 简单处理得到正样本
243+
logger.info("对文件进行处理")
244+
p_data = process_and_save(seq_num, question, answer)
245+
246+
logger.info(f"在 {config.sample} turn 中进行采样")
247+
if config.sample == "random":
248+
# 随机采样
249+
250+
responses = p_data['response'].values.tolist()
251+
length = p_data.shape[0]
252+
prefix = f"multi_{config.to_file_name}_neg_random"
253+
results = []
254+
pool = mp.Pool()
255+
for p in range(config.n_jobs):
256+
start = int(p / config.n_jobs * length)
257+
end = int((p+1) / config.n_jobs * length)
258+
result = pool.apply_async(sample_multiQA_neg_random_turn,
259+
args=(p_data,start,end,responses,length,prefix,))
260+
results.append(result)
261+
pool.close()
262+
pool.join()
263+
if all([res.ready() for res in results]):
264+
logger.info("负采样完成")
265+
files = os.listdir("./data")
266+
n_data = pd.DataFrame()
267+
for file in files:
268+
tmp = pd.read_csv(Path("./data") / file)
269+
n_data = pd.concat([n_data, tmp], axis=0, sort=False, ignore_index=True)
270+
assert n_data.shape == p_data.shape
271+
# 删除中间生成的数据
272+
shutil.rmtree("./data")
273+
elif config.sample == "same":
274+
n_data = sample_multiQA_neg_same_turn(p_data, config)
275+
else:
276+
raise ValueError(f"The {config.sample} is invalid, only `random` and `same` allow.")
277+
278+
# 得到最终的数据并保存
279+
logger.info(f"将数据保存为训练集和验证集,路径为 {config.to_path}")
280+
gen_multiQA_train_val(p_data, n_data, config)
281+
+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
python process_multiQA_pipline.py \
2+
--from_path=/home/speech/data \
3+
--from_file=jd_multiturn_1223.json \
4+
--to_path=/home/speech/data/multi_clean \
5+
--to_file_name=scene \
6+
--sample=random \
7+
--seed=2020

requirements.txt

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
tensorflow-gpu
2+
pytorch
3+
dill
4+
numpy
5+
pandas
6+
hyperopt
7+
flask
8+
fuzzywuzzy
9+
jieba
10+
pkuseg
11+
boto3
12+
nltk
13+
sentencepiece
14+
tokenizers
15+
gensim

snlp/base/base_model.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
import torch
1717
import torch.nn as nn
18+
import torch.nn.init as init
1819

1920
import snlp.callbacks as callbacks
2021
import snlp.preprocessors as preprocessors
@@ -198,6 +199,7 @@ def _make_embedding_layer(self,
198199
embedding_dim: int=0,
199200
freeze: bool=True,
200201
embedding: typing.Optional[np.ndarray]=None,
202+
padding_idx: int=0,
201203
**kwargs) -> nn.Module:
202204
if isinstance(embedding, np.ndarray):
203205
return nn.Embedding.from_pretrained(
@@ -207,7 +209,8 @@ def _make_embedding_layer(self,
207209
else:
208210
return nn.Embedding(
209211
num_embeddings=num_embeddings,
210-
embedding_dim=embedding_dim
212+
embedding_dim=embedding_dim,
213+
padding_idx=padding_idx
211214
)
212215

213216
def _make_default_embedding_layer(self,
@@ -255,8 +258,10 @@ def _make_perceptron_layer(self,
255258
in_features: int=0,
256259
out_features: int=0,
257260
activation: nn.Module=nn.ReLU()) -> nn.Module:
261+
single_perceptron = nn.Linear(in_features, out_features)
262+
init.xavier_normal_(single_perceptron.weight)
258263
return nn.Sequential(
259-
nn.Linear(in_features, out_features),
264+
single_perceptron,
260265
activation
261266
)
262267

snlp/base/base_preprocessor.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pandas as pd
1313
import typing
1414
from pathlib import Path
15-
import pickle
15+
import dill
1616
from snlp.base import units
1717

1818

@@ -47,7 +47,7 @@ def save(self, dirpath: typing.Union[str, Path]):
4747
data_file_path = dirpath.joinpath(self.DATA_FILENAME)
4848
if not dirpath.exists():
4949
dirpath.mkdir(parents=True)
50-
pickle.dump(self, open(data_file_path, mode='wb'))
50+
dill.dump(self, open(data_file_path, mode='wb'))
5151

5252
@classmethod
5353
def _default_units(cls) -> list:
@@ -62,4 +62,4 @@ def load_preprocessor(dirpath: typing.Union[str, Path]) -> 'BasePreprocessor':
6262
"""Load the fitted context"""
6363
dirpath = Path(dirpath)
6464
data_file_path = dirpath.joinpath(BasePreprocessor.DATA_FILENAME)
65-
return pickle.load(open(data_file_path, mode='rb'))
65+
return dill.load(open(data_file_path, mode='rb'))

snlp/base/units/frequency_filter.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,11 @@ def fit(self, list_of_tokens: typing.List[typing.List[str]]):
4545
def transform(self, input_: list) -> list:
4646
"""Transform a list of tokens"""
4747
valid_terms = self._context[self._mode]
48-
return list(filter(lambda token: token in valid_terms, input_))
48+
result = list(filter(lambda token: token in valid_terms, input_))
49+
## 如果过滤之后为空,则不过滤了
50+
if len(result) <= 0:
51+
result = input_
52+
return result
4953

5054
@classmethod
5155
def _tf(cls, list_of_tokens: list) -> dict:

0 commit comments

Comments
 (0)