forked from leviswind/pytorch-transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_load.py
88 lines (71 loc) · 3.35 KB
/
data_load.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# -*- coding: utf-8 -*-
#/usr/bin/python2
'''
June 2017 by kyubyong park.
https://www.github.com/kyubyong/transformer
'''
from __future__ import print_function
from hyperparams import Hyperparams as hp
import numpy as np
import codecs
import regex
import random
import torch
try:
xrange # Python 2
except NameError:
xrange = range # Python 3
def load_de_vocab():
vocab = [line.split()[0] for line in codecs.open('preprocessed/de.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
return word2idx, idx2word
def load_en_vocab():
vocab = [line.split()[0] for line in codecs.open('preprocessed/en.vocab.tsv', 'r', 'utf-8').read().splitlines() if int(line.split()[1])>=hp.min_cnt]
word2idx = {word: idx for idx, word in enumerate(vocab)}
idx2word = {idx: word for idx, word in enumerate(vocab)}
return word2idx, idx2word
def create_data(source_sents, target_sents):
de2idx, idx2de = load_de_vocab()
en2idx, idx2en = load_en_vocab()
# Index
x_list, y_list, Sources, Targets = [], [], [], []
for source_sent, target_sent in zip(source_sents, target_sents):
x = [de2idx.get(word, 1) for word in (source_sent + u" </S>").split()] # 1: OOV, </S>: End of Text
y = [en2idx.get(word, 1) for word in (target_sent + u" </S>").split()]
if max(len(x), len(y)) <=hp.maxlen:
x_list.append(np.array(x))
y_list.append(np.array(y))
Sources.append(source_sent)
Targets.append(target_sent)
# Pad
X = np.zeros([len(x_list), hp.maxlen], np.int32)
Y = np.zeros([len(y_list), hp.maxlen], np.int32)
for i, (x, y) in enumerate(zip(x_list, y_list)):
X[i] = np.lib.pad(x, [0, hp.maxlen-len(x)], 'constant', constant_values=(0, 0))
Y[i] = np.lib.pad(y, [0, hp.maxlen-len(y)], 'constant', constant_values=(0, 0))
return X, Y, Sources, Targets
def load_train_data():
de_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.source_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"]
en_sents = [regex.sub("[^\s\p{Latin}']", "", line) for line in codecs.open(hp.target_train, 'r', 'utf-8').read().split("\n") if line and line[0] != "<"]
X, Y, Sources, Targets = create_data(de_sents, en_sents)
return X, Y
def load_test_data():
def _refine(line):
line = regex.sub("<[^>]+>", "", line)
line = regex.sub("[^\s\p{Latin}']", "", line)
return line.strip()
de_sents = [_refine(line) for line in codecs.open(hp.source_test, 'r', 'utf-8').read().split("\n") if line and line[:4] == "<seg"]
en_sents = [_refine(line) for line in codecs.open(hp.target_test, 'r', 'utf-8').read().split("\n") if line and line[:4] == "<seg"]
X, Y, Sources, Targets = create_data(de_sents, en_sents)
return X, Sources, Targets # (1064, 150)
def get_batch_indices(total_length, batch_size):
current_index = 0
indexs = [i for i in xrange(total_length)]
random.shuffle(indexs)
while 1:
if current_index + batch_size >= total_length:
break
current_index += batch_size
yield indexs[current_index: current_index + batch_size], current_index