Skip to content
This repository was archived by the owner on Aug 18, 2021. It is now read-only.

added support for arbitrary encoding/unicode #44

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 17 additions & 11 deletions char-rnn-generation/generate.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,52 @@
# https://github.com/spro/practical-pytorch
# -*- coding: utf-8 -*-

import torch

from helpers import *
from model import *

def generate(decoder, prime_str='A', predict_len=100, temperature=0.8):
def generate(decoder, all_characters, prime_str='A', predict_len=100, temperature=0.8):
hidden = decoder.init_hidden()
prime_input = char_tensor(prime_str)
prime_input = char_tensor(prime_str, all_characters)
predicted = prime_str

# Use priming string to "build up" hidden state
for p in range(len(prime_str) - 1):
_, hidden = decoder(prime_input[p], hidden)

inp = prime_input[-1]

for p in range(predict_len):
output, hidden = decoder(inp, hidden)

# Sample from the network as a multinomial distribution
output_dist = output.data.view(-1).div(temperature).exp()
top_i = torch.multinomial(output_dist, 1)[0]

# Add predicted character to string and use as next input
predicted_char = all_characters[top_i]
predicted += predicted_char
inp = char_tensor(predicted_char)
inp = char_tensor(predicted_char, all_characters)

return predicted

if __name__ == '__main__':
# Parse command line arguments
import argparse
import argparse, pickle
argparser = argparse.ArgumentParser()
argparser.add_argument('filename', type=str)
argparser.add_argument('-p', '--prime_str', type=str, default='A')
argparser.add_argument('-l', '--predict_len', type=int, default=100)
argparser.add_argument('-t', '--temperature', type=float, default=0.8)
argparser.add_argument('-f', '--charset-file', type=str, default='charset.pickle')
args = argparser.parse_args()

print args
with open(args.charset_file) as fd:
all_characters = pickle.load(fd)
decoder = torch.load(args.filename)
del args.filename
print(generate(decoder, **vars(args)))

del args.charset_file
print all_characters
#print(generate(decoder=decoder, all_characters=all_characters, **vars(args)))
print generate(decoder, all_characters=all_characters, prime_str='अध्याय', predict_len=500)
23 changes: 14 additions & 9 deletions char-rnn-generation/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@

# Reading and un-unicode-encoding data

all_characters = string.printable
n_characters = len(all_characters)
#all_characters = string.printable
#n_characters = len(all_characters)
#all_characters = ['\x83', '\x87', '\x8b', '\x8f', '\x93', '\x97', '\x9b', '\x9f', ' ', '\xa3', '\xa7', '(', '\xab', ',', '\xaf', '\xb7', '\xbb', '\xbf', '\xc3', 'H', 'L', 'P', 'T', 'd', 'h', 'l', '\xef', 'p', 't', '|', '\x80', '\x88', '\x8c', '\x90', '\x94', '\x98', '\x9c', '\xa0', '\xa4', "'", '\xa8', '\xac', '/', '\xb0', '\xb8', '\xbc', '?', 'C', 'S', 'W', '\xe0', 'c', 'g', 'k', 'o', 's', 'w', '\x81', '\x85', '\x89', '\n', '\x8d', '\x95', '\x99', '\x9d', '\xa1', '\xa5', '*', '\xad', '.', '\xb1', '\xb5', '\xb9', ':', '\xbd', 'B', 'F', 'J', 'V', 'b', 'f', 'n', 'r', 'v', 'z', '~', '\x82', '\x86', '\x8a', '\x96', '\x9a', '\x9e', '!', '\xa2', '\xa6', ')', '\xaa', '-', '\xae', '\xb2', '\xb6', '\xbe', 'A', '\xc2', 'E', 'I', 'M', 'U', 'Y', 'a', '\xe2', 'e', 'i', 'm', 'q', 'u', 'y']
#n_characters = len(all_characters)

def read_file(filename):
file = unidecode.unidecode(open(filename).read())
return file, len(file)
#global all_characters
#global n_characters
s = open(filename).read()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you're working on this, mind fixing up the read to do the following:

with open(filename) as fh:
    ...

So that the file handle gets closed after the function returns. Thanks!

all_characters = [i for i in set(s)]
n_characters = len(all_characters)
return s, len(s), all_characters, n_characters

# Turning a string into a tensor

def char_tensor(string):
tensor = torch.zeros(len(string)).long()
for c in range(len(string)):
tensor[c] = all_characters.index(string[c])
def char_tensor(s, all_characters):
tensor = torch.zeros(len(s)).long()
for c in range(len(s)):
tensor[c] = all_characters.index(s[c])
return Variable(tensor)

# Readable time elapsed
Expand All @@ -32,4 +38,3 @@ def time_since(since):
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)

14 changes: 9 additions & 5 deletions char-rnn-generation/train.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# https://github.com/spro/practical-pytorch
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
Expand All @@ -21,14 +22,14 @@
argparser.add_argument('--chunk_len', type=int, default=200)
args = argparser.parse_args()

file, file_len = read_file(args.filename)
file, file_len, all_characters, n_characters = read_file(args.filename)

def random_training_set(chunk_len):
start_index = random.randint(0, file_len - chunk_len)
end_index = start_index + chunk_len + 1
chunk = file[start_index:end_index]
inp = char_tensor(chunk[:-1])
target = char_tensor(chunk[1:])
inp = char_tensor(chunk[:-1], all_characters)
target = char_tensor(chunk[1:], all_characters)
return inp, target

decoder = RNN(n_characters, args.hidden_size, n_characters, args.n_layers)
Expand Down Expand Up @@ -56,6 +57,9 @@ def train(inp, target):
def save():
save_filename = os.path.splitext(os.path.basename(args.filename))[0] + '.pt'
torch.save(decoder, save_filename)
import pickle
with open("charset.pickle", "w") as fd:
pickle.dump(all_characters, fd)
print('Saved as %s' % save_filename)

try:
Expand All @@ -66,12 +70,12 @@ def save():

if epoch % args.print_every == 0:
print('[%s (%d %d%%) %.4f]' % (time_since(start), epoch, epoch / args.n_epochs * 100, loss))
print(generate(decoder, 'Wh', 100), '\n')
#print(generate(decoder, 'Wh', 100), '\n')
print generate(decoder, all_characters=all_characters, prime_str='अध्याय', predict_len=500)

print("Saving...")
save()

except KeyboardInterrupt:
print("Saving before quit...")
save()