Skip to content

Commit 3280b90

Browse files
authored
Encoder: properly validate sample rate parameter (#624)
1 parent 44bef81 commit 3280b90

File tree

3 files changed

+40
-13
lines changed

3 files changed

+40
-13
lines changed

src/torchcodec/_core/Encoder.cpp

+35-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,40 @@
1+
#include <sstream>
2+
13
#include "src/torchcodec/_core/Encoder.h"
24
#include "torch/types.h"
35

46
namespace facebook::torchcodec {
57

8+
namespace {
9+
10+
void validateSampleRate(const AVCodec& avCodec, int sampleRate) {
11+
if (avCodec.supported_samplerates == nullptr) {
12+
return;
13+
}
14+
15+
for (auto i = 0; avCodec.supported_samplerates[i] != 0; ++i) {
16+
if (sampleRate == avCodec.supported_samplerates[i]) {
17+
return;
18+
}
19+
}
20+
std::stringstream supportedRates;
21+
for (auto i = 0; avCodec.supported_samplerates[i] != 0; ++i) {
22+
if (i > 0) {
23+
supportedRates << ", ";
24+
}
25+
supportedRates << avCodec.supported_samplerates[i];
26+
}
27+
28+
TORCH_CHECK(
29+
false,
30+
"invalid sample rate=",
31+
sampleRate,
32+
". Supported sample rate values are: ",
33+
supportedRates.str());
34+
}
35+
36+
} // namespace
37+
638
AudioEncoder::~AudioEncoder() {}
739

840
// TODO-ENCODING: disable ffmpeg logs by default
@@ -12,7 +44,7 @@ AudioEncoder::AudioEncoder(
1244
int sampleRate,
1345
std::string_view fileName,
1446
std::optional<int64_t> bit_rate)
15-
: wf_(wf), sampleRate_(sampleRate) {
47+
: wf_(wf) {
1648
TORCH_CHECK(
1749
wf_.dtype() == torch::kFloat32,
1850
"waveform must have float32 dtype, got ",
@@ -57,7 +89,8 @@ AudioEncoder::AudioEncoder(
5789
// well when "-b:a" isn't specified.
5890
avCodecContext_->bit_rate = bit_rate.value_or(0);
5991

60-
avCodecContext_->sample_rate = sampleRate_;
92+
validateSampleRate(*avCodec, sampleRate);
93+
avCodecContext_->sample_rate = sampleRate;
6194

6295
// Note: This is the format of the **input** waveform. This doesn't determine
6396
// the output.

src/torchcodec/_core/Encoder.h

+4-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ class AudioEncoder {
1414
// supported.
1515
AudioEncoder(
1616
const torch::Tensor wf,
17+
// The *output* sample rate. We can't really decide for the user what it
18+
// should be. Particularly, the sample rate of the input waveform should
19+
// match this, and that's up to the user. If sample rates don't match,
20+
// encoding will still work but audio will be distorted.
1721
int sampleRate,
1822
std::string_view fileName,
1923
std::optional<int64_t> bit_rate = std::nullopt);
@@ -30,13 +34,5 @@ class AudioEncoder {
3034
int streamIndex_;
3135

3236
const torch::Tensor wf_;
33-
// The *output* sample rate. We can't really decide for the user what it
34-
// should be. Particularly, the sample rate of the input waveform should match
35-
// this, and that's up to the user. If sample rates don't match, encoding will
36-
// still work but audio will be distorted.
37-
// We technically could let the user also specify the input sample rate, and
38-
// resample the waveform internally to match them, but that's not in scope for
39-
// an initial version (if at all).
40-
int sampleRate_;
4137
};
4238
} // namespace facebook::torchcodec

test/test_ops.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1107,9 +1107,7 @@ def test_bad_input(self, tmp_path):
11071107
wf=torch.rand(10, 10), sample_rate=10, filename="./file.bad_extension"
11081108
)
11091109

1110-
# TODO-ENCODING: raise more informative error message when sample rate
1111-
# isn't supported
1112-
with pytest.raises(RuntimeError, match="Invalid argument"):
1110+
with pytest.raises(RuntimeError, match="invalid sample rate=10"):
11131111
create_audio_encoder(
11141112
wf=self.decode(NASA_AUDIO_MP3),
11151113
sample_rate=10,

0 commit comments

Comments
 (0)