Skip to content

Commit e6c61e6

Browse files
committed
visualization complete
1 parent 57ffa5e commit e6c61e6

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

seq-learning-char/generate.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
# https://github.com/spro/practical-pytorch
22

33
import torch
4-
import colorama
5-
from colorama import Fore, Style
4+
import os
65

76
from helpers import *
87
from model import *
98

9+
# Color Palette
10+
CP_R = '\033[31m'
11+
CP_G = '\033[32m'
12+
CP_B = '\033[34m'
13+
CP_Y = '\033[33m'
14+
CP_C = '\033[0m'
15+
1016
def generate_GRU(decoder, prime_str='A', predict_len=100, temperature=0.8):
1117
hidden = decoder.init_hidden()
1218
prime_input = char_tensor(prime_str)
@@ -29,9 +35,8 @@ def generate_GRU(decoder, prime_str='A', predict_len=100, temperature=0.8):
2935
predicted_char = all_characters[top_i]
3036
predicted += predicted_char
3137
inp = char_tensor(predicted_char)
32-
33-
print(Fore.BLUE + predicted[:len(prime_str)] + Fore.GREEN + predicted[len(prime_str):])
34-
print(Style.RESET_ALL)
38+
39+
print(CP_B + predicted[:len(prime_str)] + CP_G + predicted[len(prime_str):] + CP_C)
3540
# return predicted
3641

3742

@@ -63,6 +68,5 @@ def generate_CNN(decoder, prime_str, predict_len=100, temperature=0.8):
6368
inp[:-1] = inp[1:]
6469
inp[-1] = top_i
6570

66-
print(Fore.BLUE + predicted[:len(prime_str)] + Fore.GREEN + predicted[len(prime_str):])
67-
print(Style.RESET_ALL)
71+
print(CP_B + predicted[:len(prime_str)] + CP_G + predicted[len(prime_str):] + CP_C)
6872
# return predicted

seq-learning-char/train.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -161,14 +161,15 @@ def save():
161161
loss_avg += loss
162162

163163
# generate some text to see performance:
164+
test_len = 100
164165
if epoch % args.print_every == 0:
165166
print('[elapsed time: %s, epoch: %d, percent complete: %d%%, loss: %.4f]' % (time_since(start), epoch, epoch / args.epochs * 100, loss))
166167
if args.sequencer == 'GRU':
167-
init_str = random_piece_text(args.pint)
168-
generate_GRU(model, init_str, 100)
168+
init_str = random_piece_text(args.pint+test_len)
169+
generate_GRU(model, init_str[:args.pint], test_len)
169170
elif args.sequencer == 'CNN' or 'Att':
170-
init_str,_ = random_training_set(args.pint)
171-
generate_CNN(model, init_str, 100)
171+
init_str,_ = random_training_set(args.pint+test_len)
172+
generate_CNN(model, init_str[:args.pint], test_len)
172173

173174
print("Saving...")
174175
save()

0 commit comments

Comments
 (0)