Skip to content

Commit 9df28ff

Browse files
authored
[Cherry-pick] Properly set #samples passed to encoder (#3204) (#3239)
Summary: Some audio encoders expect specific, exact number of samples described as in `AVCodecContext.frame_size`. The `AVFrame.nb_samples` is set for the frames passed to `AVFilterGraph`, but frames coming out of the graph do not necessarily have the same numbr of frames. This causes issues with encoding OPUS (among others). This commit fixes it by inserting `asetnsamples` to filter graph if a fixed number of samples is requested. Note: It turned out that FFmpeg 4.1 has issue with OPUS encoding. It does not properly discard some sample. We should probably move the minimum required FFmpeg to 4.2, but I am not sure if we can enforce it via ABI. Work around will be to issue an warning if encoding OPUS with 4.1. (follow-up) Pull Request resolved: #3204 Reviewed By: nateanl Differential Revision: D44374668 Pulled By: mthrok fbshipit-source-id: 10ef5333dc0677dfb83c8e40b78edd8ded1b21dc
1 parent 3b40834 commit 9df28ff

File tree

2 files changed

+61
-15
lines changed

2 files changed

+61
-15
lines changed

test/torchaudio_unittest/io/stream_writer_test.py

+58-15
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from parameterized import parameterized, parameterized_class
55
from torchaudio_unittest.common_utils import (
66
get_asset_path,
7+
get_sinusoid,
78
is_ffmpeg_available,
89
nested_params,
910
rgb_to_yuv_ccir,
@@ -293,28 +294,58 @@ def test_video_num_frames(self, framerate, resolution, format):
293294
pass
294295

295296
@nested_params(
296-
["wav", "mp3", "flac"],
297+
["wav", "flac"],
297298
[8000, 16000, 44100],
298299
[1, 2],
299300
)
300-
def test_audio_num_frames(self, ext, sample_rate, num_channels):
301-
""""""
301+
def test_audio_num_frames_lossless(self, ext, sample_rate, num_channels):
302+
"""Lossless format preserves the data"""
302303
filename = f"test.{ext}"
303304

305+
data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, dtype="int16", channels_first=False)
306+
304307
# Write data
305308
dst = self.get_dst(filename)
306309
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
307-
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
310+
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels, format="s16")
311+
with s.open():
312+
s.write_audio_chunk(0, data)
308313

309-
freq = 300
310-
duration = 60
311-
theta = torch.linspace(0, freq * 2 * 3.14 * duration, sample_rate * duration)
312-
if num_channels == 1:
313-
chunk = torch.sin(theta).unsqueeze(-1)
314-
else:
315-
chunk = torch.stack([torch.sin(theta), torch.cos(theta)], dim=-1)
314+
if self.test_fileobj:
315+
dst.flush()
316+
317+
# Load data
318+
s = torchaudio.io.StreamReader(src=self.get_temp_path(filename))
319+
s.add_audio_stream(-1)
320+
s.process_all_packets()
321+
(saved,) = s.pop_chunks()
322+
323+
self.assertEqual(saved, data)
324+
325+
@parameterized.expand(
326+
[
327+
("mp3", 1, 8000),
328+
("mp3", 1, 16000),
329+
("mp3", 1, 44100),
330+
("mp3", 2, 8000),
331+
("mp3", 2, 16000),
332+
("mp3", 2, 44100),
333+
("opus", 1, 48000),
334+
]
335+
)
336+
def test_audio_num_frames_lossy(self, ext, num_channels, sample_rate):
337+
"""Saving audio preserves the number of channels and frames"""
338+
filename = f"test.{ext}"
339+
340+
data = get_sinusoid(sample_rate=sample_rate, n_channels=num_channels, channels_first=False)
341+
342+
# Write data
343+
dst = self.get_dst(filename)
344+
s = torchaudio.io.StreamWriter(dst=dst, format=ext)
345+
s.add_audio_stream(sample_rate=sample_rate, num_channels=num_channels)
316346
with s.open():
317-
s.write_audio_chunk(0, chunk)
347+
s.write_audio_chunk(0, data)
348+
318349
if self.test_fileobj:
319350
dst.flush()
320351

@@ -324,9 +355,21 @@ def test_audio_num_frames(self, ext, sample_rate, num_channels):
324355
s.process_all_packets()
325356
(saved,) = s.pop_chunks()
326357

327-
assert saved.shape == chunk.shape
328-
if format in ["wav", "flac"]:
329-
self.assertEqual(saved, chunk)
358+
# This test fails for OPUS if FFmpeg is 4.1, but it passes for 4.2+
359+
# 4.1 produces 48312 samples (extra 312)
360+
# Probably this commit fixes it.
361+
# https://github.com/FFmpeg/FFmpeg/commit/18aea7bdd96b320a40573bccabea56afeccdd91c
362+
# TODO: issue warning if 4.1?
363+
if ext == "opus":
364+
ver = torchaudio.utils.ffmpeg_utils.get_versions()["libavcodec"]
365+
# 5.1 libavcodec 59. 18.100
366+
# 4.4 libavcodec 58.134.100
367+
# 4.3 libavcodec 58. 91.100
368+
# 4.2 libavcodec 58. 54.100
369+
# 4.1 libavcodec 58. 35.100
370+
if ver[0] < 59 and ver[1] < 54:
371+
return
372+
self.assertEqual(saved.shape, data.shape)
330373

331374
def test_preserve_fps(self):
332375
"""Decimal point frame rate is properly saved

torchaudio/csrc/ffmpeg/stream_writer/stream_writer.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,9 @@ std::unique_ptr<FilterGraph> _get_audio_filter(
409409
AVCodecContextPtr& ctx) {
410410
std::stringstream desc;
411411
desc << "aformat=" << av_get_sample_fmt_name(ctx->sample_fmt);
412+
if (ctx->frame_size) {
413+
desc << ",asetnsamples=n=" << ctx->frame_size << ":p=0";
414+
}
412415

413416
auto p = std::make_unique<FilterGraph>(AVMEDIA_TYPE_AUDIO);
414417
p->add_audio_src(fmt, ctx->time_base, ctx->sample_rate, ctx->channel_layout);

0 commit comments

Comments
 (0)