@@ -181,50 +181,51 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
181
181
182
182
using Index = Einsum::Index<size_t >;
183
183
184
- if constexpr (std::tuple_size<decltype (cs)>::value > 1 ) {
185
- TA_ASSERT (e);
186
- } else if (!e) { // hadamard reduction
187
- auto &[A, B] = AB;
188
- TiledRange trange (range_map[i]);
189
- RangeProduct tiles;
190
- for (auto idx : i) {
191
- tiles *= Range (range_map[idx].tiles_range ());
192
- }
193
- auto pa = A.permutation ;
194
- auto pb = B.permutation ;
195
- for (Index h : H.tiles ) {
196
- if (!C.array .is_local (h)) continue ;
197
- size_t batch = 1 ;
198
- for (size_t i = 0 ; i < h.size (); ++i) {
199
- batch *= H.batch [i].at (h[i]);
184
+ if constexpr (std::tuple_size<decltype (cs)>::value > 1 ) TA_ASSERT (e);
185
+ if constexpr (AreArraySame<ArrayA, ArrayB>) {
186
+ if (!e) { // hadamard reduction
187
+ auto &[A, B] = AB;
188
+ TiledRange trange (range_map[i]);
189
+ RangeProduct tiles;
190
+ for (auto idx : i) {
191
+ tiles *= Range (range_map[idx].tiles_range ());
200
192
}
201
- ResultTensor tile (TiledArray::Range{batch},
202
- typename ResultTensor::value_type{});
203
- for (Index i : tiles) {
204
- // skip this unless both input tiles exist
205
- const auto pahi_inv = apply_inverse (pa, h + i);
206
- const auto pbhi_inv = apply_inverse (pb, h + i);
207
- if (A.array .is_zero (pahi_inv) || B.array .is_zero (pbhi_inv)) continue ;
208
-
209
- auto ai = A.array .find (pahi_inv).get ();
210
- auto bi = B.array .find (pbhi_inv).get ();
211
- if (pa) ai = ai.permute (pa);
212
- if (pb) bi = bi.permute (pb);
213
- auto shape = trange.tile (i);
214
- ai = ai.reshape (shape, batch);
215
- bi = bi.reshape (shape, batch);
216
- for (size_t k = 0 ; k < batch; ++k) {
217
- auto hk = ai.batch (k).dot (bi.batch (k));
218
- tile ({k}) += hk;
193
+ auto pa = A.permutation ;
194
+ auto pb = B.permutation ;
195
+ for (Index h : H.tiles ) {
196
+ if (!C.array .is_local (h)) continue ;
197
+ size_t batch = 1 ;
198
+ for (size_t i = 0 ; i < h.size (); ++i) {
199
+ batch *= H.batch [i].at (h[i]);
219
200
}
201
+ ResultTensor tile (TiledArray::Range{batch},
202
+ typename ResultTensor::value_type{});
203
+ for (Index i : tiles) {
204
+ // skip this unless both input tiles exist
205
+ const auto pahi_inv = apply_inverse (pa, h + i);
206
+ const auto pbhi_inv = apply_inverse (pb, h + i);
207
+ if (A.array .is_zero (pahi_inv) || B.array .is_zero (pbhi_inv)) continue ;
208
+
209
+ auto ai = A.array .find (pahi_inv).get ();
210
+ auto bi = B.array .find (pbhi_inv).get ();
211
+ if (pa) ai = ai.permute (pa);
212
+ if (pb) bi = bi.permute (pb);
213
+ auto shape = trange.tile (i);
214
+ ai = ai.reshape (shape, batch);
215
+ bi = bi.reshape (shape, batch);
216
+ for (size_t k = 0 ; k < batch; ++k) {
217
+ auto hk = ai.batch (k).dot (bi.batch (k));
218
+ tile ({k}) += hk;
219
+ }
220
+ }
221
+ auto pc = C.permutation ;
222
+ auto shape = apply_inverse (pc, C.array .trange ().tile (h));
223
+ tile = tile.reshape (shape);
224
+ if (pc) tile = tile.permute (pc);
225
+ C.array .set (h, tile);
220
226
}
221
- auto pc = C.permutation ;
222
- auto shape = apply_inverse (pc, C.array .trange ().tile (h));
223
- tile = tile.reshape (shape);
224
- if (pc) tile = tile.permute (pc);
225
- C.array .set (h, tile);
227
+ return C.array ;
226
228
}
227
- return C.array ;
228
229
}
229
230
230
231
// generalized contraction
@@ -468,7 +469,8 @@ auto einsum(expressions::TsrExpr<T> A, expressions::TsrExpr<U> B,
468
469
const std::string &cs, World &world = get_default_world()) {
469
470
using ECT = expressions::TsrExpr<const T>;
470
471
using ECU = expressions::TsrExpr<const U>;
471
- return Einsum::einsum (ECT (A), ECU (B), Einsum::idx<T>(cs), world);
472
+ using ResultExprT = std::conditional_t <Einsum::IsArrayToT<T>, T, U>;
473
+ return Einsum::einsum (ECT (A), ECU (B), Einsum::idx<ResultExprT>(cs), world);
472
474
}
473
475
474
476
template <typename T, typename U, typename V>
0 commit comments