-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmultiArray.h
184 lines (161 loc) · 5.64 KB
/
multiArray.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
/*
@copyright Russell Standish 2019
@author Russell Standish
This file is part of Classdesc
Open source licensed under the MIT license. See LICENSE for details.
*/
#ifndef CLASSDESC_MULTIARRAY_H
#define CLASSDESC_MULTIARRAY_H
#include <assert.h>
namespace classdesc
{
template <class T, int Rank> class MultiArray;
template <class T, int R> void advance(MultiArray<T,R>& a, std::ptrdiff_t n);
/// base classes for metaprogramming
struct MultiArrayBase{};
struct MultiArrayIterator{};
// turn a pointer back into multidimensional array
template <class T, int Rank>
class MultiArray: public MultiArrayBase
{
T* m_data;
size_t dim[Rank];
size_t m_stride=1;
friend void advance<T,Rank>(MultiArray<T,Rank>&,std::ptrdiff_t);
#if defined(__cplusplus) && __cplusplus>=201703L
void constructDim(size_t& d, size_t dd)
{
dim[d++]=dd;
m_stride*=dd;
}
template <class... Args>
void constructDims(size_t d, Args... args)
{
m_stride=1;
(constructDim(d,args),...);
// fold expression multiplies m_stride by one too many dimensions
if (d>0) m_stride/=dim[d-1];
}
#else
void constructDims(size_t d, size_t d1)
{
assert(d==Rank-1);
dim[d]=d1;
}
template <class... Args>
void constructDims(size_t d, size_t d1, Args... args)
{
dim[d]=d1;
m_stride*=d1;
constructDims(d+1, args...);
}
#endif
public:
typedef MultiArray<T,Rank-1> value_type;
typedef size_t size_type;
static const int rank=Rank;
/// Create a MultiArray given data and dimensions passed as arguments
///
template <class... Args>
MultiArray(T* data, Args... args): m_data(data)
{
constructDims(0, args...);
}
/// Create a MultiArray given data and vector of dimensions:
/// @param must be of length \prod_i a_dim[i]
/// @param a_dim must be of length Rank
MultiArray(T* data, const size_t a_dim[]): m_data(data)
{
memcpy(dim, a_dim, sizeof(dim));
}
size_t size() const {return dim[Rank-1];}
size_t stride() const {return m_stride;}
const MultiArray<T,Rank-1> operator[](size_t i) const {
return MultiArray<T,Rank-1>(m_data+i*m_stride, dim);
}
MultiArray<T,Rank-1> operator[](size_t i) {
return MultiArray<T,Rank-1>(m_data+i*m_stride, dim);
}
/// return if this refers to the same memory location as x
bool same(const MultiArray& x) const {return m_data==x.m_data;}
struct iterator: public MultiArrayIterator
{
using difference_type=std::ptrdiff_t;
using value_type=MultiArray<T,Rank-1>;
using pointer=MultiArray<T,Rank-1>*;
using reference=MultiArray<T,Rank-1>&;
using iterator_category=std::random_access_iterator_tag;
using underlying_type=T;
static const int rank=Rank;
MultiArray<T,Rank-1> array;
size_t stride;
iterator(const MultiArray& array): array(array.m_data,array.dim),
stride(array.stride()) {}
iterator operator++() {advance(array,stride); return *this;}
iterator operator++(int)
{auto tmp=*this; advance(array,stride); return tmp;}
iterator operator--() {advance(array,-stride); return *this;}
iterator operator--(int)
{auto tmp=*this; advance(array,-stride); return tmp;}
iterator operator+=(size_t i) {advance(array,i*stride); return *this;}
MultiArray<T,Rank-1>& operator*() {return array;}
bool operator==(const iterator& x) const {return array.same(x.array);}
bool operator!=(const iterator& x) const {return !operator==(x);}
};
struct const_iterator: public iterator
{
const_iterator(const MultiArray& array): iterator(array) {}
const MultiArray<T,Rank-1>& operator*() {return iterator::operator*();}
};
iterator begin() {return MultiArray(*this);}
iterator end() {auto tmp=begin(); tmp+=dim[Rank-1]; return tmp;}
const_iterator begin() const {return MultiArray(*this);}
const_iterator end() const {auto tmp=begin(); tmp+=dim[Rank-1]; return tmp;}
};
template <class T>
class MultiArray<T, 1>: public MultiArrayBase
{
T* m_data;
size_t m_size;
friend void advance<T,1>(MultiArray<T,1>&,std::ptrdiff_t);
public:
typedef T value_type;
typedef size_t size_type;
static const int rank=1;
MultiArray(T* data, size_t size): m_data(data), m_size(size) {}
MultiArray(T* data, const size_t size[]): m_data(data), m_size(size[0]) {}
T& operator[](size_t i) {return m_data[i];}
const T& operator[](size_t i) const {return m_data[i];}
typedef T* iterator;
typedef const T* const_iterator;
T* begin() {return m_data;}
T* end() {return m_data+m_size;}
const T* begin() const {return m_data;}
const T* end() const {return m_data+m_size;}
size_t size() const {return m_size;}
bool same(const MultiArray& x) const {return m_data==x.m_data;}
};
template <class T, int N> struct is_sequence<classdesc::MultiArray<T,N>>:
public true_type {};
/// friended advance function for use in iterators
template <class T, int R>
void advance(MultiArray<T,R>& a, std::ptrdiff_t n)
{a.m_data+=n;}
// template <class T, int R>
// struct tn<MultiArray<T,R>>
// {
// static string name() {return "classdesc::MultiArray<"+typeName<T>()+","+std::to_string(R)+">";}
// };
//
// template <class T>
// struct tn<T,
// typename enable_if<
// And<is_base_of<MultiArrayIterator,T>, Not<is_same<MultiArrayIterator,T>>>,
// void>::T>
// {
// static string name() {
// return typeName<typename T::value_type>()+"::iterator";
// }
// };
}
#endif