@@ -42,6 +42,7 @@ static Tensor padRight(const Tensor& tensor, std::optional<int64_t> has_bdim, in
42
42
}
43
43
44
44
template <typename F, F Func>
45
+ static
45
46
std::tuple<Tensor, std::optional<int64_t >,Tensor, std::optional<int64_t >,Tensor, std::optional<int64_t >>
46
47
batch_norm_batch_rule (
47
48
const Tensor& input, std::optional<int64_t > input_bdim,
@@ -70,10 +71,10 @@ batch_norm_batch_rule(
70
71
if (!input_bdim && !running_mean_bdim && !running_var_bdim) {
71
72
const auto dummy_weight = at::ones (input.size (1 ), input.options ()); // cudnn and miopen require a weight
72
73
const auto dummy_bias = at::zeros (input.size (1 ), input.options ()); // without this, get "strides() called on undefined Tensor" on cuda
73
- const auto result = Func (input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
74
+ auto result = Func (input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
74
75
result0 = std::get<0 >(result).transpose (0 , 1 ); // [C, B, *]
75
- mean = std::get<1 >(result);
76
- rstd = std::get<2 >(result);
76
+ mean = std::move (std:: get<1 >(result) );
77
+ rstd = std::move (std:: get<2 >(result) );
77
78
} else {
78
79
bdim_size = get_bdim_size3 (input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
79
80
auto input_ = moveBatchDimToFront (input, input_bdim);
@@ -95,12 +96,12 @@ batch_norm_batch_rule(
95
96
96
97
const auto dummy_weight = at::ones (input_.size (1 ), input_.options ()); // cudnn and miopen require a weight
97
98
const auto dummy_bias = at::zeros (input_.size (1 ), input_.options ()); // without this, get "strides() called on undefined Tensor" on cuda
98
- const auto result = Func (input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
99
+ auto result = Func (input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
99
100
result0 = std::get<0 >(result).transpose (0 , 1 ); // [(B0, C), B, *]
101
+ mean = std::move (std::get<1 >(result));
102
+ rstd = std::move (std::get<2 >(result));
100
103
result0 = reshape_dim_outof (0 , bdim_size.value (), result0); // [B0, C, B, *]
101
- mean = std::get<1 >(result);
102
104
mean = reshape_dim_outof (0 , bdim_size.value (), mean); // [B0, C]
103
- rstd = std::get<2 >(result);
104
105
rstd = reshape_dim_outof (0 , bdim_size.value (), rstd); // [B0, C]
105
106
}
106
107
@@ -124,6 +125,7 @@ batch_norm_batch_rule(
124
125
}
125
126
126
127
template <typename F, F Func>
128
+ static
127
129
std::tuple<at::Tensor, std::optional<int64_t >> batch_norm_backward_no_weight_bias_batch_rule (
128
130
const at::Tensor & grad_out, std::optional<int64_t > grad_out_bdim,
129
131
const at::Tensor & input, std::optional<int64_t > input_bdim,
@@ -142,9 +144,9 @@ std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bia
142
144
TORCH_INTERNAL_ASSERT (!mean_bdim);
143
145
TORCH_INTERNAL_ASSERT (!rstd_bdim);
144
146
const auto dummy_weight = at::ones (input.size (1 ), input.options ());
145
- const auto result = Func (
147
+ auto result =Func (
146
148
grad_out, input, dummy_weight, running_mean_opt, running_var_opt, mean, rstd, training, eps, {true , false , false });
147
- return std::make_tuple (std::get<0 >(result), std::nullopt) ;
149
+ return { std::move (std::get<0 >(result)) , std::nullopt} ;
148
150
}
149
151
150
152
auto grad_out_ = moveBatchDimToFront (grad_out, grad_out_bdim);
@@ -196,6 +198,7 @@ std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bia
196
198
}
197
199
198
200
template <typename F, F Func>
201
+ static
199
202
std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing (
200
203
const at::Tensor & grad_out,
201
204
const at::Tensor & input,
@@ -270,15 +273,15 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
270
273
unwrapTensorAtLevel (grad_normalized_input.transpose (0 , 1 ), cur_level); // [B0, B, C, *]
271
274
272
275
c10::impl::ExcludeDispatchKeyGuard guard (DispatchKey::FuncTorchBatched);
273
- const auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>(
276
+ auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>(
274
277
grad_normalized_input_value, grad_normalized_input_bdim,
275
278
input_value, input_bdim,
276
279
running_mean_value, running_mean_bdim,
277
280
running_var_value, running_var_bdim,
278
281
save_mean_value, save_mean_bdim,
279
282
save_rstd_value, save_rstd_bdim,
280
283
training, eps);
281
- grad_input = makeBatched (std::get<0 >(results), std::get<1 >(results), cur_level);
284
+ grad_input = makeBatched (std::move (std:: get<0 >(results) ), std::get<1 >(results), cur_level);
282
285
}
283
286
return std::make_tuple (grad_input, grad_weight, grad_bias);
284
287
}
@@ -312,16 +315,13 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
312
315
const auto bdim_size = input_value.size (*input_bdim);
313
316
314
317
c10::impl::ExcludeDispatchKeyGuard guard (DispatchKey::FuncTorchBatched);
315
- const auto result = at::native_group_norm (input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps);
316
- result0 = makeBatched (reshape_dim_outof (0 , bdim_size, std::get< 0 >(result) ), 0 , cur_level);
317
- mean = makeBatched (reshape_dim_outof (0 , bdim_size, std::get< 1 >(result) ), 0 , cur_level);
318
- rstd = makeBatched (reshape_dim_outof (0 , bdim_size, std::get< 2 >(result) ), 0 , cur_level);
318
+ std::tie (result0, mean, rstd) = at::native_group_norm (input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps);
319
+ result0 = makeBatched (reshape_dim_outof (0 , bdim_size, result0 ), 0 , cur_level);
320
+ mean = makeBatched (reshape_dim_outof (0 , bdim_size, mean ), 0 , cur_level);
321
+ rstd = makeBatched (reshape_dim_outof (0 , bdim_size, rstd ), 0 , cur_level);
319
322
} else {
320
323
c10::impl::ExcludeDispatchKeyGuard guard (DispatchKey::FuncTorchBatched);
321
- const auto result = at::native_group_norm (input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps);
322
- result0 = std::get<0 >(result);
323
- mean = std::get<1 >(result);
324
- rstd = std::get<2 >(result);
324
+ std::tie (result0, mean, rstd) = at::native_group_norm (input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps);
325
325
}
326
326
327
327
if (weight.defined ()) {
@@ -334,10 +334,10 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
334
334
result0 = result0 + padded_bias;
335
335
}
336
336
337
- return std::make_tuple (result0, mean, rstd);
337
+ return std::make_tuple (std::move ( result0), std::move ( mean), std::move ( rstd) );
338
338
}
339
339
340
- static std::tuple< at::Tensor, std::optional< int64_t >> group_norm_backward_no_weight_bias_batch_rule (
340
+ static at::Tensor group_norm_backward_no_weight_bias_batch_rule (
341
341
const at::Tensor & grad_out, std::optional<int64_t > grad_out_bdim,
342
342
const at::Tensor & input, std::optional<int64_t > input_bdim,
343
343
const at::Tensor & mean, std::optional<int64_t > mean_bdim,
@@ -359,15 +359,13 @@ static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_wei
359
359
mean_ = reshape_dim_into (0 , 0 , mean_); // [B0 * N, G]
360
360
rstd_ = reshape_dim_into (0 , 0 , rstd_); // [B0 * N, G]
361
361
362
- const auto result = native_group_norm_backward (
362
+ auto result0 = std::get< 0 >( native_group_norm_backward (
363
363
grad_out_.contiguous (),
364
364
input_.contiguous (),
365
365
mean_.contiguous (),
366
366
rstd_.contiguous (),
367
- std::nullopt, N * bdim_size, C, HxW, group, {true , false , false });
368
- auto result0 = std::get<0 >(result);
369
- result0 = reshape_dim_outof (0 , bdim_size, result0);
370
- return std::make_tuple (result0, 0 );
367
+ std::nullopt, N * bdim_size, C, HxW, group, {true , false , false }));
368
+ return reshape_dim_outof (0 , bdim_size, result0);
371
369
}
372
370
373
371
static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing (
@@ -422,19 +420,19 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
422
420
unwrapTensorAtLevel (grad_normalized_input, cur_level);
423
421
424
422
c10::impl::ExcludeDispatchKeyGuard guard (DispatchKey::FuncTorchBatched);
425
- const auto res = group_norm_backward_no_weight_bias_batch_rule (
423
+ auto tensor = group_norm_backward_no_weight_bias_batch_rule (
426
424
grad_normalized_input_value, grad_normalized_input_bdim,
427
425
input_value, input_bdim,
428
426
mean_value, mean_bdim,
429
427
rstd_value, rstd_bdim,
430
428
N, C, HxW, group
431
429
);
432
- grad_input = makeBatched (std::get< 0 >(res ), std::get< 1 >(res) , cur_level);
430
+ grad_input = makeBatched (std::move (tensor ), 0 , cur_level);
433
431
}
434
432
return std::make_tuple (grad_input, grad_weight, grad_bias);
435
433
}
436
434
437
- C10_ALWAYS_INLINE bool has_same_shape (
435
+ static bool has_same_shape (
438
436
const Tensor& tensor, std::optional<int64_t > tensor_bdim,
439
437
c10::SymIntArrayRef normalized_shape) {
440
438
if (!tensor.defined ()) {
@@ -457,7 +455,7 @@ C10_ALWAYS_INLINE bool has_same_shape(
457
455
return true ;
458
456
}
459
457
460
- C10_ALWAYS_INLINE void check_same_shape (
458
+ static C10_ALWAYS_INLINE void check_same_shape (
461
459
const Tensor& tensor, std::optional<int64_t > tensor_bdim,
462
460
c10::SymIntArrayRef normalized_shape, const std::string& name) {
463
461
TORCH_CHECK (has_same_shape (tensor, tensor_bdim, normalized_shape),
@@ -469,7 +467,7 @@ C10_ALWAYS_INLINE void check_same_shape(
469
467
}
470
468
471
469
// Ugh, hard to deduplicate
472
- C10_ALWAYS_INLINE void _check_layer_norm_inputs (
470
+ static C10_ALWAYS_INLINE void _check_layer_norm_inputs (
473
471
SymIntArrayRef normalized_shape,
474
472
const Tensor& weight, std::optional<int64_t > weight_bdim,
475
473
const Tensor& bias, std::optional<int64_t > bias_bdim) {
@@ -493,11 +491,9 @@ native_layer_norm_batch_rule(
493
491
double eps) {
494
492
auto input_ = moveBatchDimToFront (input, input_bdim);
495
493
if (!weight_bdim && !bias_bdim) {
496
- const auto result = at::native_layer_norm_symint (input_, normalized_shape, weight_opt, bias_opt, eps);
497
- const auto mean = std::get<1 >(result);
498
- const auto rstd = std::get<2 >(result);
494
+ auto [result0, mean, rstd] = at::native_layer_norm_symint (input_, normalized_shape, weight_opt, bias_opt, eps);
499
495
const auto stats_bdim = compute_stat_bdim (input_bdim, mean);
500
- return std::make_tuple (std::get< 0 >(result ), 0 , mean, stats_bdim, rstd, stats_bdim);
496
+ return std::make_tuple (std::move (result0 ), 0 , std::move ( mean) , stats_bdim, std::move ( rstd) , stats_bdim);
501
497
}
502
498
503
499
// See [Note: hacky wrapper removal for optional tensor]
@@ -509,9 +505,7 @@ native_layer_norm_batch_rule(
509
505
510
506
const auto input_logical_rank = rankWithoutBatchDim (input, input_bdim);
511
507
const auto result = at::native_layer_norm_symint (input_, normalized_shape, std::nullopt, std::nullopt, eps);
512
- auto result0 = std::get<0 >(result);
513
- const auto mean = std::get<1 >(result);
514
- const auto rstd = std::get<2 >(result);
508
+ auto [result0, mean, rstd] = result;
515
509
const auto stats_bdim = compute_stat_bdim (input_bdim, mean);
516
510
517
511
if (weight.defined ()) {
@@ -638,7 +632,7 @@ static std::tuple<at::Tensor,at::Tensor,at::Tensor> native_layer_norm_backward_p
638
632
unwrapTensorAtLevel (grad_normalized_input, cur_level);
639
633
640
634
c10::impl::ExcludeDispatchKeyGuard guard (DispatchKey::FuncTorchBatched);
641
- const auto results = native_layer_norm_backward_no_weight_bias_batch_rule (
635
+ auto results = native_layer_norm_backward_no_weight_bias_batch_rule (
642
636
grad_normalized_input_value, grad_normalized_input_bdim,
643
637
input_value, input_bdim,
644
638
normalized_shape,
0 commit comments