Skip to content

Commit 8edf91a

Browse files
committed
Update forced alignment tutorial (#2544)
Summary: 1. Fix initialization. Previously, the SOS token score was initialized to 0 across the time axis. This was biasing the alignment to delay the start. The proper way to delay the SOS is via blank token. The new initilization takes the cumulated sum of blank scores. 2. Fill the end of trellis with Inf Similar to the start, at the end where there remaining time frame is less than the number of tokens, it is no longer possible to align the text, thus we fill with Inf for better visualization. 3. Clean up asset management code. Pull Request resolved: #2544 Reviewed By: nateanl Differential Revision: D38276478 Pulled By: mthrok fbshipit-source-id: 6d934cc850a0790b8c463a4f69f8f1143633d299
1 parent 7b0def8 commit 8edf91a

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

examples/tutorials/forced_alignment_tutorial.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,16 @@
1111
1212
"""
1313

14+
import torch
15+
import torchaudio
16+
17+
print(torch.__version__)
18+
print(torchaudio.__version__)
19+
20+
21+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22+
print(device)
23+
1424

1525
######################################################################
1626
# Overview
@@ -37,32 +47,18 @@
3747

3848
# %matplotlib inline
3949

40-
import os
4150
from dataclasses import dataclass
4251

4352
import IPython
4453
import matplotlib
4554
import matplotlib.pyplot as plt
46-
import requests
47-
import torch
48-
import torchaudio
4955

5056
matplotlib.rcParams["figure.figsize"] = [16.0, 4.8]
5157

5258
torch.random.manual_seed(0)
53-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54-
55-
print(torch.__version__)
56-
print(torchaudio.__version__)
57-
print(device)
5859

59-
SPEECH_URL = "https://download.pytorch.org/torchaudio/tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav"
60-
SPEECH_FILE = "_assets/speech.wav"
60+
SPEECH_FILE = torchaudio.utils.download_asset("tutorial-assets/Lab41-SRI-VOiCES-src-sp0307-ch127535-sg0042.wav")
6161

62-
if not os.path.exists(SPEECH_FILE):
63-
os.makedirs("_assets", exist_ok=True)
64-
with open(SPEECH_FILE, "wb") as file:
65-
file.write(requests.get(SPEECH_URL).content)
6662

6763
######################################################################
6864
# Generate frame-wise label probability
@@ -156,8 +152,12 @@ def get_trellis(emission, tokens, blank_id=0):
156152
# Trellis has extra diemsions for both time axis and tokens.
157153
# The extra dim for tokens represents <SoS> (start-of-sentence)
158154
# The extra dim for time axis is for simplification of the code.
159-
trellis = torch.full((num_frame + 1, num_tokens + 1), -float("inf"))
160-
trellis[:, 0] = 0
155+
trellis = torch.empty((num_frame + 1, num_tokens + 1))
156+
trellis[0, 0] = 0
157+
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
158+
trellis[0, -num_tokens:] = -float("inf")
159+
trellis[-num_tokens:, 0] = float("inf")
160+
161161
for t in range(num_frame):
162162
trellis[t + 1, 1:] = torch.maximum(
163163
# Score for staying at the same token
@@ -250,7 +250,8 @@ def backtrack(trellis, emission, tokens, blank_id=0):
250250

251251

252252
path = backtrack(trellis, emission, tokens)
253-
print(path)
253+
for p in path:
254+
print(p)
254255

255256

256257
################################################################################
@@ -449,6 +450,8 @@ def plot_alignments(trellis, segments, word_segments, waveform):
449450
)
450451
plt.show()
451452

453+
################################################################################
454+
#
452455

453456
# A trick to embed the resulting audio to the generated file.
454457
# `IPython.display.Audio` has to be the last call in a cell,
@@ -458,10 +461,9 @@ def display_segment(i):
458461
word = word_segments[i]
459462
x0 = int(ratio * word.start)
460463
x1 = int(ratio * word.end)
461-
filename = f"_assets/{i}_{word.label}.wav"
462-
torchaudio.save(filename, waveform[:, x0:x1], bundle.sample_rate)
463464
print(f"{word.label} ({word.score:.2f}): {x0 / bundle.sample_rate:.3f} - {x1 / bundle.sample_rate:.3f} sec")
464-
return IPython.display.Audio(filename)
465+
segment = waveform[:, x0:x1]
466+
return IPython.display.Audio(segment.numpy(), rate=bundle.sample_rate)
465467

466468

467469
######################################################################

0 commit comments

Comments
 (0)