Skip to content

Commit 7671cf4

Browse files
committed
Update TimeStretch doc and tutorial (#3694)
1 parent 383548c commit 7671cf4

File tree

2 files changed

+49
-25
lines changed

2 files changed

+49
-25
lines changed

examples/tutorials/audio_feature_augmentation_tutorial.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
import librosa
2727
import matplotlib.pyplot as plt
28+
from IPython.display import Audio
2829
from torchaudio.utils import download_asset
2930

3031
######################################################################
@@ -69,11 +70,6 @@ def get_spectrogram(
6970
return spectrogram(waveform)
7071

7172

72-
def plot_spec(ax, spec, title, ylabel="freq_bin"):
73-
ax.set_title(title)
74-
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
75-
76-
7773
######################################################################
7874
# SpecAugment
7975
# -----------
@@ -98,11 +94,15 @@ def plot_spec(ax, spec, title, ylabel="freq_bin"):
9894
spec_12 = stretch(spec, overriding_rate=1.2)
9995
spec_09 = stretch(spec, overriding_rate=0.9)
10096

101-
######################################################################
102-
#
103-
10497

98+
######################################################################
99+
# Visualization
100+
# ~~~~~~~~~~~~~
105101
def plot():
102+
def plot_spec(ax, spec, title):
103+
ax.set_title(title)
104+
ax.imshow(librosa.amplitude_to_db(spec), origin="lower", aspect="auto")
105+
106106
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
107107
plot_spec(axes[0], torch.abs(spec_12[0]), title="Stretched x1.2")
108108
plot_spec(axes[1], torch.abs(spec[0]), title="Original")
@@ -112,6 +112,30 @@ def plot():
112112

113113
plot()
114114

115+
116+
######################################################################
117+
# Audio Samples
118+
# ~~~~~~~~~~~~~
119+
def preview(spec, rate=16000):
120+
ispec = T.InverseSpectrogram()
121+
waveform = ispec(spec)
122+
123+
return Audio(waveform[0].numpy().T, rate=rate)
124+
125+
126+
preview(spec)
127+
128+
129+
######################################################################
130+
#
131+
preview(spec_12)
132+
133+
134+
######################################################################
135+
#
136+
preview(spec_09)
137+
138+
115139
######################################################################
116140
# Time and Frequency Masking
117141
# --------------------------
@@ -131,6 +155,10 @@ def plot():
131155

132156

133157
def plot():
158+
def plot_spec(ax, spec, title):
159+
ax.set_title(title)
160+
ax.imshow(librosa.power_to_db(spec), origin="lower", aspect="auto")
161+
134162
fig, axes = plt.subplots(3, 1, sharex=True, sharey=True)
135163
plot_spec(axes[0], spec[0], title="Original")
136164
plot_spec(axes[1], time_masked[0], title="Masked along time axis")

torchaudio/transforms/_transforms.py

+13-17
Original file line numberDiff line numberDiff line change
@@ -1020,31 +1020,27 @@ class TimeStretch(torch.nn.Module):
10201020
Proposed in *SpecAugment* :cite:`specaugment`.
10211021
10221022
Args:
1023-
hop_length (int or None, optional): Length of hop between STFT windows. (Default: ``win_length // 2``)
1023+
hop_length (int or None, optional): Length of hop between STFT windows.
1024+
(Default: ``n_fft // 2``, where ``n_fft == (n_freq - 1) * 2``)
10241025
n_freq (int, optional): number of filter banks from stft. (Default: ``201``)
10251026
fixed_rate (float or None, optional): rate to speed up or slow down by.
10261027
If None is provided, rate must be passed to the forward method. (Default: ``None``)
10271028
1029+
.. note::
1030+
1031+
The expected input is raw, complex-valued spectrogram.
1032+
10281033
Example
1029-
>>> spectrogram = torchaudio.transforms.Spectrogram()
1034+
>>> spectrogram = torchaudio.transforms.Spectrogram(power=None)
10301035
>>> stretch = torchaudio.transforms.TimeStretch()
10311036
>>>
10321037
>>> original = spectrogram(waveform)
1033-
>>> streched_1_2 = stretch(original, 1.2)
1034-
>>> streched_0_9 = stretch(original, 0.9)
1035-
1036-
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_1.png
1037-
:width: 600
1038-
:alt: Spectrogram streched by 1.2
1038+
>>> stretched_1_2 = stretch(original, 1.2)
1039+
>>> stretched_0_9 = stretch(original, 0.9)
10391040
1040-
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_2.png
1041+
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch.png
10411042
:width: 600
1042-
:alt: The original spectrogram
1043-
1044-
.. image:: https://download.pytorch.org/torchaudio/doc-assets/specaugment_time_stretch_3.png
1045-
:width: 600
1046-
:alt: Spectrogram streched by 0.9
1047-
1043+
:alt: The visualization of stretched spectrograms.
10481044
"""
10491045
__constants__ = ["fixed_rate"]
10501046

@@ -1067,8 +1063,8 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] =
10671063
10681064
Returns:
10691065
Tensor:
1070-
Stretched spectrogram. The resulting tensor is of the same dtype as the input
1071-
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
1066+
Stretched spectrogram. The resulting tensor is of the corresponding complex dtype
1067+
as the input spectrogram, and the number of frames is changed to ``ceil(num_frame / rate)``.
10721068
"""
10731069
if overriding_rate is None:
10741070
if self.fixed_rate is None:

0 commit comments

Comments
 (0)