forked from Zessay/NLP-Pytorch-Template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_trainer.py
114 lines (96 loc) · 3.52 KB
/
test_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python
# encoding: utf-8
'''
@author: zessay
@license: (C) Copyright Sogou.
@contact: [email protected]
@file: test_trainer.py
@time: 2019/12/23 15:58
@description: 测试Trainer类
'''
import sys
import os
import time
import pandas as pd
import torch
import torch.nn as nn
from pathlib import Path
import warnings
warnings.filterwarnings("ignore")
sys.path.append(os.path.dirname(os.getcwd()))
from snlp import tasks
from snlp.embedding import load_from_file
# 将数据封装成Dataset和DataLoader
from snlp.datagen.dataset.pair_dataset import PairDataset
from snlp.callbacks.padding import MultiQAPadding
from snlp.datagen.dataloader.dict_dataloader import DictDataLoader
from snlp.models.retrieval.dam import DAM
from snlp.preprocessors.chinese_preprocessor import CNPreprocessorForMultiQA
from snlp.optimizer import RAdam
from snlp.trainers import Trainer
from snlp.tools.log import logger
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
fixed_length_uttr = 20
fixed_length_resp = 20
fixed_length_turn = 5
start = time.time()
cls_task = tasks.Classification(num_classes=2, losses = nn.CrossEntropyLoss())
cls_task.metrics = ['accuracy']
# 对数据进行预处理
file = "../sample_data/multi_qa.csv"
logger.info("读取数据 %s" % os.path.basename(file))
data = pd.read_csv(file)
data['label'] = 1
# 对数据进行预处理
## 合并训练预处理器是一种leaky
logger.info("使用Preprocessor处理数据")
preprocessor = CNPreprocessorForMultiQA(stopwords=['\t'])
preprocessor = preprocessor.fit(data, columns=['utterances', 'response'])
data = preprocessor.transform(data)
data = data[['D_num', 'turns', 'utterances', 'response', 'utterances_len', 'response_len', 'label']]
data['label'] = data['label'].astype(int)
# 划分训练集和测试集
train = data[:90]
valid = data[90:]
# 加载预训练词向量
basename = "/home/speech/models"
# 构建词向量矩阵
logger.info("读取词向量文件")
word_embedding = load_from_file(Path(basename) / "500000-small.txt")
embedding_matrix = word_embedding.build_matrix(preprocessor.context['term_index'])
# 对训练集和验证集进行封装
logger.info("使用Dataset和DataLoader对数据进行封装")
train_dataset = PairDataset(train, num_neg=0)
valid_dataset = PairDataset(valid, num_neg=0)
padding = MultiQAPadding(fixed_length_uttr=fixed_length_uttr, fixed_length_resp=fixed_length_resp,
fixed_length_turn=fixed_length_turn)
train_dataloader = DictDataLoader(train_dataset, batch_size=16,
turns=fixed_length_turn,
shuffle=False,
sort=False,
callback=padding)
valid_dataloader = DictDataLoader(valid_dataset, batch_size=16,
turns=fixed_length_turn,
shuffle=False,
sort=False,
callback=padding)
# 定义模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info("定义模型和参数")
model = DAM(uttr_len=fixed_length_uttr, resp_len=fixed_length_resp, turns=fixed_length_turn)
params = model.get_default_params()
params['task'] = cls_task
params['embedding'] = embedding_matrix
model.params = params
model.build()
model = model.float()
optimizer = RAdam(model.parameters(), lr=1e-5)
trainer = Trainer(
model=model,
optimizer=optimizer,
trainloader=train_dataloader,
validloader=valid_dataloader,
epochs=2
)
logger.info("开始训练模型")
trainer.run()