Skip to content

Commit f001847

Browse files
committed
einsum tot x tot 'i,j;m,n * j,k;m,n -> i,jk;m,n' unit-test compares results
1 parent a5b253b commit f001847

File tree

1 file changed

+48
-3
lines changed

1 file changed

+48
-3
lines changed

tests/einsum.cpp

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -581,21 +581,24 @@ BOOST_AUTO_TEST_CASE(ij_mn_eq_ij_mn_times_ji_mn) {
581581
}
582582

583583
BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_kj_mn) {
584-
using dist_array_t = DistArray<Tensor<Tensor<double>>, DensePolicy>;
584+
using tot_type = DistArray<Tensor<Tensor<double>>, DensePolicy>;
585585
using matrix_il = TiledArray::detail::matrix_il<Tensor<double>>;
586586
auto& world = TiledArray::get_default_world();
587587

588588
auto random_tot = [](TA::Range const& rng) {
589589
TA::Range inner_rng{7,14};
590590
TA::Tensor<double> t{inner_rng};
591+
std::generate(t.begin(),t.end(),[]()->double{
592+
return TA::detail::MakeRandom<double>::generate_value();
593+
});
591594
TA::Tensor<TA::Tensor<double>> result{rng};
592595
for (auto& e: result) e = t;
593596
return result;
594597
};
595598

596599
auto random_tot_darr = [&random_tot](World& world,
597600
TiledRange const& tr) {
598-
dist_array_t result(world, tr);
601+
tot_type result(world, tr);
599602
for (auto it = result.begin(); it != result.end(); ++it) {
600603
auto tile =
601604
TA::get_default_world().taskq.add(random_tot, it.make_range());
@@ -609,9 +612,51 @@ BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_kj_mn) {
609612

610613
TiledRange rhs_trange{{0, 2, 4, 6}, {0, 2, 5}};
611614
auto rhs = random_tot_darr(world, rhs_trange);
612-
dist_array_t result;
615+
tot_type result;
613616
BOOST_REQUIRE_NO_THROW(
614617
result = einsum(lhs("i,j;m,n"), rhs("k,j;m,n"), "i,j,k;m,n"));
618+
619+
// i,j,k;m,n = i,j;m,n * k,j;m,n
620+
TiledRange ref_result_trange{lhs.trange().dim(0), lhs.trange().dim(1),
621+
rhs.trange().dim(0)};
622+
tot_type ref_result(world, ref_result_trange);
623+
624+
//
625+
// why cannot lhs and rhs be captured by ref?
626+
//
627+
auto make_tile = [lhs, rhs](TA::Range const& rng) {
628+
tot_type::value_type result_tile{rng};
629+
for (auto&& res_ix: result_tile.range()) {
630+
auto i = res_ix[0];
631+
auto j = res_ix[1];
632+
auto k = res_ix[2];
633+
using Ix2 = std::array<decltype(i), 2>;
634+
using Ix3 = std::array<decltype(i), 3>;
635+
636+
auto lhs_tile_ix = lhs.trange().element_to_tile(Ix2{i, j});
637+
auto lhs_tile = lhs.find(lhs_tile_ix).get(/* dowork = */ false);
638+
auto rhs_tile_ix = rhs.trange().element_to_tile(Ix2{k, j});
639+
auto rhs_tile = rhs.find(rhs_tile_ix).get(/* dowork = */ false);
640+
641+
auto& res_el =
642+
result_tile.at_ordinal(result_tile.range().ordinal(Ix3{i, j, k}));
643+
auto const& lhs_el =
644+
lhs_tile.at_ordinal(lhs_tile.range().ordinal(Ix2{i, j}));
645+
auto rhs_el = rhs_tile.at_ordinal(rhs_tile.range().ordinal(Ix2{k, j}));
646+
res_el = lhs_el.mult(rhs_el); // m,n * m,n -> m,n
647+
}
648+
return result_tile;
649+
};
650+
651+
using std::begin;
652+
using std::end;
653+
654+
for (auto it = begin(ref_result); it != end(ref_result); ++it) {
655+
auto tile = TA::get_default_world().taskq.add(make_tile, it.make_range());
656+
*it = tile;
657+
}
658+
bool are_equal = ToTArrayFixture::are_equal<false>(result, ref_result);
659+
BOOST_REQUIRE(are_equal);
615660
}
616661

617662
BOOST_AUTO_TEST_CASE(xxx) {

0 commit comments

Comments
 (0)