@@ -494,22 +494,47 @@ namespace xt
494
494
using size_type = std::size_t ;
495
495
using value_type = xtl::promote_type_t <typename std::decay_t <CT>::value_type...>;
496
496
497
- template <class S >
498
- inline value_type access (const tuple_type& t, size_type axis, S index ) const
497
+ template <class It >
498
+ inline value_type access (const tuple_type& t, size_type axis, It first, It last ) const
499
499
{
500
- auto match = [&index, axis](auto & arr)
500
+ // trim off extra indices if provided to match behavior of containers
501
+ auto dim_offset = std::distance (first, last) - std::get<0 >(t).dimension ();
502
+ size_t axis_dim = *(first + axis + dim_offset);
503
+ auto match = [&](auto & arr)
501
504
{
502
- if (index[axis] >= arr.shape ()[axis])
505
+ if (axis_dim >= arr.shape ()[axis])
503
506
{
504
- index[axis] -= arr.shape ()[axis];
507
+ axis_dim -= arr.shape ()[axis];
505
508
return false ;
506
509
}
507
510
return true ;
508
511
};
509
512
510
- auto get = [&index ](auto & arr)
513
+ auto get = [&](auto & arr)
511
514
{
512
- return arr[index];
515
+ size_t offset = 0 ;
516
+ const size_t end = arr.dimension ();
517
+ for (size_t i = 0 ; i < end; i++)
518
+ {
519
+ const auto & shape = arr.shape ();
520
+ const size_t stride = std::accumulate (
521
+ shape.begin () + i + 1 ,
522
+ shape.end (),
523
+ 1 ,
524
+ std::multiplies<size_t >()
525
+ );
526
+ if (i == axis)
527
+ {
528
+ offset += axis_dim * stride;
529
+ }
530
+ else
531
+ {
532
+ const auto len = (*(first + i + dim_offset));
533
+ offset += len * stride;
534
+ }
535
+ }
536
+ const auto element = arr.begin () + offset;
537
+ return *element;
513
538
};
514
539
515
540
size_type i = 0 ;
@@ -533,48 +558,68 @@ namespace xt
533
558
using size_type = std::size_t ;
534
559
using value_type = xtl::promote_type_t <typename std::decay_t <CT>::value_type...>;
535
560
536
- template <class S >
537
- inline value_type access (const tuple_type& t, size_type axis, S index ) const
561
+ template <class It >
562
+ inline value_type access (const tuple_type& t, size_type axis, It first, It ) const
538
563
{
539
- auto get_item = [&index ](auto & arr)
564
+ auto get_item = [&](auto & arr)
540
565
{
541
- return arr[index];
566
+ size_t offset = 0 ;
567
+ const size_t end = arr.dimension ();
568
+ size_t after_axis = 0 ;
569
+ for (size_t i = 0 ; i < end; i++)
570
+ {
571
+ if (i == axis)
572
+ {
573
+ after_axis = 1 ;
574
+ }
575
+ const auto & shape = arr.shape ();
576
+ const size_t stride = std::accumulate (
577
+ shape.begin () + i + 1 ,
578
+ shape.end (),
579
+ 1 ,
580
+ std::multiplies<size_t >()
581
+ );
582
+ const auto len = (*(first + i + after_axis));
583
+ offset += len * stride;
584
+ }
585
+ const auto element = arr.begin () + offset;
586
+ return *element;
542
587
};
543
- size_type i = index[axis];
544
- index.erase (index.begin () + std::ptrdiff_t (axis));
588
+ size_type i = *(first + axis);
545
589
return apply<value_type>(i, get_item, t);
546
590
}
547
591
};
548
592
549
593
template <class ... CT>
550
- class vstack_access : private concatenate_access <CT...>,
551
- private stack_access<CT...>
594
+ class vstack_access
552
595
{
553
596
public:
554
597
555
598
using tuple_type = std::tuple<CT...>;
556
599
using size_type = std::size_t ;
557
600
using value_type = xtl::promote_type_t <typename std::decay_t <CT>::value_type...>;
558
601
559
- using concatenate_base = concatenate_access<CT...>;
560
- using stack_base = stack_access<CT...>;
561
-
562
- template <class S >
563
- inline value_type access (const tuple_type& t, size_type axis, S index) const
602
+ template <class It >
603
+ inline value_type access (const tuple_type& t, size_type axis, It first, It last) const
564
604
{
565
605
if (std::get<0 >(t).dimension () == 1 )
566
606
{
567
- return stack_base:: access (t, axis, index );
607
+ return stack. access (t, axis, first, last );
568
608
}
569
609
else
570
610
{
571
- return concatenate_base:: access (t, axis, index );
611
+ return concatonate. access (t, axis, first, last );
572
612
}
573
613
}
614
+
615
+ private:
616
+
617
+ concatenate_access<CT...> concatonate;
618
+ stack_access<CT...> stack;
574
619
};
575
620
576
621
template <template <class ...> class F , class ... CT>
577
- class concatenate_invoker : private F <CT...>
622
+ class concatenate_invoker
578
623
{
579
624
public:
580
625
@@ -592,18 +637,19 @@ namespace xt
592
637
inline value_type operator ()(Args... args) const
593
638
{
594
639
// TODO: avoid memory allocation
595
- return this ->access (m_t , m_axis, xindex ({static_cast <size_type>(args)...}));
640
+ xindex index ({static_cast <size_type>(args)...});
641
+ return access_method.access (m_t , m_axis, index.begin (), index.end ());
596
642
}
597
643
598
644
template <class It >
599
645
inline value_type element (It first, It last) const
600
646
{
601
- // TODO: avoid memory allocation
602
- return this ->access (m_t , m_axis, xindex (first, last));
647
+ return access_method.access (m_t , m_axis, first, last);
603
648
}
604
649
605
650
private:
606
651
652
+ F<CT...> access_method;
607
653
tuple_type m_t ;
608
654
size_type m_axis;
609
655
};
0 commit comments