-
Notifications
You must be signed in to change notification settings - Fork 24
/
Copy pathtrainer.py
131 lines (95 loc) · 4.28 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import time
import random
import torch
import torch.nn as nn
import torch.optim as optim
from utils import epoch_time
from model.optim import ScheduledAdam
from model.transformer import Transformer
random.seed(32)
torch.manual_seed(32)
torch.backends.cudnn.deterministic = True
class Trainer:
def __init__(self, params, mode, train_iter=None, valid_iter=None, test_iter=None):
self.params = params
# Train mode
if mode == 'train':
self.train_iter = train_iter
self.valid_iter = valid_iter
# Test mode
else:
self.test_iter = test_iter
self.model = Transformer(self.params)
self.model.to(self.params.device)
# Scheduling Optimzer
self.optimizer = ScheduledAdam(
optim.Adam(self.model.parameters(), betas=(0.9, 0.98), eps=1e-9),
hidden_dim=params.hidden_dim,
warm_steps=params.warm_steps
)
self.criterion = nn.CrossEntropyLoss(ignore_index=self.params.pad_idx)
self.criterion.to(self.params.device)
def train(self):
print(self.model)
print(f'The model has {self.model.count_params():,} trainable parameters')
best_valid_loss = float('inf')
for epoch in range(self.params.num_epoch):
self.model.train()
epoch_loss = 0
start_time = time.time()
for batch in self.train_iter:
# For each batch, first zero the gradients
self.optimizer.zero_grad()
source = batch.kor
target = batch.eng
# target sentence consists of <sos> and following tokens (except the <eos> token)
output = self.model(source, target[:, :-1])[0]
# ground truth sentence consists of tokens and <eos> token (except the <sos> token)
output = output.contiguous().view(-1, output.shape[-1])
target = target[:, 1:].contiguous().view(-1)
# output = [(batch size * target length - 1), output dim]
# target = [(batch size * target length - 1)]
loss = self.criterion(output, target)
loss.backward()
# clip the gradients to prevent the model from exploding gradient
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.params.clip)
self.optimizer.step()
# 'item' method is used to extract a scalar from a tensor which only contains a single value.
epoch_loss += loss.item()
train_loss = epoch_loss / len(self.train_iter)
valid_loss = self.evaluate()
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(self.model.state_dict(), self.params.save_model)
print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Val. Loss: {valid_loss:.3f}')
def evaluate(self):
self.model.eval()
epoch_loss = 0
with torch.no_grad():
for batch in self.valid_iter:
source = batch.kor
target = batch.eng
output = self.model(source, target[:, :-1])[0]
output = output.contiguous().view(-1, output.shape[-1])
target = target[:, 1:].contiguous().view(-1)
loss = self.criterion(output, target)
epoch_loss += loss.item()
return epoch_loss / len(self.valid_iter)
def inference(self):
self.model.load_state_dict(torch.load(self.params.save_model))
self.model.eval()
epoch_loss = 0
with torch.no_grad():
for batch in self.test_iter:
source = batch.kor
target = batch.eng
output = self.model(source, target[:, :-1])[0]
output = output.contiguous().view(-1, output.shape[-1])
target = target[:, 1:].contiguous().view(-1)
loss = self.criterion(output, target)
epoch_loss += loss.item()
test_loss = epoch_loss / len(self.test_iter)
print(f'Test Loss: {test_loss:.3f}')