-
Notifications
You must be signed in to change notification settings - Fork 103
/
Copy pathBatchRulesLoss.cpp
317 lines (282 loc) · 11.6 KB
/
BatchRulesLoss.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <functorch/csrc/BatchRulesHelper.h>
#include <functorch/csrc/PlumbingHelper.h>
#include <functorch/csrc/BatchedFallback.h>
#include <ATen/core/dispatch/Dispatcher.h>
namespace at { namespace functorch {
// Flattens out all dims except the batch dim, and also moves batch dim
// (if it exists) to front.
at::Tensor flatten_logical(const Tensor& tensor, optional<int64_t> bdim) {
if (bdim.has_value()) {
auto result = moveBatchDimToFront(tensor, bdim);
if (result.dim() > 1) {
return result.flatten(1);
} else {
return result;
}
} else {
return tensor.flatten();
}
}
std::tuple<at::Tensor,optional<int64_t>>
mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction) {
auto self_ = flatten_logical(self, self_bdim);
auto target_ = flatten_logical(target, target_bdim);
auto result = at::mse_loss(self_, target_, Reduction::None);
if (result.dim() == 1) {
return std::make_tuple(result, 0);
} else if (reduction == Reduction::None) {
return std::make_tuple(result, 0);
} else if (reduction == Reduction::Sum) {
return std::make_tuple(result.sum(-1), 0);
} else if (reduction == Reduction::Mean) {
return std::make_tuple(result.mean(-1), 0);
}
TORCH_INTERNAL_ASSERT(false);
};
std::tuple<at::Tensor,optional<int64_t>>
mse_loss_backward_batch_rule(
const at::Tensor& grad_output, optional<int64_t> grad_output_bdim,
const at::Tensor& self, optional<int64_t> self_bdim,
const at::Tensor& target, optional<int64_t> target_bdim,
int64_t reduction) {
auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim);
auto self_ = moveBatchDimToFront(self, self_bdim);
auto target_ = moveBatchDimToFront(target, target_bdim);
if (reduction != Reduction::None && grad_output_bdim.has_value()) {
// grad_output_ is of shape [N]. Input is of shape [N?, ...].
// We need to view grad_output_ as shape [N, ...].
auto self_rank_without_bdim = rankWithoutBatchDim(self, self_bdim);
DimVector view_shape(self_rank_without_bdim + 1, 1);
view_shape[0] = grad_output_.size(0);
grad_output_ = grad_output_.view(view_shape);
}
auto result = at::mse_loss_backward(grad_output_, self_, target_, Reduction::None);
if (reduction == Reduction::Mean) {
return std::make_tuple(result / numelWithoutBatchDim(self, self_bdim), 0);
}
return std::make_tuple(result, 0);
};
static Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
if (reduction == at::Reduction::Mean) {
return unreduced.mean();
} else if (reduction == at::Reduction::Sum) {
return unreduced.sum();
}
return unreduced;
}
Tensor binary_cross_entropy_plumbing(
const Tensor& self, const Tensor& target,
const optional<Tensor>& weight, int64_t reduction) {
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)
&& !isBatchedAtLevel(weight, cur_level)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
return at::binary_cross_entropy(self, target, weight, reduction);
}
Tensor self_value;
optional<int64_t> self_bdim;
std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
Tensor target_value;
optional<int64_t> target_bdim;
std::tie(target_value, target_bdim) = unwrapTensorAtLevel(target, cur_level);
Tensor result;
if (self_bdim || target_bdim) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
const auto bdim_size = get_bdim_size2(self_value, self_bdim, target_value, target_bdim);
auto self_ = moveBatchDimToFront(self_value, self_bdim);
auto target_ = moveBatchDimToFront(target_value, target_bdim);
self_ = ensure_has_bdim(self_, self_bdim.has_value(), bdim_size);
target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);
result = at::binary_cross_entropy(self_, target_, nullopt, Reduction::None);
result = makeBatched(result, 0, cur_level);
} else {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
result = at::binary_cross_entropy(self_value, target_value, nullopt, Reduction::None);
}
if (weight.has_value() && weight->defined()) {
result = result * weight.value();
}
return apply_loss_reduction(result, reduction);
}
Tensor binary_cross_entropy_backward_plumbing(
const Tensor& grad, const Tensor& input, const Tensor& target,
const c10::optional<Tensor>& weight_opt, int64_t reduction) {
auto maybe_layer = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
int64_t cur_level = maybe_layer->layerId();
if (!areAnyBatchedAtLevel({grad, input, target, weight_opt}, cur_level)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
return at::binary_cross_entropy_backward(grad, input, target, weight_opt, reduction);
}
Tensor grad_value;
optional<int64_t> grad_bdim;
std::tie(grad_value, grad_bdim) = unwrapTensorAtLevel(
reduction == Reduction::None ? grad : grad.expand_as(input), cur_level);
Tensor input_value;
optional<int64_t> input_bdim;
std::tie(input_value, input_bdim) = unwrapTensorAtLevel(input, cur_level);
Tensor target_value;
optional<int64_t> target_bdim;
std::tie(target_value, target_bdim) = unwrapTensorAtLevel(target, cur_level);
Tensor grad_input;
if (grad_bdim || input_bdim || target_bdim) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
const auto bdim_size = get_bdim_size3(
grad_value, grad_bdim, input_value, input_bdim, target_value, target_bdim);
auto grad_ = moveBatchDimToFront(grad_value, grad_bdim);
auto input_ = moveBatchDimToFront(input_value, input_bdim);
auto target_ = moveBatchDimToFront(target_value, target_bdim);
grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bdim_size);
input_ = ensure_has_bdim(input_, input_bdim.has_value(), bdim_size);
target_ = ensure_has_bdim(target_, target_bdim.has_value(), bdim_size);
grad_input = at::binary_cross_entropy_backward(
grad_, input_, target_, nullopt, Reduction::None);
grad_input = makeBatched(grad_input, 0, cur_level);
} else {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
grad_input = at::binary_cross_entropy_backward(
grad_value, input_value, target_value, nullopt, Reduction::None);
}
if (weight_opt.has_value() && weight_opt->defined()) {
grad_input = grad_input * weight_opt.value();
}
if (reduction == Reduction::Mean) {
grad_input.div_(input.numel());
}
return grad_input;
}
std::tuple<Tensor, Tensor> nll_loss_forward_decomposition(
const Tensor & self,
const Tensor & target,
const c10::optional<Tensor> & weight,
int64_t reduction, int64_t ignore_index) {
// self can be [N, C, ...] or [C]
// target can be [N, ...] or []
int64_t channel_dim = 1;
if (self.dim() < 2) {
channel_dim = 0;
}
auto self_ = self;
Tensor weight_;
if (weight && weight->defined()) {
// Here is a specific case with reduction mean and non-batched tensors
// https://github.com/pytorch/pytorch/issues/61309
// In this case weight is cancelled: w * x[t] / w -> x[t]
if (!(reduction == Reduction::Mean && self_.dim() < 2)) {
// reshape weights to [1, C, 1, ..., 1]
auto shape = weight->sizes();
VmapDimVector new_shape(self_.dim(), 1);
new_shape[channel_dim] = shape[0];
weight_ = weight->reshape(new_shape);
self_ = self_ * weight_;
}
}
auto target_ = target.unsqueeze(channel_dim);
// target can be [N, 1, ...] or [1]
auto result = -at::gather(self_, channel_dim, target_).squeeze(channel_dim);
bool has_ignore_index = ignore_index >= 0;
Tensor ignore_index_mask, total_weight;
if (has_ignore_index) {
ignore_index_mask = target != ignore_index;
result = result * ignore_index_mask;
if (!(reduction == Reduction::None && self.dim() >= 2)) {
total_weight = ignore_index_mask.sum().to(self_);
}
}
if (!total_weight.defined()) {
auto init_value = (reduction == Reduction::None && self.dim() >= 2) ? 0.0 : 1.0 * result.numel();
total_weight = at::full(
{}, init_value, self_.scalar_type(),
self_.layout(), self_.device(), nullopt);
}
// Apply the reduction
if (result.dim() > 0) {
if (reduction == Reduction::Sum) {
result = result.sum();
} else if (reduction == Reduction::Mean) {
if (!weight || !weight->defined()) {
if (has_ignore_index) {
TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
// total_weight is ignore_index_mask.sum()
result = result.sum() / total_weight;
} else {
result = result.mean();
}
} else {
TORCH_INTERNAL_ASSERT(weight_.defined());
weight_ = weight_.expand(self_.sizes());
auto wsum = at::gather(weight_, channel_dim, target_).squeeze(channel_dim);
if (has_ignore_index) {
TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
wsum = wsum * ignore_index_mask;
}
wsum = wsum.sum();
result = result.sum() / wsum;
total_weight = wsum;
}
}
} else if (reduction == Reduction::Mean && weight && weight->defined()) {
// here weight is [C] and target is [1]
auto wsum = at::gather(*weight, channel_dim, target_).squeeze(channel_dim);
if (has_ignore_index) {
TORCH_INTERNAL_ASSERT(ignore_index_mask.defined());
wsum = wsum * ignore_index_mask;
}
total_weight = wsum.sum();
}
return std::make_tuple(result, total_weight);
}
at::Tensor nll_loss_backward_decomposition(
const at::Tensor & grad_output, const at::Tensor & self,
const at::Tensor & target, const c10::optional<at::Tensor> & weight,
int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) {
int64_t channel_dim = 1;
if (self.dim() < 2) {
channel_dim = 0;
}
auto target_ = target.unsqueeze(channel_dim);
auto grad_output_ = grad_output;
if (reduction == Reduction::Mean) {
grad_output_ = grad_output_ / total_weight;
}
auto grad_input = at::zeros_like(self);
grad_input = at::scatter(grad_input, channel_dim, target_, -1.0);
if (grad_output_.dim() < grad_input.dim() && grad_output_.dim() > 0) {
grad_output_ = grad_output_.unsqueeze(channel_dim);
}
Tensor weight_;
if (weight && weight->defined()) {
auto self_ = self;
auto shape = weight->sizes();
VmapDimVector new_shape(self_.dim(), 1);
new_shape[channel_dim] = shape[0];
weight_ = weight->reshape(new_shape);
grad_output_ = grad_output_ * weight_;
}
bool has_ignore_index = ignore_index >= 0;
Tensor ignore_index_mask;
if (has_ignore_index) {
ignore_index_mask = target_ != ignore_index;
grad_output_ = grad_output_ * ignore_index_mask;
}
return grad_input * grad_output_;
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("nll_loss_forward", nll_loss_forward_decomposition);
m.impl("nll_loss2d_forward", nll_loss_forward_decomposition);
m.impl("nll_loss_backward", nll_loss_backward_decomposition);
m.impl("nll_loss2d_backward", nll_loss_backward_decomposition);
VMAP_SUPPORT(mse_loss, mse_loss_batch_rule);
VMAP_SUPPORT(mse_loss_backward, mse_loss_backward_batch_rule);
m.impl("binary_cross_entropy", binary_cross_entropy_plumbing);
m.impl("binary_cross_entropy_backward", binary_cross_entropy_backward_plumbing);
}
}}