Skip to content

Commit ff4224c

Browse files
committed
Refactor make_ein_reduce_shape to hopefully generate less code and avoid copies (fixes #42).
1 parent 78264ba commit ff4224c

File tree

3 files changed

+58
-43
lines changed

3 files changed

+58
-43
lines changed

ein_reduce.h

+44-42
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121

2222
#include "array.h"
2323

24+
#include <iostream>
25+
2426
namespace nda {
2527

2628
namespace internal {
@@ -73,7 +75,7 @@ struct ein_op_base {
7375
template <class Op, size_t... Is>
7476
struct ein_op : public ein_op_base<ein_op<Op, Is...>> {
7577
Op op;
76-
ein_op(const Op& op) : op(op) {}
78+
ein_op(Op op) : op(std::move(op)) {}
7779

7880
// The largest dimension used by this operand.
7981
static constexpr index_t MaxIndex = sizeof...(Is) == 0 ? -1 : variadic_max(Is...);
@@ -108,7 +110,7 @@ struct ein_op : public ein_op_base<ein_op<Op, Is...>> {
108110
// A unary operation on an Einstein operand.
109111
template <class Op, class Derived>
110112
struct ein_unary_op : public ein_op_base<Derived> {
111-
Op op;
113+
const Op& op;
112114
ein_unary_op(const Op& op) : op(op) {}
113115
static constexpr index_t MaxIndex = Op::MaxIndex;
114116
};
@@ -143,8 +145,8 @@ struct ein_cast_op : public ein_unary_op<Op, ein_cast_op<Type, Op>> {
143145
// A binary operation of two operands.
144146
template <class OpA, class OpB, class Derived>
145147
struct ein_binary_op : public ein_op_base<Derived> {
146-
OpA op_a;
147-
OpB op_b;
148+
const OpA& op_a;
149+
const OpB& op_b;
148150
ein_binary_op(const OpA& a, const OpB& b) : op_a(a), op_b(b) {}
149151
static constexpr index_t MaxIndex = std::max(OpA::MaxIndex, OpB::MaxIndex);
150152
};
@@ -223,20 +225,6 @@ using nda::cast;
223225
using nda::max;
224226
using nda::min;
225227

226-
// Helper to reinterpret a dim/shape with a new stride.
227-
template <index_t NewStride, index_t Min, index_t Extent, index_t Stride>
228-
auto with_stride(const dim<Min, Extent, Stride>& d) {
229-
return dim<Min, Extent, NewStride>(d.min(), d.extent());
230-
}
231-
template <index_t NewStride, class... Dims, size_t... Is>
232-
auto with_stride(const std::tuple<Dims...>& dims, index_sequence<Is...>) {
233-
return std::make_tuple(with_stride<NewStride>(std::get<Is>(dims))...);
234-
}
235-
template <index_t NewStride, class... Dims>
236-
auto with_stride(const std::tuple<Dims...>& dims) {
237-
return with_stride<NewStride>(dims, make_index_sequence<sizeof...(Dims)>());
238-
}
239-
240228
// If multiple operands provide the same dim, we need to reconcile them
241229
// to one dim. If you follow a compiler or runtime error here, your
242230
// Einstein expression tries to address two dimensions that have different
@@ -246,7 +234,7 @@ auto with_stride(const std::tuple<Dims...>& dims) {
246234
template <class Dim0, class... Dims,
247235
class = std::enable_if_t<!any(not_equal(Dim0::Min, Dims::Min)...)>,
248236
class = std::enable_if_t<!any(not_equal(Dim0::Extent, Dims::Extent)...)>>
249-
const Dim0& reconcile_dim(const Dim0& dim0, const Dims&... dims) {
237+
NDARRAY_UNIQUE const Dim0& reconcile_dim(const Dim0& dim0, const Dims&... dims) {
250238
if (dim0.stride() != 0) {
251239
// If the first dim is an output dimension, just require the other dims
252240
// to be in-bounds. This is a slightly relaxed requirement compared to the
@@ -266,59 +254,74 @@ const Dim0& reconcile_dim(const Dim0& dim0, const Dims&... dims) {
266254
inline dim<0, 1, 0> reconcile_dim() { return {}; }
267255

268256
template <class... Dims, size_t... Is>
269-
auto reconcile_dim(const std::tuple<Dims...>& dims, index_sequence<Is...>) {
257+
NDARRAY_UNIQUE auto reconcile_dim(const std::tuple<Dims...>& dims, index_sequence<Is...>) {
270258
return reconcile_dim(std::get<Is>(dims)...);
271259
}
260+
template <class... Dims>
261+
NDARRAY_UNIQUE auto reconcile_dim(const std::tuple<Dims...>& dims) {
262+
return reconcile_dim(dims, make_index_sequence<sizeof...(Dims)>());
263+
}
272264

273265
// Get the shape of an ein_reduce operand, or an empty shape if not an array.
274266
template <class T, class Shape>
275-
const auto& dims_of(const array_ref<T, Shape>& op) {
267+
NDARRAY_UNIQUE const auto& dims_of(const array_ref<T, Shape>& op) {
276268
return op.shape().dims();
277269
}
278270
template <class T>
279-
std::tuple<> dims_of(const T& op) {
271+
NDARRAY_UNIQUE std::tuple<> dims_of(const T& op) {
280272
return std::tuple<>();
281273
}
282274

275+
// Helper to reinterpret a dim/shape with a new stride.
276+
template <index_t NewStride, index_t Min, index_t Extent, index_t Stride>
277+
NDARRAY_UNIQUE auto with_stride(const std::tuple<dim<Min, Extent, Stride>>& maybe_dim) {
278+
const dim<Min, Extent, Stride>& d = std::get<0>(maybe_dim);
279+
return std::make_tuple(dim<Min, Extent, NewStride>(d.min(), d.extent()));
280+
}
281+
template <index_t NewStride>
282+
NDARRAY_UNIQUE std::tuple<> with_stride(std::tuple<> maybe_dim) {
283+
return maybe_dim;
284+
}
285+
283286
// These types are flags that let us overload behavior based on these 3 options.
284287
class is_inferred_shape {};
285288
class is_result_shape {};
286289
class is_operand_shape {};
287290

288291
// Get a dim from an operand, depending on the intended use of the shape.
289292
template <size_t Dim, class Dims, size_t... Is>
290-
auto gather_dim(is_result_shape, const ein_op<Dims, Is...>& op) {
293+
NDARRAY_UNIQUE auto gather_dims(is_result_shape, const ein_op<Dims, Is...>& op) {
291294
// If this is part of the result, we want to keep its strides.
292295
return get_or_empty<index_of<Dim, Is...>()>(dims_of(op.op));
293296
}
294297
template <size_t Dim, class Dims, size_t... Is>
295-
auto gather_dim(is_inferred_shape, const ein_op<Dims, Is...>& op) {
298+
NDARRAY_UNIQUE auto gather_dims(is_inferred_shape, const ein_op<Dims, Is...>& op) {
296299
// For inferred shapes, we want shapes without any constexpr strides, so it can be reshaped.
297-
return get_or_empty<index_of<Dim, Is...>()>(with_stride<dynamic>(dims_of(op.op)));
300+
return with_stride<dynamic>(get_or_empty<index_of<Dim, Is...>()>(dims_of(op.op)));
298301
}
299302
template <size_t Dim, class Dims, size_t... Is>
300-
auto gather_dim(is_operand_shape, const ein_op<Dims, Is...>& op) {
303+
NDARRAY_UNIQUE auto gather_dims(is_operand_shape, const ein_op<Dims, Is...>& op) {
301304
// If this is an operand shape, we want all of its dimensions to be stride 0.
302-
return get_or_empty<index_of<Dim, Is...>()>(with_stride<0>(dims_of(op.op)));
305+
return with_stride<0>(get_or_empty<index_of<Dim, Is...>()>(dims_of(op.op)));
303306
}
304307

305308
template <size_t Dim, class Kind, class Op, class X>
306-
auto gather_dim(Kind kind, const ein_unary_op<Op, X>& op) {
307-
return gather_dim<Dim>(kind, op.op);
309+
NDARRAY_UNIQUE auto gather_dims(Kind kind, const ein_unary_op<Op, X>& op) {
310+
return gather_dims<Dim>(kind, op.op);
308311
}
309312
template <size_t Dim, class Kind, class OpA, class OpB, class X>
310-
auto gather_dim(Kind kind, const ein_binary_op<OpA, OpB, X>& op) {
311-
return std::tuple_cat(gather_dim<Dim>(kind, op.op_a), gather_dim<Dim>(kind, op.op_b));
313+
NDARRAY_UNIQUE auto gather_dims(Kind kind, const ein_binary_op<OpA, OpB, X>& op) {
314+
return std::tuple_cat(gather_dims<Dim>(kind, op.op_a), gather_dims<Dim>(kind, op.op_b));
312315
}
313316

314-
template <size_t Dim, class... Ops>
315-
auto gather_dims(const Ops&... ops) {
316-
auto dims = std::tuple_cat(gather_dim<Dim>(std::get<0>(ops), std::get<1>(ops))...);
317-
return reconcile_dim(dims, make_index_sequence<std::tuple_size<decltype(dims)>::value>());
317+
template <size_t Dim, class Kind0, class Op0, class Kind1, class Op1>
318+
NDARRAY_UNIQUE auto gather_dims(Kind0 kind0, const Op0& op0, Kind1 kind1, const Op1& op1) {
319+
return std::tuple_cat(gather_dims<Dim>(kind0, op0), gather_dims<Dim>(kind1, op1));
318320
}
319-
template <size_t... Is, class... Ops>
320-
auto make_ein_reduce_shape(index_sequence<Is...>, const Ops&... ops) {
321-
return make_shape(gather_dims<Is>(ops...)...);
321+
template <size_t... Is, class... KindAndOps>
322+
NDARRAY_UNIQUE auto make_ein_reduce_shape(
323+
index_sequence<Is...>, const KindAndOps&... kind_and_ops) {
324+
return make_shape(reconcile_dim(gather_dims<Is>(kind_and_ops...))...);
322325
}
323326

324327
} // namespace internal
@@ -331,7 +334,7 @@ auto make_ein_reduce_shape(index_sequence<Is...>, const Ops&... ops) {
331334
* arguments of `a`. See `ein_reduce()` for more details. */
332335
template <size_t... Is, class Op, class = internal::enable_if_callable<Op, decltype(Is)...>>
333336
auto ein(Op op) {
334-
return internal::ein_op<Op, Is...>{op};
337+
return internal::ein_op<Op, Is...>(std::move(op));
335338
}
336339
template <size_t... Is, class T, class Shape, class Alloc,
337340
class = std::enable_if_t<sizeof...(Is) == Shape::rank()>>
@@ -406,8 +409,7 @@ NDARRAY_UNIQUE auto ein_reduce(const Expr& expr) {
406409
// is present. If not, this selects one of the operand dimensions, which are
407410
// given stride 0.
408411
auto reduction_shape = internal::make_ein_reduce_shape(internal::make_index_sequence<LoopRank>(),
409-
std::make_pair(internal::is_result_shape(), expr.op_a),
410-
std::make_pair(internal::is_operand_shape(), expr.op_b));
412+
internal::is_result_shape(), expr.op_a, internal::is_operand_shape(), expr.op_b);
411413

412414
// TODO: Try to compile-time optimize reduction_shape? :)
413415
// This is maybe actually somewhat doable, simply moving the for_each_index
@@ -429,7 +431,7 @@ NDARRAY_UNIQUE auto ein_reduce(const Expr& expr) {
429431
template <size_t... ResultIs, class Expr, class = internal::enable_if_ein_op<Expr>>
430432
auto make_ein_reduce_shape(const Expr& expr) {
431433
auto result_shape = internal::make_ein_reduce_shape(
432-
internal::index_sequence<ResultIs...>(), std::make_pair(internal::is_inferred_shape(), expr));
434+
internal::index_sequence<ResultIs...>(), internal::is_inferred_shape(), expr);
433435
return make_compact(result_shape);
434436
}
435437

examples/linear_algebra/Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
CFLAGS := $(CFLAGS) -O2 -ffast-math -fstrict-aliasing -march=native
1+
CFLAGS := $(CFLAGS) -O2 -march=native -ffast-math -fstrict-aliasing -fno-exceptions -DNDEBUG
22
CXXFLAGS := $(CXXFLAGS) -std=c++14 -Wall
33
LDFLAGS := $(LDFLAGS)
44

test/ein_reduce.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -334,4 +334,17 @@ TEST(ein_reduce_dft) {
334334
}
335335
}
336336

337+
TEST(ein_reduce_no_copy) {
338+
constexpr index_t N = 30;
339+
340+
move_only token;
341+
auto non_copyable_f = [token = std::move(token)](int i) { return i; };
342+
vector<int, N> sum;
343+
ein_reduce(ein<i>(sum) = cast<int>(-ein<i>(std::move(non_copyable_f))));
344+
345+
for (index_t i : sum.x()) {
346+
ASSERT_EQ(sum(i), -i);
347+
}
348+
}
349+
337350
} // namespace nda

0 commit comments

Comments
 (0)