Skip to content

Commit c4a03e4

Browse files
janeyx99pytorchmergebot
authored andcommitted
[einsum] keep the promise that we contract left to right (pytorch#87199)
We promise that if path is not defined, we would go left to right. The previous code did not keep that promise as we push'd combined ops to the back of the list. For most use cases this is fine (einsum with 3 or fewer inputs), but we should do what we say. Test plan: Added a print statement to print the sizes of ops we're contracting to see if the order is fixed. Code run: ``` import torch a = torch.rand(1) b = torch.rand(2) c = torch.rand(3) d = torch.rand(4) torch.einsum('a,b,c,d->abcd', a,b,c,d) ``` BEFORE--it does a+b, then c+d, then a+b+c+d, which...is right, but it's not the order specified by the user. ``` /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] ``` WITH THIS CHANGE--it actually goes left to right: a+b, a+b+c, a+b+c+d ``` /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 1, 1, 1]and b: [1, 2, 1, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 1, 1]and b: [1, 1, 3, 1] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] /Users/janeyx/pytorch/torch/functional.py:378: UserWarning: Contracting a: [1, 2, 3, 1]and b: [1, 1, 1, 4] (Triggered internally at /Users/janeyx/pytorch/aten/src/ATen/native/Linear.cpp:507.) return _VF.einsum(equation, operands) # type: ignore[attr-defined] ``` Pull Request resolved: pytorch#87199 Approved by: https://github.com/soulitzer
1 parent d06d569 commit c4a03e4

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

aten/src/ATen/native/Linear.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ Tensor einsum(c10::string_view equation, TensorList operands, at::OptionalIntArr
405405
std::vector<SymInt> label_size(TOTAL_LABELS, 1);
406406
std::vector<SymInt> ell_sizes(ell_num_dim, 1);
407407
std::vector<uint64_t> dim_counts(perm_index, 0);
408-
std::vector<Tensor> ops;
408+
std::deque<Tensor> ops;
409409
for (const auto i : irange(num_ops)) {
410410
auto op = operands[i];
411411
std::vector<int64_t> permutation(perm_index, -1);
@@ -536,7 +536,11 @@ Tensor einsum(c10::string_view equation, TensorList operands, at::OptionalIntArr
536536
b = b.sum(b_dims_to_sum, true);
537537
}
538538

539-
ops.emplace_back(sumproduct_pair(a, b, sum_dims, true));
539+
if (path.has_value()) {
540+
ops.emplace_back(sumproduct_pair(a, b, sum_dims, true));
541+
} else {
542+
ops.emplace_front(sumproduct_pair(a, b, sum_dims, true));
543+
}
540544
}
541545

542546
// Sum out contraction dims

0 commit comments

Comments
 (0)