Skip to content

Commit a5b253b

Browse files
committed
Implement Tot x T (and reverse) generalized contraction.
1 parent a08026c commit a5b253b

File tree

2 files changed

+53
-45
lines changed

2 files changed

+53
-45
lines changed

src/TiledArray/einsum/tiledarray.h

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -181,50 +181,51 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
181181

182182
using Index = Einsum::Index<size_t>;
183183

184-
if constexpr (std::tuple_size<decltype(cs)>::value > 1) {
185-
TA_ASSERT(e);
186-
} else if (!e) { // hadamard reduction
187-
auto &[A, B] = AB;
188-
TiledRange trange(range_map[i]);
189-
RangeProduct tiles;
190-
for (auto idx : i) {
191-
tiles *= Range(range_map[idx].tiles_range());
192-
}
193-
auto pa = A.permutation;
194-
auto pb = B.permutation;
195-
for (Index h : H.tiles) {
196-
if (!C.array.is_local(h)) continue;
197-
size_t batch = 1;
198-
for (size_t i = 0; i < h.size(); ++i) {
199-
batch *= H.batch[i].at(h[i]);
184+
if constexpr (std::tuple_size<decltype(cs)>::value > 1) TA_ASSERT(e);
185+
if constexpr (AreArraySame<ArrayA, ArrayB>) {
186+
if (!e) { // hadamard reduction
187+
auto &[A, B] = AB;
188+
TiledRange trange(range_map[i]);
189+
RangeProduct tiles;
190+
for (auto idx : i) {
191+
tiles *= Range(range_map[idx].tiles_range());
200192
}
201-
ResultTensor tile(TiledArray::Range{batch},
202-
typename ResultTensor::value_type{});
203-
for (Index i : tiles) {
204-
// skip this unless both input tiles exist
205-
const auto pahi_inv = apply_inverse(pa, h + i);
206-
const auto pbhi_inv = apply_inverse(pb, h + i);
207-
if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue;
208-
209-
auto ai = A.array.find(pahi_inv).get();
210-
auto bi = B.array.find(pbhi_inv).get();
211-
if (pa) ai = ai.permute(pa);
212-
if (pb) bi = bi.permute(pb);
213-
auto shape = trange.tile(i);
214-
ai = ai.reshape(shape, batch);
215-
bi = bi.reshape(shape, batch);
216-
for (size_t k = 0; k < batch; ++k) {
217-
auto hk = ai.batch(k).dot(bi.batch(k));
218-
tile({k}) += hk;
193+
auto pa = A.permutation;
194+
auto pb = B.permutation;
195+
for (Index h : H.tiles) {
196+
if (!C.array.is_local(h)) continue;
197+
size_t batch = 1;
198+
for (size_t i = 0; i < h.size(); ++i) {
199+
batch *= H.batch[i].at(h[i]);
219200
}
201+
ResultTensor tile(TiledArray::Range{batch},
202+
typename ResultTensor::value_type{});
203+
for (Index i : tiles) {
204+
// skip this unless both input tiles exist
205+
const auto pahi_inv = apply_inverse(pa, h + i);
206+
const auto pbhi_inv = apply_inverse(pb, h + i);
207+
if (A.array.is_zero(pahi_inv) || B.array.is_zero(pbhi_inv)) continue;
208+
209+
auto ai = A.array.find(pahi_inv).get();
210+
auto bi = B.array.find(pbhi_inv).get();
211+
if (pa) ai = ai.permute(pa);
212+
if (pb) bi = bi.permute(pb);
213+
auto shape = trange.tile(i);
214+
ai = ai.reshape(shape, batch);
215+
bi = bi.reshape(shape, batch);
216+
for (size_t k = 0; k < batch; ++k) {
217+
auto hk = ai.batch(k).dot(bi.batch(k));
218+
tile({k}) += hk;
219+
}
220+
}
221+
auto pc = C.permutation;
222+
auto shape = apply_inverse(pc, C.array.trange().tile(h));
223+
tile = tile.reshape(shape);
224+
if (pc) tile = tile.permute(pc);
225+
C.array.set(h, tile);
220226
}
221-
auto pc = C.permutation;
222-
auto shape = apply_inverse(pc, C.array.trange().tile(h));
223-
tile = tile.reshape(shape);
224-
if (pc) tile = tile.permute(pc);
225-
C.array.set(h, tile);
227+
return C.array;
226228
}
227-
return C.array;
228229
}
229230

230231
// generalized contraction
@@ -468,7 +469,8 @@ auto einsum(expressions::TsrExpr<T> A, expressions::TsrExpr<U> B,
468469
const std::string &cs, World &world = get_default_world()) {
469470
using ECT = expressions::TsrExpr<const T>;
470471
using ECU = expressions::TsrExpr<const U>;
471-
return Einsum::einsum(ECT(A), ECU(B), Einsum::idx<T>(cs), world);
472+
using ResultExprT = std::conditional_t<Einsum::IsArrayToT<T>, T, U>;
473+
return Einsum::einsum(ECT(A), ECU(B), Einsum::idx<ResultExprT>(cs), world);
472474
}
473475

474476
template <typename T, typename U, typename V>

tests/einsum.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -845,7 +845,7 @@ BOOST_AUTO_TEST_CASE(ilkj_nm_eq_ij_mn_times_kl) {
845845
}
846846
}
847847

848-
BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) {
848+
BOOST_AUTO_TEST_CASE(ijk_mn_eq_ij_mn_times_jk) {
849849
using t_type = DistArray<Tensor<double>, SparsePolicy>;
850850
using tot_type = DistArray<Tensor<Tensor<double>>, SparsePolicy>;
851851
using matrix_il = TiledArray::detail::matrix_il<Tensor<double>>;
@@ -877,7 +877,6 @@ BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) {
877877
t_type rhs(world, rhs_trange);
878878
rhs.fill_random();
879879

880-
// TODO compute ref_result
881880
// i,j;m,n * j,k => i,j,k;m,n
882881
TiledRange ref_result_trange{lhs_trange.dim(0), rhs_trange.dim(0),
883882
rhs_trange.dim(1)};
@@ -928,10 +927,17 @@ BOOST_AUTO_TEST_CASE(ikj_mn_eq_ij_mn_times_jk) {
928927
// - general product w.r.t. outer indices
929928
// - involves ToT * T
930929
// tot_type result;
931-
// BOOST_REQUIRE_NO_THROW(result("k,i,j;n,m") = lhs("i,j;m,n") * rhs("j,k"));
930+
// BOOST_REQUIRE_NO_THROW(result("i,j,k;m,n") = lhs("i,j;m,n") * rhs("j,k"));
932931

933932
// will try to make this work
934-
// tot_type out = einsum(lhs("i,j;m,n"), rhs("j,k"), "k,i,j;n,m");
933+
tot_type result = einsum(lhs("i,j;m,n"), rhs("j,k"), "i,j,k;m,n");
934+
bool are_equal = ToTArrayFixture::are_equal<false>(result, ref_result);
935+
BOOST_REQUIRE(are_equal);
936+
{
937+
result = einsum(rhs("j,k"), lhs("i,j;m,n"), "i,j,k;m,n");
938+
are_equal = ToTArrayFixture::are_equal<false>(result, ref_result);
939+
BOOST_REQUIRE(are_equal);
940+
}
935941
}
936942

937943
BOOST_AUTO_TEST_CASE(ij_mn_eq_ji_mn_times_ij) {

0 commit comments

Comments
 (0)