|
27 | 27 | # parser.add_argument('--pprint', type=bool, default=True,
|
28 | 28 | # help='Print PDF or display image only')
|
29 | 29 | parser.add_argument('--hidden_size', type=int, default=64)
|
30 |
| -parser.add_argument('--n_layers', type=int, default=2, help='GRU layers') |
| 30 | +parser.add_argument('--n_layers', type=int, default=1, help='GRU layers') |
31 | 31 | parser.add_argument('--pint', type=int, default=16, help='CNN/Att past samples to integrate')
|
32 | 32 | parser.add_argument('--print_every', type=int, default=25)
|
33 | 33 | parser.add_argument('--learning_rate', type=float, default=0.01)
|
|
45 | 45 | file, file_len = read_file(args.filename)
|
46 | 46 |
|
47 | 47 |
|
48 |
| -def random_training_set(chunk_len): |
| 48 | +def random_piece_text(chunk_len): |
49 | 49 | start_index = random.randint(0, file_len - chunk_len)
|
50 | 50 | end_index = start_index + chunk_len + 1
|
51 | 51 | chunk = file[start_index:end_index]
|
| 52 | + return chunk |
| 53 | + |
| 54 | + |
| 55 | +def random_training_set(chunk_len): |
| 56 | + chunk = random_piece_text(chunk_len) |
52 | 57 | inp = char_tensor(chunk[:-1])
|
53 | 58 | target = char_tensor(chunk[1:])
|
54 | 59 | return inp, target
|
@@ -155,17 +160,15 @@ def save():
|
155 | 160 | loss = train(*random_training_set(args.chunk_len))
|
156 | 161 | loss_avg += loss
|
157 | 162 |
|
| 163 | + # generate some text to see performance: |
158 | 164 | if epoch % args.print_every == 0:
|
159 | 165 | print('[elapsed time: %s, epoch: %d, percent complete: %d%%, loss: %.4f]' % (time_since(start), epoch, epoch / args.epochs * 100, loss))
|
160 | 166 | if args.sequencer == 'GRU':
|
161 |
| - chunk_len = args.pint |
162 |
| - start_index = random.randint(0, file_len - chunk_len) |
163 |
| - end_index = start_index + chunk_len + 1 |
164 |
| - init_str = file[start_index:end_index] |
165 |
| - print(generate_GRU(model, init_str, 100), '\n') |
| 167 | + init_str = random_piece_text(args.pint) |
| 168 | + generate_GRU(model, init_str, 100) |
166 | 169 | elif args.sequencer == 'CNN' or 'Att':
|
167 | 170 | init_str,_ = random_training_set(args.pint)
|
168 |
| - print(generate_CNN(model, init_str, 100), '\n') |
| 171 | + generate_CNN(model, init_str, 100) |
169 | 172 |
|
170 | 173 | print("Saving...")
|
171 | 174 | save()
|
|
0 commit comments