@@ -37,6 +37,25 @@ namespace xt
3737 using temporary_type = pyarray<T, ExtraFlags>;
3838 };
3939
40+ template <class A >
41+ class pyarray_backstrides
42+ {
43+
44+ public:
45+
46+ using array_type = A;
47+ using value_type = typename array_type::size_type;
48+ using size_type = typename array_type::size_type;
49+
50+ pyarray_backstrides (const A& a);
51+
52+ value_type operator [](size_type i) const ;
53+
54+ private:
55+
56+ const pybind_array* p_a;
57+ };
58+
4059 /* *
4160 * @class pyarray
4261 * @brief Wrapper on the Python buffer protocol.
@@ -60,8 +79,8 @@ namespace xt
6079 using size_type = std::size_t ;
6180 using difference_type = std::ptrdiff_t ;
6281
63- using stepper = xstepper<T >;
64- using const_stepper = xstepper<const T >;
82+ using stepper = xstepper<self_type >;
83+ using const_stepper = xstepper<const self_type >;
6584
6685 using iterator = xiterator<stepper>;
6786 using const_iterator = xiterator<const_stepper>;
@@ -71,6 +90,7 @@ namespace xt
7190
7291 using shape_type = xshape<size_type>;
7392 using strides_type = xstrides<size_type>;
93+ using backstrides_type = pyarray_backstrides<self_type>;
7494
7595 using closure_type = const self_type&;
7696
@@ -93,19 +113,26 @@ namespace xt
93113 const T* ptr = nullptr ,
94114 handle base = handle());
95115
96- auto dimension () const -> size_type;
116+ size_type dimension () const ;
117+ shape_type shape () const ;
118+ strides_type strides () const ;
119+ backstrides_type backstrides () const ;
120+
121+ void reshape (const shape_type& shape);
122+ void reshape (const shape_type& shape, layout l);
123+ void reshape (const shape_type& shape, const strides_type& strides);
97124
98125 template <typename ... Args>
99- auto operator ()(Args... args) -> reference ;
126+ reference operator ()(Args... args);
100127
101128 template <typename ... Args>
102- auto operator ()(Args... args) const -> const_reference ;
129+ const_reference operator ()(Args... args) const ;
103130
104131 template <typename ... Args>
105- auto data (Args... args) -> pointer ;
132+ pointer data (Args... args);
106133
107134 template <typename ... Args>
108- auto data (Args... args) const -> const_pointer ;
135+ const_pointer data (Args... args) const ;
109136
110137 bool broadcast_shape (shape_type& shape) const ;
111138 bool is_trivial_broadcast (const strides_type& strides) const ;
@@ -138,8 +165,6 @@ namespace xt
138165 const_storage_iterator storage_begin () const ;
139166 const_storage_iterator storage_end () const ;
140167
141- shape_type shape () const ;
142-
143168 template <class E >
144169 pyarray (const xexpression<E>& e);
145170
@@ -158,6 +183,24 @@ namespace xt
158183 static PyObject *ensure_ (PyObject* ptr);
159184
160185 };
186+
187+ /* *************************************
188+ * pyarray_backstrides implementation *
189+ **************************************/
190+
191+ template <class A >
192+ inline pyarray_backstrides<A>::pyarray_backstrides(const A& a)
193+ : p_a(&a)
194+ {
195+ }
196+
197+ template <class A >
198+ inline auto pyarray_backstrides<A>::operator [](size_type i) const -> value_type
199+ {
200+ value_type sh = p_a->shape ()[i];
201+ value_type res = sh == 1 ? 0 : sh * p_a->strides ()[i] / sizeof (typename A::value_type);
202+ return res;
203+ }
161204
162205 /* *************************
163206 * pyarray implementation *
@@ -206,6 +249,79 @@ namespace xt
206249 return pybind_array::ndim ();
207250 }
208251
252+ template <class T , int ExtraFlags>
253+ inline auto pyarray<T, ExtraFlags>::shape() const -> shape_type
254+ {
255+ // Until we have the CRTP on shape types, we copy the shape.
256+ shape_type shape (dimension ());
257+ std::copy (pybind_array::shape (), pybind_array::shape () + dimension (), shape.begin ());
258+ return shape;
259+ }
260+
261+ template <class T , int ExtraFlags>
262+ inline auto pyarray<T, ExtraFlags>::strides() const -> strides_type
263+ {
264+ strides_type strides (dimension ());
265+ std::transform (pybind_array::strides (), pybind_array::strides () + dimension (), strides.begin (),
266+ [](size_type str) { return str / sizeof (value_type); });
267+ return strides;
268+ }
269+
270+ template <class T , int ExtraFlags>
271+ inline auto pyarray<T, ExtraFlags>::backstrides() const -> backstrides_type
272+ {
273+ backstrides_type tmp (*this );
274+ return tmp;
275+ }
276+
277+ template <class T , int ExtraFlags>
278+ void pyarray<T, ExtraFlags>::reshape(const shape_type& shape)
279+ {
280+ if (!m_ptr || shape.size () != dimension () || !std::equal (shape.begin (), shape.end (), pybind_array::shape ()))
281+ {
282+ reshape (shape, layout::row_major);
283+ }
284+ }
285+
286+ template <class T , int ExtraFlags>
287+ void pyarray<T, ExtraFlags>::reshape(const shape_type& shape, layout l)
288+ {
289+ strides_type strides (shape.size ());
290+ size_type data_size = sizeof (value_type);
291+ if (l == layout::row_major)
292+ {
293+ for (size_type i = strides.size (); i != 0 ; --i)
294+ {
295+ strides[i - 1 ] = data_size;
296+ data_size = strides[i - 1 ] * shape[i - 1 ];
297+ if (shape[i - 1 ] == 1 )
298+ {
299+ strides[i - 1 ] = 0 ;
300+ }
301+ }
302+ }
303+ else
304+ {
305+ for (size_type i = 0 ; i < strides.size (); ++i)
306+ {
307+ strides[i] = data_size;
308+ data_size = strides[i] * shape[i];
309+ if (shape[i] == 1 )
310+ {
311+ strides[i] = 0 ;
312+ }
313+ }
314+ }
315+ reshape (shape, strides);
316+ }
317+
318+ template <class T , int ExtraFlags>
319+ void pyarray<T, ExtraFlags>::reshape(const shape_type& shape, const strides_type& strides)
320+ {
321+ self_type tmp (shape, strides);
322+ *this = std::move (tmp);
323+ }
324+
209325 template <class T , int ExtraFlags>
210326 template <typename ... Args>
211327 inline auto pyarray<T, ExtraFlags>::operator ()(Args... args) -> reference
@@ -243,7 +359,20 @@ namespace xt
243359 {
244360 return static_cast <const T*>(pybind_array::data (args...));
245361 }
246-
362+
363+ template <class T , int ExtraFlags>
364+ bool pyarray<T, ExtraFlags>::broadcast_shape(shape_type& shape) const
365+ {
366+ return xt::broadcast_shape (this ->shape (), shape);
367+ }
368+
369+ template <class T , int ExtraFlags>
370+ bool pyarray<T, ExtraFlags>::is_trivial_broadcast(const strides_type& strides) const
371+ {
372+ return strides.size () == dimension () &&
373+ std::equal (strides.begin (), strides.end (), this ->strides ().begin ());
374+ }
375+
247376 template <class T , int ExtraFlags>
248377 inline auto pyarray<T, ExtraFlags>::begin() -> iterator
249378 {
@@ -347,7 +476,7 @@ namespace xt
347476 template <class T , int ExtraFlags>
348477 inline auto pyarray<T, ExtraFlags>::storage_begin() -> storage_iterator
349478 {
350- return static_cast <storage_iterator>(PyArray_GET_ (m_ptr, data));
479+ return reinterpret_cast <storage_iterator>(PyArray_GET_ (m_ptr, data));
351480 }
352481
353482 template <class T , int ExtraFlags>
@@ -359,7 +488,7 @@ namespace xt
359488 template <class T , int ExtraFlags>
360489 inline auto pyarray<T, ExtraFlags>::storage_begin() const -> const_storage_iterator
361490 {
362- return static_cast <const_storage_iterator>(PyArray_GET_ (m_ptr, data));
491+ return reinterpret_cast <const_storage_iterator>(PyArray_GET_ (m_ptr, data));
363492 }
364493
365494 template <class T , int ExtraFlags>
@@ -368,21 +497,12 @@ namespace xt
368497 return storage_begin () + pybind_array::size ();
369498 }
370499
371- template <class T , int ExtraFlags>
372- inline auto pyarray<T, ExtraFlags>::shape() const -> shape_type
373- {
374- // Until we have the CRTP on shape types, we copy the shape.
375- shape_type shape (dimension ());
376- std::copy (pybind_array::shape (), pybind_array::shape () + dimension (), shape.begin ());
377- return shape;
378- }
379-
380500 template <class T , int ExtraFlags>
381501 template <class E >
382502 inline pyarray<T, ExtraFlags>::pyarray(const xexpression<E>& e)
383503 : pybind_array()
384504 {
385- semantic_base::operator = (e);
505+ semantic_base::assign (e);
386506 }
387507
388508 template <class T , int ExtraFlags>
0 commit comments