-
Notifications
You must be signed in to change notification settings - Fork 686
/
Copy pathautograd.cpp
69 lines (63 loc) · 1.9 KB
/
autograd.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <libtorchaudio/rnnt/compute.h>
namespace torchaudio {
namespace rnnt {
class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
public:
static torch::autograd::tensor_list forward(
torch::autograd::AutogradContext* ctx,
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
torch::Tensor undef;
auto result = rnnt_loss(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
auto costs = std::get<0>(result);
auto grads = std::get<1>(result).value_or(undef);
ctx->save_for_backward({grads});
return {costs, grads};
}
static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto grad = saved[0];
auto grad_out = grad_outputs[0].view({-1, 1, 1, 1});
auto result = grad * grad_out;
torch::Tensor undef;
return {result, undef, undef, undef, undef, undef, undef, undef};
}
};
std::tuple<torch::Tensor, std::optional<torch::Tensor>> rnnt_loss_autograd(
torch::Tensor& logits,
const torch::Tensor& targets,
const torch::Tensor& logit_lengths,
const torch::Tensor& target_lengths,
int64_t blank,
double clamp,
bool fused_log_softmax = true) {
at::AutoDispatchBelowADInplaceOrView guard;
auto results = RNNTLossFunction::apply(
logits,
targets,
logit_lengths,
target_lengths,
blank,
clamp,
fused_log_softmax);
return std::make_tuple(results[0], results[1]);
}
TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) {
m.impl("rnnt_loss", rnnt_loss_autograd);
}
} // namespace rnnt
} // namespace torchaudio