Skip to content

Commit f4bba8e

Browse files
committed
Make shape comparison flags more explicit.
1 parent f001847 commit f4bba8e

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

tests/einsum.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_kj_mn) {
655655
auto tile = TA::get_default_world().taskq.add(make_tile, it.make_range());
656656
*it = tile;
657657
}
658-
bool are_equal = ToTArrayFixture::are_equal<false>(result, ref_result);
658+
bool are_equal = ToTArrayFixture::are_equal<ShapeComp::False>(result, ref_result);
659659
BOOST_REQUIRE(are_equal);
660660
}
661661

@@ -879,13 +879,13 @@ BOOST_AUTO_TEST_CASE(ilkj_nm_eq_ij_mn_times_kl) {
879879
tot_type result;
880880
BOOST_REQUIRE_NO_THROW(result("i,l,k,j;n,m") = lhs("i,j;m,n") * rhs("k,l"));
881881

882-
const bool are_equal = ToTArrayFixture::are_equal<false>(result, ref_result);
882+
const bool are_equal = ToTArrayFixture::are_equal<ShapeComp::False>(result, ref_result);
883883
BOOST_CHECK(are_equal);
884884

885885
{ // reverse the order
886886
tot_type result;
887887
BOOST_REQUIRE_NO_THROW(result("i,l,k,j;n,m") = rhs("k,l") * lhs("i,j;m,n"));
888-
const bool are_equal = ToTArrayFixture::are_equal<false>(result, ref_result);
888+
const bool are_equal = ToTArrayFixture::are_equal<ShapeComp::False>(result, ref_result);
889889
BOOST_CHECK(are_equal);
890890
}
891891
}
@@ -976,11 +976,11 @@ BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_jk) {
976976

977977
// will try to make this work
978978
tot_type result = einsum(lhs("i,j;m,n"), rhs("j,k"), "i,j,k;m,n");
979-
bool are_equal = ToTArrayFixture::are_equal<false>(result, ref_result);
979+
bool are_equal = ToTArrayFixture::are_equal<ShapeComp::False>(result, ref_result);
980980
BOOST_REQUIRE(are_equal);
981981
{
982982
result = einsum(rhs("j,k"), lhs("i,j;m,n"), "i,j,k;m,n");
983-
are_equal = ToTArrayFixture::are_equal<false>(result, ref_result);
983+
are_equal = ToTArrayFixture::are_equal<ShapeComp::False>(result, ref_result);
984984
BOOST_REQUIRE(are_equal);
985985
}
986986
}
@@ -1073,7 +1073,7 @@ BOOST_AUTO_TEST_CASE(ij_mn_eq_ji_mn_times_ij) {
10731073
tot_type result;
10741074
BOOST_REQUIRE_NO_THROW(result("i,j;m,n") = lhs("j,i;m,n") * rhs("i,j"));
10751075

1076-
const bool are_equal = ToTArrayFixture::are_equal<false>(result, ref_result);
1076+
const bool are_equal = ToTArrayFixture::are_equal<ShapeComp::False>(result, ref_result);
10771077
BOOST_CHECK(are_equal);
10781078
}
10791079

tests/tot_array_fixture.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ using input_archive_type = madness::archive::BinaryFstreamInputArchive;
8888
// Type of an output archive
8989
using output_archive_type = madness::archive::BinaryFstreamOutputArchive;
9090

91+
enum class ShapeComp {
92+
True,
93+
False
94+
};
95+
96+
9197
/*
9298
*
9399
* When generating arrays containing tensors of tensors (ToT) we adopt simple
@@ -238,7 +244,7 @@ struct ToTArrayFixture {
238244
*
239245
* TODO: pmap comparisons
240246
*/
241-
template <bool ShapeCmp = true, typename LHSTileType, typename LHSPolicy,
247+
template <ShapeComp ShapeCompFlag = ShapeComp::True, typename LHSTileType, typename LHSPolicy,
242248
typename RHSTileType, typename RHSPolicy>
243249
static bool are_equal(const DistArray<LHSTileType, LHSPolicy>& lhs,
244250
const DistArray<RHSTileType, RHSPolicy>& rhs) {
@@ -254,7 +260,7 @@ struct ToTArrayFixture {
254260
if (&lhs.world() != &rhs.world()) return false;
255261

256262
// Same shape?
257-
if constexpr (ShapeCmp)
263+
if constexpr (ShapeCompFlag == ShapeComp::True)
258264
if (lhs.shape() != rhs.shape()) return false;
259265

260266
// Same pmap?

0 commit comments

Comments
 (0)