Skip to content

Commit 1fa27f6

Browse files
cyyeverpytorchmergebot
authored andcommitted
[3/N] Avoid copy in std::get (pytorch#141843)
Fixes #ISSUE_NUMBER Pull Request resolved: pytorch#141843 Approved by: https://github.com/Skylion007
1 parent add4a42 commit 1fa27f6

22 files changed

+130
-152
lines changed

aten/src/ATen/LegacyBatchedTensorImpl.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ void BatchedTensorImpl::checkInvariants() const {
7676
}
7777
}
7878

79-
// The following are publically exposed as methods of Tensor
79+
// The following are publicly exposed as methods of Tensor
8080

8181
IntArrayRef BatchedTensorImpl::strides_custom() const {
8282
return strides_default();
@@ -113,7 +113,7 @@ const char* BatchedTensorImpl::tensorimpl_type_name() const {
113113
return "BatchedTensorImpl";
114114
}
115115

116-
Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
116+
Tensor makeBatched(Tensor tensor, BatchDims bdims) {
117117
TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor));
118118
auto tensor_dim = tensor.dim();
119119
TORCH_CHECK(
@@ -124,15 +124,15 @@ Tensor makeBatched(const Tensor& tensor, BatchDims bdims) {
124124
std::all_of(bdims.begin(), bdims.end(),
125125
[](const BatchDim& bdim) { return bdim.level() < kVmapNumLevels; }),
126126
"We only support up to ", kVmapNumLevels, " nested vmaps");
127-
return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims));
127+
return at::detail::make_tensor<BatchedTensorImpl>(std::move(tensor), std::move(bdims));
128128
}
129129

130-
Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim) {
130+
Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim) {
131131
const auto* batched = maybeGetBatchedImpl(tensor);
132132
if (!batched) {
133133
BatchDims bdims;
134134
bdims.emplace_back(level, dim);
135-
return at::detail::make_tensor<BatchedTensorImpl>(tensor, std::move(bdims));
135+
return at::detail::make_tensor<BatchedTensorImpl>(std::move(tensor), std::move(bdims));
136136
}
137137
BatchDims new_bdims(batched->bdims().begin(), batched->bdims().end());
138138
auto actual_bdim = batched->actualDim(dim, /*wrap_dim=*/true);

aten/src/ATen/LegacyBatchedTensorImpl.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -148,10 +148,10 @@ inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) {
148148
}
149149

150150
// Use this to construct a BatchedTensor from a regular Tensor
151-
TORCH_API Tensor makeBatched(const Tensor& tensor, BatchDims bdims);
151+
TORCH_API Tensor makeBatched(Tensor tensor, BatchDims bdims);
152152

153153
// Adds a batch dim to `tensor`, returning a BatchedTensor
154-
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t level, int64_t dim);
154+
TORCH_API Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim);
155155

156156
// Checks if an inplace operation on self and other is "vmap compatible".
157157
// See NOTE: [vmap-incompatible in-place operations] for the definition of this.

aten/src/ATen/core/boxing/KernelFunction_test.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -353,19 +353,17 @@ void expectOutOfPlaceMultiUnboxedCallingWorks(const KernelFunction& func) {
353353
auto t1 = at::zeros({1});
354354
auto t2 = at::zeros({1});
355355

356-
std::tuple<at::Tensor&, at::Tensor&> tup = func.call<
356+
auto [t1_out, t2_out] = func.call<
357357
std::tuple<at::Tensor&, at::Tensor&>, at::Scalar, at::Scalar, at::Tensor&, at::Tensor&
358358
>(dummy, CPU_TEST_SET, s1, s2, t1, t2);
359359

360360
// kernel should have updated out args and returned them in a tuple
361361
EXPECT_EQ(t1.item().toFloat(), 1.0f);
362362
EXPECT_EQ(t2.item().toFloat(), 2.0f);
363363

364-
auto t1_out = std::get<0>(tup);
365364
EXPECT_EQ(t1_out.item().toFloat(), 1.0f);
366365
EXPECT_TRUE(t1_out.is_same(t1));
367366

368-
auto t2_out = std::get<1>(tup);
369367
EXPECT_EQ(t2_out.item().toFloat(), 2.0f);
370368
EXPECT_TRUE(t2_out.is_same(t2));
371369
}

aten/src/ATen/functorch/BatchRulesBinaryOps.cpp

+7-11
Original file line numberDiff line numberDiff line change
@@ -218,47 +218,43 @@ static std::tuple<Tensor, std::optional<int64_t>> masked_select_backward_batch_r
218218

219219
static std::tuple<Tensor, std::optional<int64_t>> cdist_backward_batch_rule(
220220
const Tensor& grad, std::optional<int64_t> grad_bdim,
221-
const Tensor& x1, std::optional<int64_t> x1_bdim,
222-
const Tensor& x2, std::optional<int64_t> x2_bdim,
221+
Tensor x1, std::optional<int64_t> x1_bdim,
222+
Tensor x2, std::optional<int64_t> x2_bdim,
223223
const double p,
224224
const Tensor& cdist, std::optional<int64_t> cdist_bdim) {
225225

226-
auto x1_ = x1;
227226
if (cdist_bdim && !x1_bdim) {
228227
// We need to make sure that x1 has batch dim if cdist has one
229228
// otherwise, we get
230229
// RuntimeError: Function CdistBackward0 returned an invalid gradient at index 1 - got [5]
231230
// but expected shape compatible with [4, 5]
232231
auto bs = cdist.size(*cdist_bdim);
233-
x1_ = ensure_has_bdim(x1, false, bs);
234-
x1_ = x1_.contiguous();
232+
x1 = ensure_has_bdim(x1, false, bs).contiguous();
235233
x1_bdim = 0;
236234
}
237235

238236
// We need to apply the same preprocessing on x1 and x2 as in the forward pass
239237
// _binary_pointwise_batch_rule
240-
auto x12 = _binary_pointwise_helper(x1_, x1_bdim, x2, x2_bdim);
241-
x1_ = std::move(std::get<0>(x12));
242-
auto& x2_ = std::get<1>(x12);
238+
std::tie(x1, x2)= _binary_pointwise_helper(x1, x1_bdim, x2, x2_bdim);
243239

244240
auto grad_ = moveBatchDimToFront(grad, grad_bdim);
245241
if ((x1_bdim || x2_bdim) && !grad_bdim) {
246242
// We need to make sure that grad has batch dim if x1 or x2 have one
247243
// Probably, there is an assumption on the strides.
248244
// Otherwise grad input contains thrash values, e.g. -7.0816e+29, 7.0816e+29
249-
auto bs = get_bdim_size2(x1_, 0, x2_, 0);
245+
auto bs = get_bdim_size2(x1, 0, x2, 0);
250246
grad_ = ensure_has_bdim(grad_, grad_bdim.has_value(), bs);
251247
grad_ = grad_.contiguous();
252248
}
253249

254-
auto out = at::_cdist_backward(grad_, x1_, x2_, p, cdist);
250+
auto out = at::_cdist_backward(grad_, x1, x2, p, cdist);
255251

256252
std::optional<int64_t> out_bdim = std::nullopt;
257253
if (x1_bdim || x2_bdim) {
258254
out_bdim = 0;
259255
}
260256

261-
return std::make_tuple(out, out_bdim);
257+
return std::make_tuple(std::move(out), out_bdim);
262258
}
263259

264260
static void fill__Tensor_batch_rule(

aten/src/ATen/functorch/BatchRulesNorm.cpp

+32-38
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ static Tensor padRight(const Tensor& tensor, std::optional<int64_t> has_bdim, in
4242
}
4343

4444
template<typename F, F Func>
45+
static
4546
std::tuple<Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>,Tensor, std::optional<int64_t>>
4647
batch_norm_batch_rule(
4748
const Tensor& input, std::optional<int64_t> input_bdim,
@@ -70,10 +71,10 @@ batch_norm_batch_rule(
7071
if (!input_bdim && !running_mean_bdim && !running_var_bdim) {
7172
const auto dummy_weight = at::ones(input.size(1), input.options()); // cudnn and miopen require a weight
7273
const auto dummy_bias = at::zeros(input.size(1), input.options()); // without this, get "strides() called on undefined Tensor" on cuda
73-
const auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
74+
auto result = Func(input, dummy_weight, dummy_bias, running_mean_opt, running_var_opt, training, momentum, eps);
7475
result0 = std::get<0>(result).transpose(0, 1); // [C, B, *]
75-
mean = std::get<1>(result);
76-
rstd = std::get<2>(result);
76+
mean = std::move(std::get<1>(result));
77+
rstd = std::move(std::get<2>(result));
7778
} else {
7879
bdim_size = get_bdim_size3(input, input_bdim, running_mean, running_mean_bdim, running_var, running_var_bdim);
7980
auto input_ = moveBatchDimToFront(input, input_bdim);
@@ -95,12 +96,12 @@ batch_norm_batch_rule(
9596

9697
const auto dummy_weight = at::ones(input_.size(1), input_.options()); // cudnn and miopen require a weight
9798
const auto dummy_bias = at::zeros(input_.size(1), input_.options()); // without this, get "strides() called on undefined Tensor" on cuda
98-
const auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
99+
auto result = Func(input_, dummy_weight, dummy_bias, running_mean_, running_var_, training, momentum, eps);
99100
result0 = std::get<0>(result).transpose(0, 1); // [(B0, C), B, *]
101+
mean = std::move(std::get<1>(result));
102+
rstd = std::move(std::get<2>(result));
100103
result0 = reshape_dim_outof(0, bdim_size.value(), result0); // [B0, C, B, *]
101-
mean = std::get<1>(result);
102104
mean = reshape_dim_outof(0, bdim_size.value(), mean); // [B0, C]
103-
rstd = std::get<2>(result);
104105
rstd = reshape_dim_outof(0, bdim_size.value(), rstd); // [B0, C]
105106
}
106107

@@ -124,6 +125,7 @@ batch_norm_batch_rule(
124125
}
125126

126127
template<typename F, F Func>
128+
static
127129
std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bias_batch_rule(
128130
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
129131
const at::Tensor & input, std::optional<int64_t> input_bdim,
@@ -142,9 +144,9 @@ std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bia
142144
TORCH_INTERNAL_ASSERT(!mean_bdim);
143145
TORCH_INTERNAL_ASSERT(!rstd_bdim);
144146
const auto dummy_weight = at::ones(input.size(1), input.options());
145-
const auto result = Func(
147+
auto result =Func(
146148
grad_out, input, dummy_weight, running_mean_opt, running_var_opt, mean, rstd, training, eps, {true, false, false});
147-
return std::make_tuple(std::get<0>(result), std::nullopt);
149+
return {std::move(std::get<0>(result)), std::nullopt};
148150
}
149151

150152
auto grad_out_ = moveBatchDimToFront(grad_out, grad_out_bdim);
@@ -196,6 +198,7 @@ std::tuple<at::Tensor, std::optional<int64_t>> batch_norm_backward_no_weight_bia
196198
}
197199

198200
template<typename F, F Func>
201+
static
199202
std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
200203
const at::Tensor & grad_out,
201204
const at::Tensor & input,
@@ -270,15 +273,15 @@ std::tuple<at::Tensor,at::Tensor,at::Tensor> batch_norm_backward_plumbing(
270273
unwrapTensorAtLevel(grad_normalized_input.transpose(0, 1), cur_level); // [B0, B, C, *]
271274

272275
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
273-
const auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>(
276+
auto results = batch_norm_backward_no_weight_bias_batch_rule<F, Func>(
274277
grad_normalized_input_value, grad_normalized_input_bdim,
275278
input_value, input_bdim,
276279
running_mean_value, running_mean_bdim,
277280
running_var_value, running_var_bdim,
278281
save_mean_value, save_mean_bdim,
279282
save_rstd_value, save_rstd_bdim,
280283
training, eps);
281-
grad_input = makeBatched(std::get<0>(results), std::get<1>(results), cur_level);
284+
grad_input = makeBatched(std::move(std::get<0>(results)), std::get<1>(results), cur_level);
282285
}
283286
return std::make_tuple(grad_input, grad_weight, grad_bias);
284287
}
@@ -312,16 +315,13 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
312315
const auto bdim_size = input_value.size(*input_bdim);
313316

314317
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
315-
const auto result = at::native_group_norm(input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps);
316-
result0 = makeBatched(reshape_dim_outof(0, bdim_size, std::get<0>(result)), 0, cur_level);
317-
mean = makeBatched(reshape_dim_outof(0, bdim_size, std::get<1>(result)), 0, cur_level);
318-
rstd = makeBatched(reshape_dim_outof(0, bdim_size, std::get<2>(result)), 0, cur_level);
318+
std::tie(result0, mean, rstd) = at::native_group_norm(input_, std::nullopt, std::nullopt, N * bdim_size, C, HxW, group, eps);
319+
result0 = makeBatched(reshape_dim_outof(0, bdim_size, result0), 0, cur_level);
320+
mean = makeBatched(reshape_dim_outof(0, bdim_size, mean), 0, cur_level);
321+
rstd = makeBatched(reshape_dim_outof(0, bdim_size, rstd), 0, cur_level);
319322
} else {
320323
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
321-
const auto result = at::native_group_norm(input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps);
322-
result0 = std::get<0>(result);
323-
mean = std::get<1>(result);
324-
rstd = std::get<2>(result);
324+
std::tie(result0, mean, rstd) = at::native_group_norm(input_value, std::nullopt, std::nullopt, N, C, HxW, group, eps);
325325
}
326326

327327
if (weight.defined()) {
@@ -334,10 +334,10 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_plumbing(
334334
result0 = result0 + padded_bias;
335335
}
336336

337-
return std::make_tuple(result0, mean, rstd);
337+
return std::make_tuple(std::move(result0), std::move(mean), std::move(rstd));
338338
}
339339

340-
static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_weight_bias_batch_rule(
340+
static at::Tensor group_norm_backward_no_weight_bias_batch_rule(
341341
const at::Tensor & grad_out, std::optional<int64_t> grad_out_bdim,
342342
const at::Tensor & input, std::optional<int64_t> input_bdim,
343343
const at::Tensor & mean, std::optional<int64_t> mean_bdim,
@@ -359,15 +359,13 @@ static std::tuple<at::Tensor, std::optional<int64_t>> group_norm_backward_no_wei
359359
mean_ = reshape_dim_into(0, 0, mean_); // [B0 * N, G]
360360
rstd_ = reshape_dim_into(0, 0, rstd_); // [B0 * N, G]
361361

362-
const auto result = native_group_norm_backward(
362+
auto result0 = std::get<0>(native_group_norm_backward(
363363
grad_out_.contiguous(),
364364
input_.contiguous(),
365365
mean_.contiguous(),
366366
rstd_.contiguous(),
367-
std::nullopt, N * bdim_size, C, HxW, group, {true, false, false});
368-
auto result0 = std::get<0>(result);
369-
result0 = reshape_dim_outof(0, bdim_size, result0);
370-
return std::make_tuple(result0, 0);
367+
std::nullopt, N * bdim_size, C, HxW, group, {true, false, false}));
368+
return reshape_dim_outof(0, bdim_size, result0);
371369
}
372370

373371
static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
@@ -422,19 +420,19 @@ static std::tuple<Tensor,Tensor,Tensor> native_group_norm_backward_plumbing(
422420
unwrapTensorAtLevel(grad_normalized_input, cur_level);
423421

424422
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
425-
const auto res = group_norm_backward_no_weight_bias_batch_rule(
423+
auto tensor = group_norm_backward_no_weight_bias_batch_rule(
426424
grad_normalized_input_value, grad_normalized_input_bdim,
427425
input_value, input_bdim,
428426
mean_value, mean_bdim,
429427
rstd_value, rstd_bdim,
430428
N, C, HxW, group
431429
);
432-
grad_input = makeBatched(std::get<0>(res), std::get<1>(res), cur_level);
430+
grad_input = makeBatched(std::move(tensor), 0, cur_level);
433431
}
434432
return std::make_tuple(grad_input, grad_weight, grad_bias);
435433
}
436434

437-
C10_ALWAYS_INLINE bool has_same_shape(
435+
static bool has_same_shape(
438436
const Tensor& tensor, std::optional<int64_t> tensor_bdim,
439437
c10::SymIntArrayRef normalized_shape) {
440438
if (!tensor.defined()) {
@@ -457,7 +455,7 @@ C10_ALWAYS_INLINE bool has_same_shape(
457455
return true;
458456
}
459457

460-
C10_ALWAYS_INLINE void check_same_shape(
458+
static C10_ALWAYS_INLINE void check_same_shape(
461459
const Tensor& tensor, std::optional<int64_t> tensor_bdim,
462460
c10::SymIntArrayRef normalized_shape, const std::string& name) {
463461
TORCH_CHECK(has_same_shape(tensor, tensor_bdim, normalized_shape),
@@ -469,7 +467,7 @@ C10_ALWAYS_INLINE void check_same_shape(
469467
}
470468

471469
// Ugh, hard to deduplicate
472-
C10_ALWAYS_INLINE void _check_layer_norm_inputs(
470+
static C10_ALWAYS_INLINE void _check_layer_norm_inputs(
473471
SymIntArrayRef normalized_shape,
474472
const Tensor& weight, std::optional<int64_t> weight_bdim,
475473
const Tensor& bias, std::optional<int64_t> bias_bdim) {
@@ -493,11 +491,9 @@ native_layer_norm_batch_rule(
493491
double eps) {
494492
auto input_ = moveBatchDimToFront(input, input_bdim);
495493
if (!weight_bdim && !bias_bdim) {
496-
const auto result = at::native_layer_norm_symint(input_, normalized_shape, weight_opt, bias_opt, eps);
497-
const auto mean = std::get<1>(result);
498-
const auto rstd = std::get<2>(result);
494+
auto [result0, mean, rstd] = at::native_layer_norm_symint(input_, normalized_shape, weight_opt, bias_opt, eps);
499495
const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
500-
return std::make_tuple(std::get<0>(result), 0, mean, stats_bdim, rstd, stats_bdim);
496+
return std::make_tuple(std::move(result0), 0, std::move(mean), stats_bdim, std::move(rstd), stats_bdim);
501497
}
502498

503499
// See [Note: hacky wrapper removal for optional tensor]
@@ -509,9 +505,7 @@ native_layer_norm_batch_rule(
509505

510506
const auto input_logical_rank = rankWithoutBatchDim(input, input_bdim);
511507
const auto result = at::native_layer_norm_symint(input_, normalized_shape, std::nullopt, std::nullopt, eps);
512-
auto result0 = std::get<0>(result);
513-
const auto mean = std::get<1>(result);
514-
const auto rstd = std::get<2>(result);
508+
auto [result0, mean, rstd] = result;
515509
const auto stats_bdim = compute_stat_bdim(input_bdim, mean);
516510

517511
if (weight.defined()) {
@@ -638,7 +632,7 @@ static std::tuple<at::Tensor,at::Tensor,at::Tensor> native_layer_norm_backward_p
638632
unwrapTensorAtLevel(grad_normalized_input, cur_level);
639633

640634
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
641-
const auto results = native_layer_norm_backward_no_weight_bias_batch_rule(
635+
auto results = native_layer_norm_backward_no_weight_bias_batch_rule(
642636
grad_normalized_input_value, grad_normalized_input_bdim,
643637
input_value, input_bdim,
644638
normalized_shape,

aten/src/ATen/functorch/BatchedTensorImpl.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -171,18 +171,18 @@ void BatchedTensorImpl::shallow_copy_from(const c10::intrusive_ptr<TensorImpl>&
171171
TORCH_CHECK(false, "mutating directly with `.data` under vmap transform is not allowed.");
172172
}
173173

174-
Tensor makeBatched(const Tensor& tensor, int64_t bdim, int64_t level) {
174+
Tensor makeBatched(Tensor tensor, int64_t bdim, int64_t level) {
175175
DispatchKeySet key_set = getKeysToPropagateToWrapper(tensor);
176176
auto* batched = maybeGetBatchedImpl(tensor);
177177
if (batched) {
178178
auto batched_level = batched->level();
179179
TORCH_INTERNAL_ASSERT(level > batched_level, " batched_level: ", batched_level, " level: ", level);
180180
}
181-
return at::detail::make_tensor<BatchedTensorImpl>(key_set, tensor, bdim, level);
181+
return at::detail::make_tensor<BatchedTensorImpl>(key_set, std::move(tensor), bdim, level);
182182
}
183183

184-
Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level) {
185-
return makeBatched(tensor, dim, level);
184+
Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level) {
185+
return makeBatched(std::move(tensor), dim, level);
186186
}
187187

188188
} // namespace at::functorch

aten/src/ATen/functorch/BatchedTensorImpl.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,10 @@ inline std::bitset<kVmapNumLevels> createVmapLevelsBitset(int64_t level) {
144144
}
145145

146146
// Use this to construct a BatchedTensor from a regular Tensor
147-
TORCH_API Tensor makeBatched(const Tensor& tensor, int64_t dim, int64_t level);
147+
TORCH_API Tensor makeBatched(Tensor tensor, int64_t dim, int64_t level);
148148

149149
// Adds a batch dim to `tensor`, returning a BatchedTensor
150-
TORCH_API Tensor addBatchDim(const Tensor& tensor, int64_t dim, int64_t level);
150+
TORCH_API Tensor addBatchDim(Tensor tensor, int64_t dim, int64_t level);
151151

152152
// Certain dispatch keys must be propagated to the BatchedTensor (or, in general,
153153
// any wrapper Tensor subclasses). This is because there are methods on Tensor

0 commit comments

Comments
 (0)