This repository was archived by the owner on Aug 18, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathtrain.py
81 lines (65 loc) · 2.52 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# https://github.com/spro/practical-pytorch
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
from torch.autograd import Variable
import argparse
import os
from helpers import *
from model import *
from generate import *
# Parse command line arguments
argparser = argparse.ArgumentParser()
argparser.add_argument('filename', type=str)
argparser.add_argument('--n_epochs', type=int, default=2000)
argparser.add_argument('--print_every', type=int, default=100)
argparser.add_argument('--hidden_size', type=int, default=50)
argparser.add_argument('--n_layers', type=int, default=2)
argparser.add_argument('--learning_rate', type=float, default=0.01)
argparser.add_argument('--chunk_len', type=int, default=200)
args = argparser.parse_args()
file, file_len, all_characters, n_characters = read_file(args.filename)
def random_training_set(chunk_len):
start_index = random.randint(0, file_len - chunk_len)
end_index = start_index + chunk_len + 1
chunk = file[start_index:end_index]
inp = char_tensor(chunk[:-1], all_characters)
target = char_tensor(chunk[1:], all_characters)
return inp, target
decoder = RNN(n_characters, args.hidden_size, n_characters, args.n_layers)
decoder_optimizer = torch.optim.Adam(decoder.parameters(), lr=args.learning_rate)
criterion = nn.CrossEntropyLoss()
start = time.time()
all_losses = []
loss_avg = 0
def train(inp, target):
hidden = decoder.init_hidden()
decoder.zero_grad()
loss = 0
for c in range(args.chunk_len):
output, hidden = decoder(inp[c], hidden)
loss += criterion(output, target[c])
loss.backward()
decoder_optimizer.step()
return loss.data[0] / args.chunk_len
def save():
save_filename = os.path.splitext(os.path.basename(args.filename))[0] + '.pt'
torch.save(decoder, save_filename)
import pickle
with open("charset.pickle", "w") as fd:
pickle.dump(all_characters, fd)
print('Saved as %s' % save_filename)
try:
print("Training for %d epochs..." % args.n_epochs)
for epoch in range(1, args.n_epochs + 1):
loss = train(*random_training_set(args.chunk_len))
loss_avg += loss
if epoch % args.print_every == 0:
print('[%s (%d %d%%) %.4f]' % (time_since(start), epoch, epoch / args.n_epochs * 100, loss))
#print(generate(decoder, 'Wh', 100), '\n')
print generate(decoder, all_characters=all_characters, prime_str='अध्याय', predict_len=500)
print("Saving...")
save()
except KeyboardInterrupt:
print("Saving before quit...")
save()