Skip to content

Commit e0adee8

Browse files
committed
[flang] Correct folding of CSHIFT and EOSHIFT for DIM>1
The algorithm was wrong for higher dimensions, and so were the expected test results. Rework. Differential Revision: https://reviews.llvm.org/D127018
1 parent 47ec8b5 commit e0adee8

File tree

3 files changed

+71
-49
lines changed

3 files changed

+71
-49
lines changed

flang/lib/Evaluate/fold-implementation.h

+67-45
Original file line numberDiff line numberDiff line change
@@ -613,26 +613,33 @@ template <typename T> Expr<T> Folder<T>::CSHIFT(FunctionRef<T> &&funcRef) {
613613
}
614614
if (ok) {
615615
std::vector<Scalar<T>> resultElements;
616-
ConstantSubscripts arrayAt{array->lbounds()};
617-
ConstantSubscript dimLB{arrayAt[zbDim]};
616+
ConstantSubscripts arrayLB{array->lbounds()};
617+
ConstantSubscripts arrayAt{arrayLB};
618+
ConstantSubscript &dimIndex{arrayAt[zbDim]};
619+
ConstantSubscript dimLB{dimIndex}; // initial value
618620
ConstantSubscript dimExtent{array->shape()[zbDim]};
619-
ConstantSubscripts shiftAt{shift->lbounds()};
620-
for (auto n{GetSize(array->shape())}; n > 0; n -= dimExtent) {
621-
ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()};
622-
ConstantSubscript zbDimIndex{shiftCount % dimExtent};
623-
if (zbDimIndex < 0) {
624-
zbDimIndex += dimExtent;
625-
}
626-
for (ConstantSubscript j{0}; j < dimExtent; ++j) {
627-
arrayAt[zbDim] = dimLB + zbDimIndex;
628-
resultElements.push_back(array->At(arrayAt));
629-
if (++zbDimIndex == dimExtent) {
630-
zbDimIndex = 0;
621+
ConstantSubscripts shiftLB{shift->lbounds()};
622+
for (auto n{GetSize(array->shape())}; n > 0; --n) {
623+
ConstantSubscript origDimIndex{dimIndex};
624+
ConstantSubscripts shiftAt;
625+
if (shift->Rank() > 0) {
626+
int k{0};
627+
for (int j{0}; j < rank; ++j) {
628+
if (j != zbDim) {
629+
shiftAt.emplace_back(shiftLB[k++] + arrayAt[j] - arrayLB[j]);
630+
}
631631
}
632632
}
633-
arrayAt[zbDim] = dimLB + std::max<ConstantSubscript>(dimExtent, 1) - 1;
633+
ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()};
634+
dimIndex = dimLB + ((dimIndex - dimLB + shiftCount) % dimExtent);
635+
if (dimIndex < dimLB) {
636+
dimIndex += dimExtent;
637+
} else if (dimIndex >= dimLB + dimExtent) {
638+
dimIndex -= dimExtent;
639+
}
640+
resultElements.push_back(array->At(arrayAt));
641+
dimIndex = origDimIndex;
634642
array->IncrementSubscripts(arrayAt);
635-
shift->IncrementSubscripts(shiftAt);
636643
}
637644
return Expr<T>{PackageConstant<T>(
638645
std::move(resultElements), *array, array->shape())};
@@ -714,42 +721,57 @@ template <typename T> Expr<T> Folder<T>::EOSHIFT(FunctionRef<T> &&funcRef) {
714721
}
715722
if (ok) {
716723
std::vector<Scalar<T>> resultElements;
717-
ConstantSubscripts arrayAt{array->lbounds()};
718-
ConstantSubscript dimLB{arrayAt[zbDim]};
724+
ConstantSubscripts arrayLB{array->lbounds()};
725+
ConstantSubscripts arrayAt{arrayLB};
726+
ConstantSubscript &dimIndex{arrayAt[zbDim]};
727+
ConstantSubscript dimLB{dimIndex}; // initial value
719728
ConstantSubscript dimExtent{array->shape()[zbDim]};
720-
ConstantSubscripts shiftAt{shift->lbounds()};
721-
ConstantSubscripts boundaryAt;
729+
ConstantSubscripts shiftLB{shift->lbounds()};
730+
ConstantSubscripts boundaryLB;
722731
if (boundary) {
723-
boundaryAt = boundary->lbounds();
732+
boundaryLB = boundary->lbounds();
724733
}
725-
for (auto n{GetSize(array->shape())}; n > 0; n -= dimExtent) {
734+
for (auto n{GetSize(array->shape())}; n > 0; --n) {
735+
ConstantSubscript origDimIndex{dimIndex};
736+
ConstantSubscripts shiftAt;
737+
if (shift->Rank() > 0) {
738+
int k{0};
739+
for (int j{0}; j < rank; ++j) {
740+
if (j != zbDim) {
741+
shiftAt.emplace_back(shiftLB[k++] + arrayAt[j] - arrayLB[j]);
742+
}
743+
}
744+
}
726745
ConstantSubscript shiftCount{shift->At(shiftAt).ToInt64()};
727-
for (ConstantSubscript j{0}; j < dimExtent; ++j) {
728-
ConstantSubscript zbAt{shiftCount + j};
729-
if (zbAt >= 0 && zbAt < dimExtent) {
730-
arrayAt[zbDim] = dimLB + zbAt;
731-
resultElements.push_back(array->At(arrayAt));
732-
} else if (boundary) {
733-
resultElements.push_back(boundary->At(boundaryAt));
734-
} else if constexpr (T::category == TypeCategory::Integer ||
735-
T::category == TypeCategory::Real ||
736-
T::category == TypeCategory::Complex ||
737-
T::category == TypeCategory::Logical) {
738-
resultElements.emplace_back();
739-
} else if constexpr (T::category == TypeCategory::Character) {
740-
auto len{static_cast<std::size_t>(array->LEN())};
741-
typename Scalar<T>::value_type space{' '};
742-
resultElements.emplace_back(len, space);
743-
} else {
744-
DIE("no derived type boundary");
746+
dimIndex += shiftCount;
747+
if (dimIndex >= dimLB && dimIndex < dimLB + dimExtent) {
748+
resultElements.push_back(array->At(arrayAt));
749+
} else if (boundary) {
750+
ConstantSubscripts boundaryAt;
751+
if (boundary->Rank() > 0) {
752+
for (int j{0}; j < rank; ++j) {
753+
int k{0};
754+
if (j != zbDim) {
755+
boundaryAt.emplace_back(
756+
boundaryLB[k++] + arrayAt[j] - arrayLB[j]);
757+
}
758+
}
745759
}
760+
resultElements.push_back(boundary->At(boundaryAt));
761+
} else if constexpr (T::category == TypeCategory::Integer ||
762+
T::category == TypeCategory::Real ||
763+
T::category == TypeCategory::Complex ||
764+
T::category == TypeCategory::Logical) {
765+
resultElements.emplace_back();
766+
} else if constexpr (T::category == TypeCategory::Character) {
767+
auto len{static_cast<std::size_t>(array->LEN())};
768+
typename Scalar<T>::value_type space{' '};
769+
resultElements.emplace_back(len, space);
770+
} else {
771+
DIE("no derived type boundary");
746772
}
747-
arrayAt[zbDim] = dimLB + std::max<ConstantSubscript>(dimExtent, 1) - 1;
773+
dimIndex = origDimIndex;
748774
array->IncrementSubscripts(arrayAt);
749-
shift->IncrementSubscripts(shiftAt);
750-
if (boundary) {
751-
boundary->IncrementSubscripts(boundaryAt);
752-
}
753775
}
754776
return Expr<T>{PackageConstant<T>(
755777
std::move(resultElements), *array, array->shape())};

flang/test/Evaluate/folding23.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ module m
99
logical, parameter :: test_eoshift_3 = all(eoshift([1., 2., 3.], 1) == [2., 3., 0.])
1010
logical, parameter :: test_eoshift_4 = all(eoshift(['ab', 'cd', 'ef'], -1, 'x') == ['x ', 'ab', 'cd'])
1111
logical, parameter :: test_eoshift_5 = all([eoshift(arr, 1, dim=1)] == [2, 0, 4, 0, 6, 0])
12-
logical, parameter :: test_eoshift_6 = all([eoshift(arr, 1, dim=2)] == [3, 5, 0, 4, 6, 0])
12+
logical, parameter :: test_eoshift_6 = all([eoshift(arr, 1, dim=2)] == [3, 4, 5, 6, 0, 0])
1313
logical, parameter :: test_eoshift_7 = all([eoshift(arr, [1, -1, 0])] == [2, 0, 0, 3, 5, 6])
14-
logical, parameter :: test_eoshift_8 = all([eoshift(arr, [1, -1], dim=2)] == [3, 5, 0, 0, 2, 4])
14+
logical, parameter :: test_eoshift_8 = all([eoshift(arr, [1, -1], dim=2)] == [3, 0, 5, 2, 0, 4])
1515
end module

flang/test/Evaluate/folding27.f90

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ module m
99
logical, parameter :: test_cshift_3 = all(cshift([1, 2, 3], 4) == [2, 3, 1])
1010
logical, parameter :: test_cshift_4 = all(cshift([1, 2, 3], -1) == [3, 1, 2])
1111
logical, parameter :: test_cshift_5 = all([cshift(arr, 1, dim=1)] == [2, 1, 4, 3, 6, 5])
12-
logical, parameter :: test_cshift_6 = all([cshift(arr, 1, dim=2)] == [3, 5, 1, 4, 6, 2])
12+
logical, parameter :: test_cshift_6 = all([cshift(arr, 1, dim=2)] == [3, 4, 5, 6, 1, 2])
1313
logical, parameter :: test_cshift_7 = all([cshift(arr, [1, 2, 3])] == [2, 1, 3, 4, 6, 5])
14-
logical, parameter :: test_cshift_8 = all([cshift(arr, [1, 2], dim=2)] == [3, 5, 1, 6, 2, 4])
14+
logical, parameter :: test_cshift_8 = all([cshift(arr, [1, 2], dim=2)] == [3, 6, 5, 2, 1, 4])
1515
end module

0 commit comments

Comments
 (0)