@@ -581,21 +581,24 @@ BOOST_AUTO_TEST_CASE(ij_mn_eq_ij_mn_times_ji_mn) {
581
581
}
582
582
583
583
BOOST_AUTO_TEST_CASE (ijk_mn_eq_ij_mn_times_kj_mn) {
584
- using dist_array_t = DistArray<Tensor<Tensor<double >>, DensePolicy>;
584
+ using tot_type = DistArray<Tensor<Tensor<double >>, DensePolicy>;
585
585
using matrix_il = TiledArray::detail::matrix_il<Tensor<double >>;
586
586
auto & world = TiledArray::get_default_world ();
587
587
588
588
auto random_tot = [](TA::Range const & rng) {
589
589
TA::Range inner_rng{7 ,14 };
590
590
TA::Tensor<double > t{inner_rng};
591
+ std::generate (t.begin (),t.end (),[]()->double {
592
+ return TA::detail::MakeRandom<double >::generate_value ();
593
+ });
591
594
TA::Tensor<TA::Tensor<double >> result{rng};
592
595
for (auto & e: result) e = t;
593
596
return result;
594
597
};
595
598
596
599
auto random_tot_darr = [&random_tot](World& world,
597
600
TiledRange const & tr) {
598
- dist_array_t result (world, tr);
601
+ tot_type result (world, tr);
599
602
for (auto it = result.begin (); it != result.end (); ++it) {
600
603
auto tile =
601
604
TA::get_default_world ().taskq .add (random_tot, it.make_range ());
@@ -609,9 +612,51 @@ BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_kj_mn) {
609
612
610
613
TiledRange rhs_trange{{0 , 2 , 4 , 6 }, {0 , 2 , 5 }};
611
614
auto rhs = random_tot_darr (world, rhs_trange);
612
- dist_array_t result;
615
+ tot_type result;
613
616
BOOST_REQUIRE_NO_THROW (
614
617
result = einsum (lhs (" i,j;m,n" ), rhs (" k,j;m,n" ), " i,j,k;m,n" ));
618
+
619
+ // i,j,k;m,n = i,j;m,n * k,j;m,n
620
+ TiledRange ref_result_trange{lhs.trange ().dim (0 ), lhs.trange ().dim (1 ),
621
+ rhs.trange ().dim (0 )};
622
+ tot_type ref_result (world, ref_result_trange);
623
+
624
+ //
625
+ // why cannot lhs and rhs be captured by ref?
626
+ //
627
+ auto make_tile = [lhs, rhs](TA::Range const & rng) {
628
+ tot_type::value_type result_tile{rng};
629
+ for (auto && res_ix: result_tile.range ()) {
630
+ auto i = res_ix[0 ];
631
+ auto j = res_ix[1 ];
632
+ auto k = res_ix[2 ];
633
+ using Ix2 = std::array<decltype (i), 2 >;
634
+ using Ix3 = std::array<decltype (i), 3 >;
635
+
636
+ auto lhs_tile_ix = lhs.trange ().element_to_tile (Ix2{i, j});
637
+ auto lhs_tile = lhs.find (lhs_tile_ix).get (/* dowork = */ false );
638
+ auto rhs_tile_ix = rhs.trange ().element_to_tile (Ix2{k, j});
639
+ auto rhs_tile = rhs.find (rhs_tile_ix).get (/* dowork = */ false );
640
+
641
+ auto & res_el =
642
+ result_tile.at_ordinal (result_tile.range ().ordinal (Ix3{i, j, k}));
643
+ auto const & lhs_el =
644
+ lhs_tile.at_ordinal (lhs_tile.range ().ordinal (Ix2{i, j}));
645
+ auto rhs_el = rhs_tile.at_ordinal (rhs_tile.range ().ordinal (Ix2{k, j}));
646
+ res_el = lhs_el.mult (rhs_el); // m,n * m,n -> m,n
647
+ }
648
+ return result_tile;
649
+ };
650
+
651
+ using std::begin;
652
+ using std::end;
653
+
654
+ for (auto it = begin (ref_result); it != end (ref_result); ++it) {
655
+ auto tile = TA::get_default_world ().taskq .add (make_tile, it.make_range ());
656
+ *it = tile;
657
+ }
658
+ bool are_equal = ToTArrayFixture::are_equal<false >(result, ref_result);
659
+ BOOST_REQUIRE (are_equal);
615
660
}
616
661
617
662
BOOST_AUTO_TEST_CASE (xxx) {
0 commit comments