Skip to content

Commit 2c137e7

Browse files
authored
Encoding: support wav, flac etc. (#630)
1 parent 12cdaa8 commit 2c137e7

File tree

3 files changed

+109
-31
lines changed

3 files changed

+109
-31
lines changed

src/torchcodec/_core/Encoder.cpp

+84-19
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,44 @@ void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
3333
supportedRates.str());
3434
}
3535

36+
static const std::vector<AVSampleFormat> preferredFormatsOrder = {
37+
AV_SAMPLE_FMT_FLTP,
38+
AV_SAMPLE_FMT_FLT,
39+
AV_SAMPLE_FMT_DBLP,
40+
AV_SAMPLE_FMT_DBL,
41+
AV_SAMPLE_FMT_S64P,
42+
AV_SAMPLE_FMT_S64,
43+
AV_SAMPLE_FMT_S32P,
44+
AV_SAMPLE_FMT_S32,
45+
AV_SAMPLE_FMT_S16P,
46+
AV_SAMPLE_FMT_S16,
47+
AV_SAMPLE_FMT_U8P,
48+
AV_SAMPLE_FMT_U8};
49+
50+
AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) {
51+
// Find a sample format that the encoder supports. We prefer using FLT[P],
52+
// since this is the format of the input waveform. If FLTP isn't supported
53+
// then we'll need to convert the AVFrame's format. Our heuristic is to encode
54+
// into the format with the highest resolution.
55+
if (avCodec.sample_fmts == nullptr) {
56+
// Can't really validate anything in this case, best we can do is hope that
57+
// FLTP is supported by the encoder. If not, FFmpeg will raise.
58+
return AV_SAMPLE_FMT_FLTP;
59+
}
60+
61+
for (AVSampleFormat preferredFormat : preferredFormatsOrder) {
62+
for (int i = 0; avCodec.sample_fmts[i] != -1; ++i) {
63+
if (avCodec.sample_fmts[i] == preferredFormat) {
64+
return preferredFormat;
65+
}
66+
}
67+
}
68+
// We should always find a match in preferredFormatsOrder, so we should always
69+
// return earlier. But in the event that a future FFmpeg version defines an
70+
// additional sample format that isn't in preferredFormatsOrder, we fallback:
71+
return avCodec.sample_fmts[0];
72+
}
73+
3674
} // namespace
3775

3876
AudioEncoder::~AudioEncoder() {}
@@ -47,6 +85,8 @@ AudioEncoder::AudioEncoder(
4785
wf_.dtype() == torch::kFloat32,
4886
"waveform must have float32 dtype, got ",
4987
wf_.dtype());
88+
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
89+
// planar (fltp).
5090
TORCH_CHECK(
5191
wf_.dim() == 2, "waveform must have 2 dimensions, got ", wf_.dim());
5292

@@ -92,14 +132,10 @@ AudioEncoder::AudioEncoder(
92132
validateSampleRate(*avCodec, sampleRate);
93133
avCodecContext_->sample_rate = sampleRate;
94134

95-
// Note: This is the format of the **input** waveform. This doesn't determine
96-
// the output.
97-
// TODO-ENCODING check contiguity of the input wf to ensure that it is indeed
98-
// planar.
99-
// TODO-ENCODING If the encoder doesn't support FLTP (like flac), FFmpeg will
100-
// raise. We need to handle this, probably converting the format with
101-
// libswresample.
102-
avCodecContext_->sample_fmt = AV_SAMPLE_FMT_FLTP;
135+
// Input waveform is expected to be FLTP. Not all encoders support FLTP, so we
136+
// may need to convert the wf into a supported output sample format, which is
137+
// what the `.sample_fmt` defines.
138+
avCodecContext_->sample_fmt = findBestOutputSampleFormat(*avCodec);
103139

104140
int numChannels = static_cast<int>(wf_.sizes()[0]);
105141
TORCH_CHECK(
@@ -120,12 +156,6 @@ AudioEncoder::AudioEncoder(
120156
"avcodec_open2 failed: ",
121157
getFFMPEGErrorStringFromErrorCode(status));
122158

123-
TORCH_CHECK(
124-
avCodecContext_->frame_size > 0,
125-
"frame_size is ",
126-
avCodecContext_->frame_size,
127-
". Cannot encode. This should probably never happen?");
128-
129159
// We're allocating the stream here. Streams are meant to be freed by
130160
// avformat_free_context(avFormatContext), which we call in the
131161
// avFormatContext_'s destructor.
@@ -143,8 +173,11 @@ AudioEncoder::AudioEncoder(
143173
void AudioEncoder::encode() {
144174
UniqueAVFrame avFrame(av_frame_alloc());
145175
TORCH_CHECK(avFrame != nullptr, "Couldn't allocate AVFrame.");
146-
avFrame->nb_samples = avCodecContext_->frame_size;
147-
avFrame->format = avCodecContext_->sample_fmt;
176+
// Default to 256 like in torchaudio
177+
int numSamplesAllocatedPerFrame =
178+
avCodecContext_->frame_size > 0 ? avCodecContext_->frame_size : 256;
179+
avFrame->nb_samples = numSamplesAllocatedPerFrame;
180+
avFrame->format = AV_SAMPLE_FMT_FLTP;
148181
avFrame->sample_rate = avCodecContext_->sample_rate;
149182
avFrame->pts = 0;
150183
setChannelLayout(avFrame, avCodecContext_);
@@ -160,7 +193,6 @@ void AudioEncoder::encode() {
160193
uint8_t* pwf = static_cast<uint8_t*>(wf_.data_ptr());
161194
int numSamples = static_cast<int>(wf_.sizes()[1]); // per channel
162195
int numEncodedSamples = 0; // per channel
163-
int numSamplesPerFrame = avCodecContext_->frame_size; // per channel
164196
int numBytesPerSample = static_cast<int>(wf_.element_size());
165197
int numBytesPerChannel = numSamples * numBytesPerSample;
166198

@@ -178,7 +210,7 @@ void AudioEncoder::encode() {
178210
getFFMPEGErrorStringFromErrorCode(status));
179211

180212
int numSamplesToEncode =
181-
std::min(numSamplesPerFrame, numSamples - numEncodedSamples);
213+
std::min(numSamplesAllocatedPerFrame, numSamples - numEncodedSamples);
182214
int numBytesToEncode = numSamplesToEncode * numBytesPerSample;
183215

184216
for (int ch = 0; ch < wf_.sizes()[0]; ch++) {
@@ -211,7 +243,37 @@ void AudioEncoder::encode() {
211243

212244
void AudioEncoder::encodeInnerLoop(
213245
AutoAVPacket& autoAVPacket,
214-
const UniqueAVFrame& avFrame) {
246+
const UniqueAVFrame& srcAVFrame) {
247+
bool mustConvert =
248+
(avCodecContext_->sample_fmt != AV_SAMPLE_FMT_FLTP &&
249+
srcAVFrame != nullptr);
250+
UniqueAVFrame convertedAVFrame;
251+
if (mustConvert) {
252+
if (!swrContext_) {
253+
swrContext_.reset(createSwrContext(
254+
avCodecContext_,
255+
AV_SAMPLE_FMT_FLTP,
256+
avCodecContext_->sample_fmt,
257+
srcAVFrame->sample_rate, // No sample rate conversion
258+
srcAVFrame->sample_rate));
259+
}
260+
convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
261+
swrContext_,
262+
srcAVFrame,
263+
avCodecContext_->sample_fmt,
264+
srcAVFrame->sample_rate, // No sample rate conversion
265+
srcAVFrame->sample_rate);
266+
TORCH_CHECK(
267+
convertedAVFrame->nb_samples == srcAVFrame->nb_samples,
268+
"convertedAVFrame->nb_samples=",
269+
convertedAVFrame->nb_samples,
270+
" differs from ",
271+
"srcAVFrame->nb_samples=",
272+
srcAVFrame->nb_samples,
273+
"This is unexpected, please report on the TorchCodec bug tracker.");
274+
}
275+
const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame;
276+
215277
auto status = avcodec_send_frame(avCodecContext_.get(), avFrame.get());
216278
TORCH_CHECK(
217279
status == AVSUCCESS,
@@ -248,6 +310,9 @@ void AudioEncoder::encodeInnerLoop(
248310
}
249311

250312
void AudioEncoder::flushBuffers() {
313+
// We flush the main FFmpeg buffers, but not swresample buffers. Flushing
314+
// swresample is only necessary when converting sample rates, which we don't
315+
// do for encoding.
251316
AutoAVPacket autoAVPacket;
252317
encodeInnerLoop(autoAVPacket, UniqueAVFrame(nullptr));
253318
}

src/torchcodec/_core/Encoder.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ class AudioEncoder {
2626
private:
2727
void encodeInnerLoop(
2828
AutoAVPacket& autoAVPacket,
29-
const UniqueAVFrame& avFrame);
29+
const UniqueAVFrame& srcAVFrame);
3030
void flushBuffers();
3131

3232
UniqueEncodingAVFormatContext avFormatContext_;
3333
UniqueAVCodecContext avCodecContext_;
3434
int streamIndex_;
35+
UniqueSwrContext swrContext_;
3536

3637
const torch::Tensor wf_;
3738
};

test/test_ops.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from .utils import (
4545
assert_frames_equal,
4646
cpu_and_cuda,
47+
get_ffmpeg_major_version,
4748
in_fbcode,
4849
NASA_AUDIO,
4950
NASA_AUDIO_MP3,
@@ -1122,33 +1123,40 @@ def test_bad_input(self, tmp_path):
11221123
bit_rate=-1, # bad
11231124
)
11241125

1125-
def test_round_trip(self, tmp_path):
1126-
# Check that decode(encode(samples)) == samples
1126+
@pytest.mark.parametrize("output_format", ("wav", "flac"))
1127+
def test_round_trip(self, output_format, tmp_path):
1128+
# Check that decode(encode(samples)) == samples on lossless formats
1129+
1130+
if get_ffmpeg_major_version() == 4 and output_format == "wav":
1131+
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
1132+
11271133
asset = NASA_AUDIO_MP3
11281134
source_samples = self.decode(asset)
11291135

1130-
encoded_path = tmp_path / "output.mp3"
1136+
encoded_path = tmp_path / f"output.{output_format}"
11311137
encoder = create_audio_encoder(
11321138
wf=source_samples, sample_rate=asset.sample_rate, filename=str(encoded_path)
11331139
)
11341140
encode_audio(encoder)
11351141

1136-
# TODO-ENCODING: tol should be stricter. We probably need to encode
1137-
# into a lossless format.
1142+
rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None)
11381143
torch.testing.assert_close(
1139-
self.decode(encoded_path), source_samples, rtol=0, atol=0.07
1144+
self.decode(encoded_path), source_samples, rtol=rtol, atol=atol
11401145
)
11411146

1142-
# TODO-ENCODING: test more encoding formats
11431147
@pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI")
11441148
@pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32))
11451149
@pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999))
1146-
def test_against_cli(self, asset, bit_rate, tmp_path):
1150+
@pytest.mark.parametrize("output_format", ("mp3", "wav", "flac"))
1151+
def test_against_cli(self, asset, bit_rate, output_format, tmp_path):
11471152
# Encodes samples with our encoder and with the FFmpeg CLI, and checks
11481153
# that both decoded outputs are equal
11491154

1150-
encoded_by_ffmpeg = tmp_path / "ffmpeg_output.mp3"
1151-
encoded_by_us = tmp_path / "our_output.mp3"
1155+
if get_ffmpeg_major_version() == 4 and output_format == "wav":
1156+
pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files")
1157+
1158+
encoded_by_ffmpeg = tmp_path / f"ffmpeg_output.{output_format}"
1159+
encoded_by_us = tmp_path / f"our_output.{output_format}"
11521160

11531161
subprocess.run(
11541162
["ffmpeg", "-i", str(asset.path)]
@@ -1168,8 +1176,12 @@ def test_against_cli(self, asset, bit_rate, tmp_path):
11681176
)
11691177
encode_audio(encoder)
11701178

1179+
rtol, atol = (0, 1e-4) if output_format == "wav" else (None, None)
11711180
torch.testing.assert_close(
1172-
self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us)
1181+
self.decode(encoded_by_ffmpeg),
1182+
self.decode(encoded_by_us),
1183+
rtol=rtol,
1184+
atol=atol,
11731185
)
11741186

11751187

0 commit comments

Comments
 (0)