@@ -41,8 +41,10 @@ namespace xt
41
41
using const_reverse_iterator = std::reverse_iterator<const_iterator>;
42
42
using reverse_iterator = const_reverse_iterator;
43
43
44
+ using shape_type = size_t *;
45
+
44
46
pystrides_adaptor () = default ;
45
- pystrides_adaptor (const_pointer data, size_type size);
47
+ pystrides_adaptor (const_pointer data, size_type size, shape_type shape );
46
48
47
49
bool empty () const noexcept ;
48
50
size_type size () const noexcept ;
@@ -66,6 +68,7 @@ namespace xt
66
68
67
69
const_pointer p_data;
68
70
size_type m_size;
71
+ shape_type p_shape;
69
72
};
70
73
71
74
/* *********************************
@@ -84,21 +87,23 @@ namespace xt
84
87
using reference = typename pystrides_adaptor<N>::const_reference;
85
88
using difference_type = typename pystrides_adaptor<N>::difference_type;
86
89
using iterator_category = std::random_access_iterator_tag;
90
+ using shape_pointer = typename pystrides_adaptor<N>::shape_type;
87
91
88
- inline pystrides_iterator (pointer current)
92
+ inline pystrides_iterator (pointer current, shape_pointer shape )
89
93
: p_current(current)
94
+ , p_shape(shape)
90
95
{
91
96
}
92
97
93
98
inline reference operator *() const
94
99
{
95
- return *p_current / N;
100
+ return *p_shape == size_t ( 1 ) ? 0 : * p_current / N;
96
101
}
97
102
98
103
inline pointer operator ->() const
99
104
{
100
105
// Returning the address of a temporary
101
- value_type res = *p_current / N ;
106
+ value_type res = this -> operator *() ;
102
107
return &res;
103
108
}
104
109
@@ -110,49 +115,55 @@ namespace xt
110
115
inline self_type& operator ++()
111
116
{
112
117
++p_current;
118
+ ++p_shape;
113
119
return *this ;
114
120
}
115
121
116
122
inline self_type& operator --()
117
123
{
118
124
--p_current;
125
+ --p_shape;
119
126
return *this ;
120
127
}
121
128
122
129
inline self_type operator ++(int )
123
130
{
124
131
self_type tmp (*this );
125
132
++p_current;
133
+ ++p_shape;
126
134
return tmp;
127
135
}
128
136
129
137
inline self_type operator --(int )
130
138
{
131
139
self_type tmp (*this );
132
140
--p_current;
141
+ --p_shape;
133
142
return tmp;
134
143
}
135
144
136
145
inline self_type& operator +=(difference_type n)
137
146
{
138
147
p_current += n;
148
+ p_shape += n;
139
149
return *this ;
140
150
}
141
151
142
152
inline self_type& operator -=(difference_type n)
143
153
{
144
154
p_current -= n;
155
+ p_shape -= n;
145
156
return *this ;
146
157
}
147
158
148
159
inline self_type operator +(difference_type n) const
149
160
{
150
- return self_type (p_current + n);
161
+ return self_type (p_current + n, p_shape + n );
151
162
}
152
163
153
164
inline self_type operator -(difference_type n) const
154
165
{
155
- return self_type (p_current - n);
166
+ return self_type (p_current - n, p_shape - n );
156
167
}
157
168
158
169
inline difference_type operator -(const self_type& rhs) const
@@ -166,6 +177,7 @@ namespace xt
166
177
private:
167
178
168
179
pointer p_current;
180
+ shape_pointer p_shape;
169
181
};
170
182
171
183
template <std::size_t N>
@@ -215,8 +227,8 @@ namespace xt
215
227
************************************/
216
228
217
229
template <std::size_t N>
218
- inline pystrides_adaptor<N>::pystrides_adaptor(const_pointer data, size_type size)
219
- : p_data(data), m_size(size)
230
+ inline pystrides_adaptor<N>::pystrides_adaptor(const_pointer data, size_type size, shape_type shape )
231
+ : p_data(data), m_size(size), p_shape(shape)
220
232
{
221
233
}
222
234
@@ -235,19 +247,19 @@ namespace xt
235
247
template <std::size_t N>
236
248
inline auto pystrides_adaptor<N>::operator [](size_type i) const -> const_reference
237
249
{
238
- return p_data[i] / N;
250
+ return p_shape[i] == size_t ( 1 ) ? 0 : p_data[i] / N;
239
251
}
240
252
241
253
template <std::size_t N>
242
254
inline auto pystrides_adaptor<N>::front() const -> const_reference
243
255
{
244
- return p_data[ 0 ] / N ;
256
+ return this -> operator []( 0 ) ;
245
257
}
246
258
247
259
template <std::size_t N>
248
260
inline auto pystrides_adaptor<N>::back() const -> const_reference
249
261
{
250
- return p_data[ m_size - 1 ] / N ;
262
+ return this -> operator []( m_size - 1 ) ;
251
263
}
252
264
253
265
template <std::size_t N>
@@ -265,13 +277,13 @@ namespace xt
265
277
template <std::size_t N>
266
278
inline auto pystrides_adaptor<N>::cbegin() const -> const_iterator
267
279
{
268
- return const_iterator (p_data);
280
+ return const_iterator (p_data, p_shape );
269
281
}
270
282
271
283
template <std::size_t N>
272
284
inline auto pystrides_adaptor<N>::cend() const -> const_iterator
273
285
{
274
- return const_iterator (p_data + m_size);
286
+ return const_iterator (p_data + m_size, p_shape + m_size );
275
287
}
276
288
277
289
template <std::size_t N>
0 commit comments