-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
69 lines (56 loc) · 2.4 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from torch import nn
import cfg
from modules.discriminator import Discriminator
from modules.generator import Generator
class LMGan(nn.Module):
def __init__(self, opt):
super(LMGan, self).__init__()
self.opt = opt
self.generator = Generator(
opt.vocab_size,
opt.embedding_size,
opt.hidden_size,
opt.device
)
self.discriminator = Discriminator(
opt.hidden_size,
opt.d_hidden_size,
opt.d_linear_size,
opt.d_dropout,
opt.device
) if opt.adversarial else None
def forward(self, input, adversarial=True):
batch_size = input.size(0)
if not adversarial:
# vanilla Negative log-likelihood training, no sampling
start_hidden = self.generator.init_hidden(batch_size, strategy=cfg.inits.zeros)
loss, gen_hidden_states, _ = self.generator.consume(input, start_hidden, sampling=False)
return loss, None, None, None, None
else:
# run one pass without sampling
start_hidden_nll = self.generator.init_hidden(batch_size, strategy=cfg.inits.zeros)
loss_nll, gen_hidden_states_nll, _ = self.generator.consume(
input, start_hidden_nll, sampling=False)
# run one pass with sampling
start_hidden_adv = self.generator.init_hidden(batch_size, strategy=cfg.inits.zeros)
loss_adv, gen_hidden_states_adv, _ = self.generator.consume(
input,
start_hidden_adv,
sampling=True,
method=self.opt.sampling_strategy,
temperature=self.opt.temperature
)
# these two passes have computational graphs that are completely different, so
# in the future we can call backwards for each loss consequently
# Now, call the discriminator
teacher_forcing_scores = self.discriminator(gen_hidden_states_nll)
autoregressive_scores = self.discriminator(gen_hidden_states_adv)
return loss_nll + loss_adv, teacher_forcing_scores, autoregressive_scores,\
gen_hidden_states_nll, gen_hidden_states_adv
def view_rnn_grad_norms(self):
norms_dict = {
k: v.grad.norm().item()
for k, v in self.named_parameters()
if 'rnn' in k
}
return norms_dict