Skip to content

Commit 57ffa5e

Browse files
committed
better visualize
1 parent 00a7cd2 commit 57ffa5e

File tree

2 files changed

+20
-11
lines changed

2 files changed

+20
-11
lines changed

seq-learning-char/generate.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# https://github.com/spro/practical-pytorch
22

33
import torch
4+
import colorama
5+
from colorama import Fore, Style
46

57
from helpers import *
68
from model import *
@@ -27,8 +29,10 @@ def generate_GRU(decoder, prime_str='A', predict_len=100, temperature=0.8):
2729
predicted_char = all_characters[top_i]
2830
predicted += predicted_char
2931
inp = char_tensor(predicted_char)
30-
31-
return predicted
32+
33+
print(Fore.BLUE + predicted[:len(prime_str)] + Fore.GREEN + predicted[len(prime_str):])
34+
print(Style.RESET_ALL)
35+
# return predicted
3236

3337

3438
def generate_CNN(decoder, prime_str, predict_len=100, temperature=0.8):
@@ -59,4 +63,6 @@ def generate_CNN(decoder, prime_str, predict_len=100, temperature=0.8):
5963
inp[:-1] = inp[1:]
6064
inp[-1] = top_i
6165

62-
return predicted
66+
print(Fore.BLUE + predicted[:len(prime_str)] + Fore.GREEN + predicted[len(prime_str):])
67+
print(Style.RESET_ALL)
68+
# return predicted

seq-learning-char/train.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# parser.add_argument('--pprint', type=bool, default=True,
2828
# help='Print PDF or display image only')
2929
parser.add_argument('--hidden_size', type=int, default=64)
30-
parser.add_argument('--n_layers', type=int, default=2, help='GRU layers')
30+
parser.add_argument('--n_layers', type=int, default=1, help='GRU layers')
3131
parser.add_argument('--pint', type=int, default=16, help='CNN/Att past samples to integrate')
3232
parser.add_argument('--print_every', type=int, default=25)
3333
parser.add_argument('--learning_rate', type=float, default=0.01)
@@ -45,10 +45,15 @@
4545
file, file_len = read_file(args.filename)
4646

4747

48-
def random_training_set(chunk_len):
48+
def random_piece_text(chunk_len):
4949
start_index = random.randint(0, file_len - chunk_len)
5050
end_index = start_index + chunk_len + 1
5151
chunk = file[start_index:end_index]
52+
return chunk
53+
54+
55+
def random_training_set(chunk_len):
56+
chunk = random_piece_text(chunk_len)
5257
inp = char_tensor(chunk[:-1])
5358
target = char_tensor(chunk[1:])
5459
return inp, target
@@ -155,17 +160,15 @@ def save():
155160
loss = train(*random_training_set(args.chunk_len))
156161
loss_avg += loss
157162

163+
# generate some text to see performance:
158164
if epoch % args.print_every == 0:
159165
print('[elapsed time: %s, epoch: %d, percent complete: %d%%, loss: %.4f]' % (time_since(start), epoch, epoch / args.epochs * 100, loss))
160166
if args.sequencer == 'GRU':
161-
chunk_len = args.pint
162-
start_index = random.randint(0, file_len - chunk_len)
163-
end_index = start_index + chunk_len + 1
164-
init_str = file[start_index:end_index]
165-
print(generate_GRU(model, init_str, 100), '\n')
167+
init_str = random_piece_text(args.pint)
168+
generate_GRU(model, init_str, 100)
166169
elif args.sequencer == 'CNN' or 'Att':
167170
init_str,_ = random_training_set(args.pint)
168-
print(generate_CNN(model, init_str, 100), '\n')
171+
generate_CNN(model, init_str, 100)
169172

170173
print("Saving...")
171174
save()

0 commit comments

Comments
 (0)