Skip to content

Commit 80790ec

Browse files
janeyx99pytorchmergebot
authored andcommitted
[einsum] Call view instead of sum to remediate MPS regression (pytorch#87135)
Fixes pytorch#87010. It turns out that squeeze is much faster than sum, and view is faster than squeeze, so we should default to that whenever possible. Benchmarking results show that, on MPS, we would be going from the following code taking **29.89ms instead of the current 1466ms, almost a 50x speedup**. ``` q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float) k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float) torch.einsum('b i d, b j d -> b i j', q, k).max().item() ``` And a regular einsum will now take **.506ms instead of 2.76ms.** ``` q = torch.rand(16, 4096, 40, device='mps', dtype=torch.float) k = torch.rand(16, 4096, 40, device='mps', dtype=torch.float) torch.einsum('b i d, b j d -> b i j', q, k) ``` Special thanks to @soulitzer for helping me experiment + figure out how to squash the remaining 5x regression due to squeeze being slower than view!! Pull Request resolved: pytorch#87135 Approved by: https://github.com/soulitzer, https://github.com/malfet, https://github.com/albanD
1 parent c4a03e4 commit 80790ec

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

aten/src/ATen/native/Linear.cpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -545,9 +545,17 @@ Tensor einsum(c10::string_view equation, TensorList operands, at::OptionalIntArr
545545

546546
// Sum out contraction dims
547547
if (perm_index - out_num_dim > 0) {
548-
std::vector<int64_t> sum_dims(perm_index - out_num_dim);
549-
std::iota(sum_dims.begin(), sum_dims.end(), out_num_dim);
550-
ops[0] = ops[0].sum(sum_dims);
548+
if (num_ops > 1) {
549+
auto sizes = ops[0].sym_sizes().vec();
550+
for (auto dim = perm_index - 1; dim >= out_num_dim; --dim) {
551+
sizes.erase(sizes.begin() + dim);
552+
}
553+
return ops[0].view_symint(sizes);
554+
} else {
555+
std::vector<int64_t> sum_dims(perm_index - out_num_dim);
556+
std::iota(sum_dims.begin(), sum_dims.end(), out_num_dim);
557+
return ops[0].sum(sum_dims);
558+
}
551559
}
552560

553561
return ops[0];

0 commit comments

Comments
 (0)