-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathsearch.py
86 lines (61 loc) · 2.29 KB
/
search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# search.py
# author: Playinf
# email: [email protected]
import numpy as np
from tensorflow.python.util import nest
# score: a beam_size * num_vars matrix, represent current score
# n: max number of elements to select
# threshold: prune if score < best + threshold
def find_nbest(score, n, threshold=None):
num_vars = score.shape[1]
score = score.flatten()
nbest = np.argpartition(score, n)[:n]
beam_indices = nbest / num_vars
var_indices = nbest % num_vars
nbest_score = score[nbest]
if threshold:
best = np.max(nbest_score)
cond = nbest_score > best + threshold
nbest_score = nbest_score[cond]
beam_indices = beam_indices[cond]
var_indices = var_indices[cond]
return nbest_score, beam_indices, var_indices
# select value when corresponding condition is True
def select(value, condition):
selected = []
for cond in condition:
if not cond:
continue
selected.append()
return selected
# nested: a nested structure of shape batch * dim
# indices: indices to select
def select_nbest(nested, indices):
if not isinstance(nested, (list, tuple)):
return nested[indices]
flat_list = nest.flatten(nested)
selected_list = [item[indices] for item in flat_list]
return nest.pack_sequence_as(nested, selected_list)
class beam:
def __init__(self, beamsize, threshold=None):
self.size = beamsize
self.threshold = threshold
self.score = []
self.candidate = []
def prune(self, dist, cond, prev_beam):
prev_score = np.array(prev_beam.score, dist.dtype)
score = prev_score[:, None] - dist
outputs = find_nbest(score, self.size, self.threshold)
nbest_score, beam_indices, var_indices = outputs
finished = []
remained = []
for i, (bid, vid) in enumerate(zip(beam_indices, var_indices)):
prev_candidate = prev_beam.candidate
candidate = prev_candidate[bid] + [vid]
if cond(candidate):
finished.append([candidate, nbest_score[i]])
else:
remained.append(i)
self.candidate.append(candidate)
self.score.append(nbest_score[i])
return finished, beam_indices[remained], var_indices[remained]