1
+ #include < sstream>
2
+
1
3
#include " src/torchcodec/_core/Encoder.h"
2
4
#include " torch/types.h"
3
5
4
6
namespace facebook ::torchcodec {
5
7
8
+ namespace {
9
+
10
+ void validateSampleRate (const AVCodec& avCodec, int sampleRate) {
11
+ if (avCodec.supported_samplerates == nullptr ) {
12
+ return ;
13
+ }
14
+
15
+ for (auto i = 0 ; avCodec.supported_samplerates [i] != 0 ; ++i) {
16
+ if (sampleRate == avCodec.supported_samplerates [i]) {
17
+ return ;
18
+ }
19
+ }
20
+ std::stringstream supportedRates;
21
+ for (auto i = 0 ; avCodec.supported_samplerates [i] != 0 ; ++i) {
22
+ if (i > 0 ) {
23
+ supportedRates << " , " ;
24
+ }
25
+ supportedRates << avCodec.supported_samplerates [i];
26
+ }
27
+
28
+ TORCH_CHECK (
29
+ false ,
30
+ " invalid sample rate=" ,
31
+ sampleRate,
32
+ " . Supported sample rate values are: " ,
33
+ supportedRates.str ());
34
+ }
35
+
36
+ } // namespace
37
+
6
38
AudioEncoder::~AudioEncoder () {}
7
39
8
40
// TODO-ENCODING: disable ffmpeg logs by default
@@ -12,7 +44,7 @@ AudioEncoder::AudioEncoder(
12
44
int sampleRate,
13
45
std::string_view fileName,
14
46
std::optional<int64_t > bit_rate)
15
- : wf_(wf), sampleRate_(sampleRate) {
47
+ : wf_(wf) {
16
48
TORCH_CHECK (
17
49
wf_.dtype () == torch::kFloat32 ,
18
50
" waveform must have float32 dtype, got " ,
@@ -57,7 +89,8 @@ AudioEncoder::AudioEncoder(
57
89
// well when "-b:a" isn't specified.
58
90
avCodecContext_->bit_rate = bit_rate.value_or (0 );
59
91
60
- avCodecContext_->sample_rate = sampleRate_;
92
+ validateSampleRate (*avCodec, sampleRate);
93
+ avCodecContext_->sample_rate = sampleRate;
61
94
62
95
// Note: This is the format of the **input** waveform. This doesn't determine
63
96
// the output.
0 commit comments