-
Notifications
You must be signed in to change notification settings - Fork 686
/
Copy pathcompute.cpp
33 lines (31 loc) · 908 Bytes
/
compute.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include <libtorchaudio/rnnt/compute.h>
std::tuple<torch::Tensor, std::optional<torch::Tensor>> rnnt_loss(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
static auto op = torch::Dispatcher::singleton()
.findSchemaOrThrow("torchaudio::rnnt_loss", "")
.typed<decltype(rnnt_loss)>();
return op.call(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def(
"rnnt_loss(Tensor logits,"
"Tensor targets,"
"Tensor logit_lengths,"
"Tensor target_lengths,"
"int blank,"
"float clamp,"
"bool fused_log_softmax) -> (Tensor, Tensor?)");
}