Skip to content

Commit 6ea1133

Browse files
eyalcohen308Eyal Cohen
authored andcommitted
Optimize merge_tokens method (#3615)
Optimizes merge_tokens method as discussed in #3614 Co-authored-by: Eyal Cohen <[email protected]>
1 parent 420d9ac commit 6ea1133

File tree

1 file changed

+10
-13
lines changed

1 file changed

+10
-13
lines changed

torchaudio/functional/_alignment.py

+10-13
Original file line numberDiff line numberDiff line change
@@ -115,17 +115,14 @@ def merge_tokens(tokens: Tensor, scores: Tensor, blank: int = 0) -> List[TokenSp
115115
if len(tokens) != len(scores):
116116
raise ValueError("`tokens` and `scores` must be the same length.")
117117

118-
t_prev = blank
119-
i = start = -1
120-
spans = []
121-
for t, token in enumerate(tokens):
122-
if token != t_prev:
123-
if t_prev != blank:
124-
spans.append(TokenSpan(t_prev.item(), start, t, scores[start:t].mean().item()))
125-
if token != blank:
126-
i += 1
127-
start = t
128-
t_prev = token
129-
if t_prev != blank:
130-
spans.append(TokenSpan(t_prev.item(), start, len(tokens), scores[start:].mean().item()))
118+
diff = torch.diff(
119+
tokens, prepend=torch.tensor([-1], device=tokens.device), append=torch.tensor([-1], device=tokens.device)
120+
)
121+
changes_wo_blank = torch.nonzero((diff != 0)).squeeze().tolist()
122+
tokens = tokens.tolist()
123+
spans = [
124+
TokenSpan(token=token, start=start, end=end, score=scores[start:end].mean().item())
125+
for start, end in zip(changes_wo_blank[:-1], changes_wo_blank[1:])
126+
if (token := tokens[start]) != blank
127+
]
131128
return spans

0 commit comments

Comments
 (0)