|
| 1 | +#!/usr/bin/env python |
| 2 | +# encoding: utf-8 |
| 3 | +''' |
| 4 | +@author: zessay |
| 5 | +@license: (C) Copyright Sogou. |
| 6 | + |
| 7 | +@file: process_multiQA_pipline.py |
| 8 | +@time: 2019/12/25 12:22 |
| 9 | +@description: |
| 10 | +''' |
| 11 | +import json |
| 12 | +import os |
| 13 | +import re |
| 14 | +import argparse |
| 15 | +import random |
| 16 | +import logging |
| 17 | +import numpy as np |
| 18 | +import pandas as pd |
| 19 | +from pathlib import Path |
| 20 | +import multiprocessing as mp |
| 21 | +from tqdm import tqdm |
| 22 | +import shutil |
| 23 | + |
| 24 | +log_format = logging.Formatter(fmt='%(asctime)s - %(levelname)s - %(name)s - %(message)s', |
| 25 | + datefmt='%m/%d/%Y %H:%M:%S') |
| 26 | +logger = logging.getLogger("Process-MultiQA") |
| 27 | +logger.setLevel(logging.INFO) |
| 28 | +console_handler = logging.StreamHandler() |
| 29 | +console_handler.setFormatter(log_format) |
| 30 | +logger.handlers = [console_handler] |
| 31 | + |
| 32 | +def seed_everything(seed=2020): |
| 33 | + os.environ['PYTHONASHSEED'] = str(seed) |
| 34 | + random.seed(seed) |
| 35 | + np.random.seed(seed) |
| 36 | + |
| 37 | + |
| 38 | +def read_and_get_QAlist_from_json(config): |
| 39 | + """根据标准的多轮对话json文件生成""" |
| 40 | + with open(Path(config.from_path) / config.from_file, 'r', encoding='utf8') as f: |
| 41 | + data = json.load(f) |
| 42 | + |
| 43 | + seq_num = [] |
| 44 | + question = [] |
| 45 | + answer = [] |
| 46 | + num = -1 |
| 47 | + |
| 48 | + for dialogue in data: |
| 49 | + utterance = dialogue['utterance'] |
| 50 | + reply = dialogue['responses'][0]['reply'] |
| 51 | + |
| 52 | + uttrs = re.split(r"\[U=\d\]", utterance.replace(" ", ""))[1:] |
| 53 | + if len(uttrs) <= 1: |
| 54 | + num += 1 |
| 55 | + |
| 56 | + q = "\t".join(uttrs) |
| 57 | + a = reply |
| 58 | + |
| 59 | + seq_num.append(f"D_{config.to_file_name}{num}") |
| 60 | + question.append(q.strip()) |
| 61 | + answer.append(a) |
| 62 | + return seq_num, question, answer |
| 63 | + |
| 64 | +def read_and_get_QAlist_from_txt(config): |
| 65 | + """一行一句话,每一个dialogue用换行符分隔""" |
| 66 | + data = [] |
| 67 | + with open(Path(config.from_path)/ config.from_file, 'r', encoding='utf8') as f: |
| 68 | + dialogue = [] |
| 69 | + for i, line in enumerate(f): |
| 70 | + line = line.strip() |
| 71 | + if line: |
| 72 | + dialogue.append(line) |
| 73 | + else: |
| 74 | + data.append(dialogue) |
| 75 | + dialogue = [] |
| 76 | + seq_num, question, answer = [], [], [] |
| 77 | + |
| 78 | + for index, dialogue in enumerate(data): |
| 79 | + q = "" |
| 80 | + for i in range(0, len(dialogue), 2): |
| 81 | + q += dialogue[i] + "\t" |
| 82 | + if (i + 1) >= len(dialogue): |
| 83 | + continue |
| 84 | + a = dialogue[i + 1] |
| 85 | + seq_num.append(f"D_{config.to_file_name}{index}") |
| 86 | + question.append(q.strip()) |
| 87 | + answer.append(a) |
| 88 | + return seq_num, question, answer |
| 89 | + |
| 90 | +def process_and_save(seq_num, question, answer): |
| 91 | + # seq_num相同表示是同一组对话 |
| 92 | + # questions表示是当前response的历史utterances |
| 93 | + # answer表示当前回复 |
| 94 | + pos_data = pd.DataFrame({"D_num": seq_num, "utterances": question, "response": answer}) |
| 95 | + ## 得到最后一个utterance以及之前的utterances |
| 96 | + pos_data['prev_uttrs'] = pos_data['utterances'].apply(lambda s: "\t".join(s.split("\t")[:-1])) |
| 97 | + pos_data['last'] = pos_data['utterances'].apply(lambda s: s.split("\t")[-1]) |
| 98 | + ## 获取包含当前轮总共的轮数 |
| 99 | + pos_data['turns'] = pos_data['utterances'].apply(lambda s: len(s.split("\t"))) |
| 100 | + |
| 101 | + pos_data = pos_data[['D_num', 'turns', 'utterances', 'prev_uttrs', 'last', 'response']] |
| 102 | + # 返回处理好的数据 |
| 103 | + return pos_data |
| 104 | + |
| 105 | +def sample_multiQA_neg_random_turn(pos_data, start, end, responses, length, prefix): |
| 106 | + """ |
| 107 | + 从任意轮的回复中选择一次response作为neg回复 |
| 108 | + :param pos_data: |
| 109 | + :return: |
| 110 | + """ |
| 111 | + neg_data = pos_data[start:end].reset_index(drop=True).copy() |
| 112 | + to_file = prefix + f"_{start}_{end}" + '.csv' |
| 113 | + path = "./data" |
| 114 | + if not Path(path).exists(): |
| 115 | + Path(path).mkdir(parents=True) |
| 116 | + t_file = Path(path) / to_file |
| 117 | + pid = os.getpid() |
| 118 | + logger.info(f"{mp.current_process()} 已启动...") |
| 119 | + |
| 120 | + for i, record in tqdm(neg_data.iterrows()): |
| 121 | + pos = record.response |
| 122 | + index = np.random.randint(0, length) |
| 123 | + low = max(0, index-10) |
| 124 | + high = min(length, index+10) |
| 125 | + candidates = responses[low:high] |
| 126 | + if pos in candidates: |
| 127 | + candidates.remove(pos) |
| 128 | + neg_index = np.random.randint(len(candidates)) |
| 129 | + neg = candidates[neg_index] |
| 130 | + |
| 131 | + if (i+1) % 5000 == 0: |
| 132 | + logger.info(f"{pid}\t已经处理了 {i+1} 条数据...") |
| 133 | + record.response = neg |
| 134 | + neg_data.loc[i] = record |
| 135 | + neg_data.to_csv(t_file, index=False) |
| 136 | + |
| 137 | +def sample_multiQA_neg_same_turn(pos_data, config): |
| 138 | + """ |
| 139 | + 从当前轮次的对话中选择一句(非pos的)作为neg回复 |
| 140 | + :param pos_data: |
| 141 | + :return: |
| 142 | + """ |
| 143 | + neg_data = pos_data.copy() |
| 144 | + to_file = Path(config.to_path) / f"multi_{config.to_file_name}_neg_same.csv" |
| 145 | + |
| 146 | + for i, (tag, group) in tqdm(enumerate(neg_data.groupby(by=['D_num'], |
| 147 | + as_index=False, |
| 148 | + sort=False))): |
| 149 | + # 将所有的回复和之前的问题合并 |
| 150 | + candidates = group['response'].values.tolist() + group.iloc[-1].utterances.split("\t") |
| 151 | + for j, record in group.iterrows(): |
| 152 | + pos = record.response |
| 153 | + candidates.remove(pos) |
| 154 | + if len(candidates) <= 0: |
| 155 | + neg = "嗯嗯,好吧!" |
| 156 | + else: |
| 157 | + neg = np.random.choice(candidates) |
| 158 | + |
| 159 | + record.response = neg |
| 160 | + candidates += [pos] |
| 161 | + group.loc[j, :] = record |
| 162 | + if len(group.shape) == 1: |
| 163 | + group = group.to_frame() |
| 164 | + if i == 0: |
| 165 | + group.to_csv(to_file, mode="w", encoding="utf8", header=True, index=False) |
| 166 | + else: |
| 167 | + group.to_csv(to_file, mode="a", encoding="utf8", header=False, index=False) |
| 168 | + if (i+1) % 1000 == 0: |
| 169 | + logger.info(f"已经处理了 {i+1} 组数据 ...") |
| 170 | + neg_data = pd.read_csv(to_file) |
| 171 | + os.remove(to_file) |
| 172 | + return neg_data |
| 173 | + |
| 174 | +def gen_multiQA_train_val(p_data, n_data, config): |
| 175 | + """ |
| 176 | + 生成训练集和验证集数据 |
| 177 | + :param p_data: |
| 178 | + :param n_data: |
| 179 | + :return: |
| 180 | + """ |
| 181 | + p_data['label'] = 1 |
| 182 | + n_data['label'] = 0 |
| 183 | + # 将索引打乱 |
| 184 | + p_all_index = p_data.index.tolist() |
| 185 | + n_all_index = n_data.index.tolist() |
| 186 | + |
| 187 | + p_rand_index = np.random.permutation(p_all_index) |
| 188 | + n_rand_index = np.random.permutation(n_all_index) |
| 189 | + |
| 190 | + # 定义训练集的长度 |
| 191 | + p_train_len = int(len(p_rand_index) * 0.8) |
| 192 | + n_train_len = int(len(n_rand_index) * 0.8) |
| 193 | + |
| 194 | + p_train = p_data.iloc[p_rand_index[:p_train_len], :].reset_index(drop=True) |
| 195 | + p_valid = p_data.iloc[p_rand_index[p_train_len:], :].reset_index(drop=True) |
| 196 | + |
| 197 | + n_train = n_data.iloc[n_rand_index[:n_train_len], :].reset_index(drop=True) |
| 198 | + n_valid = n_data.iloc[n_rand_index[n_train_len:], :].reset_index(drop=True) |
| 199 | + |
| 200 | + multi_train = pd.concat([p_train, n_train], axis=0, sort=False, ignore_index=True) |
| 201 | + multi_valid = pd.concat([p_valid, n_valid], axis=0, sort=False, ignore_index=True) |
| 202 | + |
| 203 | + train_index = np.random.permutation(range(multi_train.shape[0])) |
| 204 | + valid_index = np.random.permutation(range(multi_valid.shape[0])) |
| 205 | + |
| 206 | + multi_train = multi_train.iloc[train_index, :] |
| 207 | + multi_valid = multi_valid.iloc[valid_index, :] |
| 208 | + |
| 209 | + multi_train.to_csv(Path(config.to_path) / f"multi_{config.to_file_name}_train.csv", index=False) |
| 210 | + multi_valid.to_csv(Path(config.to_path) / f"multi_{config.to_file_name}_val.csv", index=False) |
| 211 | + |
| 212 | + |
| 213 | +if __name__ == "__main__": |
| 214 | + parser = argparse.ArgumentParser() |
| 215 | + |
| 216 | + # 添加参数 |
| 217 | + parser.add_argument("--from_path", default=None, type=str, required=True, |
| 218 | + help="The input data dir.") |
| 219 | + parser.add_argument("--from_file", default=None, type=str, required=True, |
| 220 | + help="The input file, should be .json preprocessed or .data.format.") |
| 221 | + parser.add_argument("--to_path", default=None, type=str, required=True, |
| 222 | + help="The save path of train and val file.") |
| 223 | + parser.add_argument("--to_file_name", default=None, type=str, required=True, |
| 224 | + help="The final file name, like 50w, persona, or pass, don't need the full name.") |
| 225 | + parser.add_argument("--sample", default='random', type=str, choices=['random', 'same'], |
| 226 | + help="Sample the neg in the random turn or the same turn.") |
| 227 | + parser.add_argument("--n_jobs", default=5, type=int, |
| 228 | + help="The number of jobs to use when random sample.") |
| 229 | + parser.add_argument("--seed", default=2020, type=int, |
| 230 | + help="The sample seed.") |
| 231 | + |
| 232 | + config = parser.parse_args() |
| 233 | + seed_everything(config.seed) |
| 234 | + |
| 235 | + logger.info(f"读取文件 {config.from_file}") |
| 236 | + # 首先读取数据 |
| 237 | + if config.from_file.endswith(".json"): |
| 238 | + seq_num, question, answer = read_and_get_QAlist_from_json(config) |
| 239 | + else: |
| 240 | + seq_num, question, answer = read_and_get_QAlist_from_txt(config) |
| 241 | + |
| 242 | + # 简单处理得到正样本 |
| 243 | + logger.info("对文件进行处理") |
| 244 | + p_data = process_and_save(seq_num, question, answer) |
| 245 | + |
| 246 | + logger.info(f"在 {config.sample} turn 中进行采样") |
| 247 | + if config.sample == "random": |
| 248 | + # 随机采样 |
| 249 | + |
| 250 | + responses = p_data['response'].values.tolist() |
| 251 | + length = p_data.shape[0] |
| 252 | + prefix = f"multi_{config.to_file_name}_neg_random" |
| 253 | + results = [] |
| 254 | + pool = mp.Pool() |
| 255 | + for p in range(config.n_jobs): |
| 256 | + start = int(p / config.n_jobs * length) |
| 257 | + end = int((p+1) / config.n_jobs * length) |
| 258 | + result = pool.apply_async(sample_multiQA_neg_random_turn, |
| 259 | + args=(p_data,start,end,responses,length,prefix,)) |
| 260 | + results.append(result) |
| 261 | + pool.close() |
| 262 | + pool.join() |
| 263 | + if all([res.ready() for res in results]): |
| 264 | + logger.info("负采样完成") |
| 265 | + files = os.listdir("./data") |
| 266 | + n_data = pd.DataFrame() |
| 267 | + for file in files: |
| 268 | + tmp = pd.read_csv(Path("./data") / file) |
| 269 | + n_data = pd.concat([n_data, tmp], axis=0, sort=False, ignore_index=True) |
| 270 | + assert n_data.shape == p_data.shape |
| 271 | + # 删除中间生成的数据 |
| 272 | + shutil.rmtree("./data") |
| 273 | + elif config.sample == "same": |
| 274 | + n_data = sample_multiQA_neg_same_turn(p_data, config) |
| 275 | + else: |
| 276 | + raise ValueError(f"The {config.sample} is invalid, only `random` and `same` allow.") |
| 277 | + |
| 278 | + # 得到最终的数据并保存 |
| 279 | + logger.info(f"将数据保存为训练集和验证集,路径为 {config.to_path}") |
| 280 | + gen_multiQA_train_val(p_data, n_data, config) |
| 281 | + |
0 commit comments