Skip to content

Commit f268c6d

Browse files
authored
Merge pull request #2759 from spectre-ns/concatenate_access
[Optimization] Updated `concatenate_access` and `stack_access` to remove allocations
2 parents 1aa7099 + 4045fb2 commit f268c6d

File tree

2 files changed

+73
-27
lines changed

2 files changed

+73
-27
lines changed

include/xtensor/xbuilder.hpp

Lines changed: 72 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -494,22 +494,47 @@ namespace xt
494494
using size_type = std::size_t;
495495
using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
496496

497-
template <class S>
498-
inline value_type access(const tuple_type& t, size_type axis, S index) const
497+
template <class It>
498+
inline value_type access(const tuple_type& t, size_type axis, It first, It last) const
499499
{
500-
auto match = [&index, axis](auto& arr)
500+
// trim off extra indices if provided to match behavior of containers
501+
auto dim_offset = std::distance(first, last) - std::get<0>(t).dimension();
502+
size_t axis_dim = *(first + axis + dim_offset);
503+
auto match = [&](auto& arr)
501504
{
502-
if (index[axis] >= arr.shape()[axis])
505+
if (axis_dim >= arr.shape()[axis])
503506
{
504-
index[axis] -= arr.shape()[axis];
507+
axis_dim -= arr.shape()[axis];
505508
return false;
506509
}
507510
return true;
508511
};
509512

510-
auto get = [&index](auto& arr)
513+
auto get = [&](auto& arr)
511514
{
512-
return arr[index];
515+
size_t offset = 0;
516+
const size_t end = arr.dimension();
517+
for (size_t i = 0; i < end; i++)
518+
{
519+
const auto& shape = arr.shape();
520+
const size_t stride = std::accumulate(
521+
shape.begin() + i + 1,
522+
shape.end(),
523+
1,
524+
std::multiplies<size_t>()
525+
);
526+
if (i == axis)
527+
{
528+
offset += axis_dim * stride;
529+
}
530+
else
531+
{
532+
const auto len = (*(first + i + dim_offset));
533+
offset += len * stride;
534+
}
535+
}
536+
const auto element = arr.begin() + offset;
537+
return *element;
513538
};
514539

515540
size_type i = 0;
@@ -533,48 +558,68 @@ namespace xt
533558
using size_type = std::size_t;
534559
using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
535560

536-
template <class S>
537-
inline value_type access(const tuple_type& t, size_type axis, S index) const
561+
template <class It>
562+
inline value_type access(const tuple_type& t, size_type axis, It first, It) const
538563
{
539-
auto get_item = [&index](auto& arr)
564+
auto get_item = [&](auto& arr)
540565
{
541-
return arr[index];
566+
size_t offset = 0;
567+
const size_t end = arr.dimension();
568+
size_t after_axis = 0;
569+
for (size_t i = 0; i < end; i++)
570+
{
571+
if (i == axis)
572+
{
573+
after_axis = 1;
574+
}
575+
const auto& shape = arr.shape();
576+
const size_t stride = std::accumulate(
577+
shape.begin() + i + 1,
578+
shape.end(),
579+
1,
580+
std::multiplies<size_t>()
581+
);
582+
const auto len = (*(first + i + after_axis));
583+
offset += len * stride;
584+
}
585+
const auto element = arr.begin() + offset;
586+
return *element;
542587
};
543-
size_type i = index[axis];
544-
index.erase(index.begin() + std::ptrdiff_t(axis));
588+
size_type i = *(first + axis);
545589
return apply<value_type>(i, get_item, t);
546590
}
547591
};
548592

549593
template <class... CT>
550-
class vstack_access : private concatenate_access<CT...>,
551-
private stack_access<CT...>
594+
class vstack_access
552595
{
553596
public:
554597

555598
using tuple_type = std::tuple<CT...>;
556599
using size_type = std::size_t;
557600
using value_type = xtl::promote_type_t<typename std::decay_t<CT>::value_type...>;
558601

559-
using concatenate_base = concatenate_access<CT...>;
560-
using stack_base = stack_access<CT...>;
561-
562-
template <class S>
563-
inline value_type access(const tuple_type& t, size_type axis, S index) const
602+
template <class It>
603+
inline value_type access(const tuple_type& t, size_type axis, It first, It last) const
564604
{
565605
if (std::get<0>(t).dimension() == 1)
566606
{
567-
return stack_base::access(t, axis, index);
607+
return stack.access(t, axis, first, last);
568608
}
569609
else
570610
{
571-
return concatenate_base::access(t, axis, index);
611+
return concatonate.access(t, axis, first, last);
572612
}
573613
}
614+
615+
private:
616+
617+
concatenate_access<CT...> concatonate;
618+
stack_access<CT...> stack;
574619
};
575620

576621
template <template <class...> class F, class... CT>
577-
class concatenate_invoker : private F<CT...>
622+
class concatenate_invoker
578623
{
579624
public:
580625

@@ -592,18 +637,19 @@ namespace xt
592637
inline value_type operator()(Args... args) const
593638
{
594639
// TODO: avoid memory allocation
595-
return this->access(m_t, m_axis, xindex({static_cast<size_type>(args)...}));
640+
xindex index({static_cast<size_type>(args)...});
641+
return access_method.access(m_t, m_axis, index.begin(), index.end());
596642
}
597643

598644
template <class It>
599645
inline value_type element(It first, It last) const
600646
{
601-
// TODO: avoid memory allocation
602-
return this->access(m_t, m_axis, xindex(first, last));
647+
return access_method.access(m_t, m_axis, first, last);
603648
}
604649

605650
private:
606651

652+
F<CT...> access_method;
607653
tuple_type m_t;
608654
size_type m_axis;
609655
};

test/test_xbuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ namespace xt
453453
ASSERT_EQ(11, c(1, 1, 1, 2));
454454
ASSERT_EQ(11, c(1, 1, 2, 2));
455455

456-
auto e = arange(1, 4);
456+
xarray<double> e = arange(1, 4);
457457
xarray<double> f = {2, 3, 4};
458458
xarray<double> k = stack(xtuple(e, f));
459459
xarray<double> l = stack(xtuple(e, f), 1);

0 commit comments

Comments
 (0)