21
21
22
22
#include " array.h"
23
23
24
+ #include < iostream>
25
+
24
26
namespace nda {
25
27
26
28
namespace internal {
@@ -73,7 +75,7 @@ struct ein_op_base {
73
75
template <class Op , size_t ... Is>
74
76
struct ein_op : public ein_op_base <ein_op<Op, Is...>> {
75
77
Op op;
76
- ein_op (const Op& op) : op(op ) {}
78
+ ein_op (Op op) : op(std::move(op) ) {}
77
79
78
80
// The largest dimension used by this operand.
79
81
static constexpr index_t MaxIndex = sizeof ...(Is) == 0 ? -1 : variadic_max(Is...);
@@ -108,7 +110,7 @@ struct ein_op : public ein_op_base<ein_op<Op, Is...>> {
108
110
// A unary operation on an Einstein operand.
109
111
template <class Op , class Derived >
110
112
struct ein_unary_op : public ein_op_base <Derived> {
111
- Op op;
113
+ const Op& op;
112
114
ein_unary_op (const Op& op) : op(op) {}
113
115
static constexpr index_t MaxIndex = Op::MaxIndex;
114
116
};
@@ -143,8 +145,8 @@ struct ein_cast_op : public ein_unary_op<Op, ein_cast_op<Type, Op>> {
143
145
// A binary operation of two operands.
144
146
template <class OpA , class OpB , class Derived >
145
147
struct ein_binary_op : public ein_op_base <Derived> {
146
- OpA op_a;
147
- OpB op_b;
148
+ const OpA& op_a;
149
+ const OpB& op_b;
148
150
ein_binary_op (const OpA& a, const OpB& b) : op_a(a), op_b(b) {}
149
151
static constexpr index_t MaxIndex = std::max(OpA::MaxIndex, OpB::MaxIndex);
150
152
};
@@ -223,20 +225,6 @@ using nda::cast;
223
225
using nda::max;
224
226
using nda::min;
225
227
226
- // Helper to reinterpret a dim/shape with a new stride.
227
- template <index_t NewStride, index_t Min, index_t Extent, index_t Stride>
228
- auto with_stride (const dim<Min, Extent, Stride>& d) {
229
- return dim<Min, Extent, NewStride>(d.min (), d.extent ());
230
- }
231
- template <index_t NewStride, class ... Dims, size_t ... Is>
232
- auto with_stride (const std::tuple<Dims...>& dims, index_sequence<Is...>) {
233
- return std::make_tuple (with_stride<NewStride>(std::get<Is>(dims))...);
234
- }
235
- template <index_t NewStride, class ... Dims>
236
- auto with_stride (const std::tuple<Dims...>& dims) {
237
- return with_stride<NewStride>(dims, make_index_sequence<sizeof ...(Dims)>());
238
- }
239
-
240
228
// If multiple operands provide the same dim, we need to reconcile them
241
229
// to one dim. If you follow a compiler or runtime error here, your
242
230
// Einstein expression tries to address two dimensions that have different
@@ -246,7 +234,7 @@ auto with_stride(const std::tuple<Dims...>& dims) {
246
234
template <class Dim0 , class ... Dims,
247
235
class = std::enable_if_t <!any(not_equal(Dim0::Min, Dims::Min)...)>,
248
236
class = std::enable_if_t <!any(not_equal(Dim0::Extent, Dims::Extent)...)>>
249
- const Dim0& reconcile_dim (const Dim0& dim0, const Dims&... dims) {
237
+ NDARRAY_UNIQUE const Dim0& reconcile_dim (const Dim0& dim0, const Dims&... dims) {
250
238
if (dim0.stride () != 0 ) {
251
239
// If the first dim is an output dimension, just require the other dims
252
240
// to be in-bounds. This is a slightly relaxed requirement compared to the
@@ -266,59 +254,74 @@ const Dim0& reconcile_dim(const Dim0& dim0, const Dims&... dims) {
266
254
inline dim<0 , 1 , 0 > reconcile_dim () { return {}; }
267
255
268
256
template <class ... Dims, size_t ... Is>
269
- auto reconcile_dim (const std::tuple<Dims...>& dims, index_sequence<Is...>) {
257
+ NDARRAY_UNIQUE auto reconcile_dim (const std::tuple<Dims...>& dims, index_sequence<Is...>) {
270
258
return reconcile_dim (std::get<Is>(dims)...);
271
259
}
260
+ template <class ... Dims>
261
+ NDARRAY_UNIQUE auto reconcile_dim (const std::tuple<Dims...>& dims) {
262
+ return reconcile_dim (dims, make_index_sequence<sizeof ...(Dims)>());
263
+ }
272
264
273
265
// Get the shape of an ein_reduce operand, or an empty shape if not an array.
274
266
template <class T , class Shape >
275
- const auto & dims_of (const array_ref<T, Shape>& op) {
267
+ NDARRAY_UNIQUE const auto & dims_of (const array_ref<T, Shape>& op) {
276
268
return op.shape ().dims ();
277
269
}
278
270
template <class T >
279
- std::tuple<> dims_of (const T& op) {
271
+ NDARRAY_UNIQUE std::tuple<> dims_of (const T& op) {
280
272
return std::tuple<>();
281
273
}
282
274
275
+ // Helper to reinterpret a dim/shape with a new stride.
276
+ template <index_t NewStride, index_t Min, index_t Extent, index_t Stride>
277
+ NDARRAY_UNIQUE auto with_stride (const std::tuple<dim<Min, Extent, Stride>>& maybe_dim) {
278
+ const dim<Min, Extent, Stride>& d = std::get<0 >(maybe_dim);
279
+ return std::make_tuple (dim<Min, Extent, NewStride>(d.min (), d.extent ()));
280
+ }
281
+ template <index_t NewStride>
282
+ NDARRAY_UNIQUE std::tuple<> with_stride (std::tuple<> maybe_dim) {
283
+ return maybe_dim;
284
+ }
285
+
283
286
// These types are flags that let us overload behavior based on these 3 options.
284
287
class is_inferred_shape {};
285
288
class is_result_shape {};
286
289
class is_operand_shape {};
287
290
288
291
// Get a dim from an operand, depending on the intended use of the shape.
289
292
template <size_t Dim, class Dims , size_t ... Is>
290
- auto gather_dim (is_result_shape, const ein_op<Dims, Is...>& op) {
293
+ NDARRAY_UNIQUE auto gather_dims (is_result_shape, const ein_op<Dims, Is...>& op) {
291
294
// If this is part of the result, we want to keep its strides.
292
295
return get_or_empty<index_of<Dim, Is...>()>(dims_of (op.op ));
293
296
}
294
297
template <size_t Dim, class Dims , size_t ... Is>
295
- auto gather_dim (is_inferred_shape, const ein_op<Dims, Is...>& op) {
298
+ NDARRAY_UNIQUE auto gather_dims (is_inferred_shape, const ein_op<Dims, Is...>& op) {
296
299
// For inferred shapes, we want shapes without any constexpr strides, so it can be reshaped.
297
- return get_or_empty<index_of<Dim, Is...>()>(with_stride<dynamic >(dims_of (op.op )));
300
+ return with_stride<dynamic>( get_or_empty<index_of<Dim, Is...>()>(dims_of (op.op )));
298
301
}
299
302
template <size_t Dim, class Dims , size_t ... Is>
300
- auto gather_dim (is_operand_shape, const ein_op<Dims, Is...>& op) {
303
+ NDARRAY_UNIQUE auto gather_dims (is_operand_shape, const ein_op<Dims, Is...>& op) {
301
304
// If this is an operand shape, we want all of its dimensions to be stride 0.
302
- return get_or_empty<index_of<Dim, Is...>()>(with_stride< 0 >(dims_of (op.op )));
305
+ return with_stride< 0 >( get_or_empty<index_of<Dim, Is...>()>(dims_of (op.op )));
303
306
}
304
307
305
308
template <size_t Dim, class Kind , class Op , class X >
306
- auto gather_dim (Kind kind, const ein_unary_op<Op, X>& op) {
307
- return gather_dim <Dim>(kind, op.op );
309
+ NDARRAY_UNIQUE auto gather_dims (Kind kind, const ein_unary_op<Op, X>& op) {
310
+ return gather_dims <Dim>(kind, op.op );
308
311
}
309
312
template <size_t Dim, class Kind , class OpA , class OpB , class X >
310
- auto gather_dim (Kind kind, const ein_binary_op<OpA, OpB, X>& op) {
311
- return std::tuple_cat (gather_dim <Dim>(kind, op.op_a ), gather_dim <Dim>(kind, op.op_b ));
313
+ NDARRAY_UNIQUE auto gather_dims (Kind kind, const ein_binary_op<OpA, OpB, X>& op) {
314
+ return std::tuple_cat (gather_dims <Dim>(kind, op.op_a ), gather_dims <Dim>(kind, op.op_b ));
312
315
}
313
316
314
- template <size_t Dim, class ... Ops>
315
- auto gather_dims (const Ops&... ops) {
316
- auto dims = std::tuple_cat (gather_dim<Dim>(std::get<0 >(ops), std::get<1 >(ops))...);
317
- return reconcile_dim (dims, make_index_sequence<std::tuple_size<decltype (dims)>::value>());
317
+ template <size_t Dim, class Kind0 , class Op0 , class Kind1 , class Op1 >
318
+ NDARRAY_UNIQUE auto gather_dims (Kind0 kind0, const Op0& op0, Kind1 kind1, const Op1& op1) {
319
+ return std::tuple_cat (gather_dims<Dim>(kind0, op0), gather_dims<Dim>(kind1, op1));
318
320
}
319
- template <size_t ... Is, class ... Ops>
320
- auto make_ein_reduce_shape (index_sequence<Is...>, const Ops&... ops) {
321
- return make_shape (gather_dims<Is>(ops...)...);
321
+ template <size_t ... Is, class ... KindAndOps>
322
+ NDARRAY_UNIQUE auto make_ein_reduce_shape (
323
+ index_sequence<Is...>, const KindAndOps&... kind_and_ops) {
324
+ return make_shape (reconcile_dim (gather_dims<Is>(kind_and_ops...))...);
322
325
}
323
326
324
327
} // namespace internal
@@ -331,7 +334,7 @@ auto make_ein_reduce_shape(index_sequence<Is...>, const Ops&... ops) {
331
334
* arguments of `a`. See `ein_reduce()` for more details. */
332
335
template <size_t ... Is, class Op , class = internal::enable_if_callable<Op, decltype(Is)...>>
333
336
auto ein (Op op) {
334
- return internal::ein_op<Op, Is...>{op} ;
337
+ return internal::ein_op<Op, Is...>( std::move (op)) ;
335
338
}
336
339
template <size_t ... Is, class T , class Shape , class Alloc ,
337
340
class = std::enable_if_t <sizeof ...(Is) == Shape::rank()>>
@@ -406,8 +409,7 @@ NDARRAY_UNIQUE auto ein_reduce(const Expr& expr) {
406
409
// is present. If not, this selects one of the operand dimensions, which are
407
410
// given stride 0.
408
411
auto reduction_shape = internal::make_ein_reduce_shape (internal::make_index_sequence<LoopRank>(),
409
- std::make_pair (internal::is_result_shape (), expr.op_a ),
410
- std::make_pair (internal::is_operand_shape (), expr.op_b ));
412
+ internal::is_result_shape (), expr.op_a , internal::is_operand_shape (), expr.op_b );
411
413
412
414
// TODO: Try to compile-time optimize reduction_shape? :)
413
415
// This is maybe actually somewhat doable, simply moving the for_each_index
@@ -429,7 +431,7 @@ NDARRAY_UNIQUE auto ein_reduce(const Expr& expr) {
429
431
template <size_t ... ResultIs, class Expr , class = internal::enable_if_ein_op<Expr>>
430
432
auto make_ein_reduce_shape (const Expr& expr) {
431
433
auto result_shape = internal::make_ein_reduce_shape (
432
- internal::index_sequence<ResultIs...>(), std::make_pair ( internal::is_inferred_shape (), expr) );
434
+ internal::index_sequence<ResultIs...>(), internal::is_inferred_shape (), expr);
433
435
return make_compact (result_shape);
434
436
}
435
437
0 commit comments