1414#include < cmath>
1515#include " pybind11/pybind11.h"
1616#include " pybind11/common.h"
17- # include " xtensor/ xtensor_forward.hpp "
18- #include " xtensor/xiterator .hpp"
17+ // Because of layout, else xiterator and xtensor_forward are sufficient
18+ #include " xtensor/xcontainer .hpp"
1919
2020namespace xt
2121{
@@ -61,6 +61,7 @@ namespace xt
6161 const backstrides_type& backstrides () const ;
6262
6363 void reshape (const shape_type& shape);
64+ void reshape (const shape_type& shape, layout l);
6465 void reshape (const shape_type& shape, const strides_type& strides);
6566
6667 template <class ... Args>
@@ -143,6 +144,7 @@ namespace xt
143144 pycontainer& operator =(pycontainer&&) = default ;
144145
145146 void fill_default_strides (const shape_type& shape,
147+ layout l,
146148 strides_type& strides);
147149
148150 static derived_type ensure (pybind11::handle h);
@@ -232,16 +234,31 @@ namespace xt
232234 }
233235
234236 template <class D >
235- inline void pycontainer<D>::fill_default_strides(const shape_type& shape, strides_type& strides)
237+ inline void pycontainer<D>::fill_default_strides(const shape_type& shape, layout l, strides_type& strides)
236238 {
237239 typename strides_type::value_type data_size = 1 ;
238- for (size_type i = strides. size (); i != 0 ; --i )
240+ if (l == layout::row_major )
239241 {
240- strides[i - 1 ] = data_size;
241- data_size = strides[i - 1 ] * shape[i - 1 ];
242- if (shape[i - 1 ] == 1 )
242+ for (size_type i = strides.size (); i != 0 ; --i)
243243 {
244- strides[i - 1 ] = 0 ;
244+ strides[i - 1 ] = data_size;
245+ data_size = strides[i - 1 ] * shape[i - 1 ];
246+ if (shape[i - 1 ] == 1 )
247+ {
248+ strides[i - 1 ] = 0 ;
249+ }
250+ }
251+ }
252+ else
253+ {
254+ for (size_type i = 0 ; i < strides.size (); ++i)
255+ {
256+ strides[i] = data_size;
257+ data_size = strides[i] * shape[i];
258+ if (shape[i] == 1 )
259+ {
260+ strides[i] = 0 ;
261+ }
245262 }
246263 }
247264 }
@@ -259,7 +276,8 @@ namespace xt
259276 inline bool pycontainer<D>::check_(pybind11::handle h)
260277 {
261278 int type_num = detail::numpy_traits<value_type>::type_num;
262- return PyArray_Check (h.ptr ()) && PyArray_EquivTypenums (PyArray_TYPE (python_array ()), type_num);
279+ return PyArray_Check (h.ptr ()) &&
280+ PyArray_EquivTypenums (PyArray_TYPE (reinterpret_cast <PyArrayObject*>(h.ptr ())), type_num);
263281 }
264282
265283 template <class D >
@@ -338,12 +356,18 @@ namespace xt
338356 {
339357 if (shape.size () != dimension () || !std::equal (shape.begin (), shape.end (), this ->shape ().begin ()))
340358 {
341- strides_type strides (shape.size ());
342- fill_default_strides (shape, strides);
343- reshape (shape, strides);
359+ reshape (shape, layout::row_major);
344360 }
345361 }
346362
363+ template <class D >
364+ inline void pycontainer<D>::reshape(const shape_type& shape, layout l)
365+ {
366+ strides_type strides (shape.size ());
367+ fill_default_strides (shape, l, strides);
368+ reshape (shape, strides);
369+ }
370+
347371 template <class D >
348372 inline void pycontainer<D>::reshape(const shape_type& shape, const strides_type& strides)
349373 {
@@ -460,29 +484,29 @@ namespace xt
460484 template <class D >
461485 inline auto pycontainer<D>::xbegin() -> broadcast_iterator
462486 {
463- const inner_shape_type& shape = shape ();
464- return broadcast_iterator (stepper_begin (shape ), shape );
487+ const inner_shape_type& sh = shape ();
488+ return broadcast_iterator (stepper_begin (sh ), sh );
465489 }
466490
467491 template <class D >
468492 inline auto pycontainer<D>::xend() -> broadcast_iterator
469493 {
470- const inner_shape_type& shape = shape ();
471- return broadcast_iterator (stepper_end (shape ), shape );
494+ const inner_shape_type& sh = shape ();
495+ return broadcast_iterator (stepper_end (sh ), sh );
472496 }
473497
474498 template <class D >
475499 inline auto pycontainer<D>::xbegin() const -> const_broadcast_iterator
476500 {
477- const inner_shape_type& shape = shape ();
478- return const_broadcast_iterator (stepper_begin (shape ), shape );
501+ const inner_shape_type& sh = shape ();
502+ return const_broadcast_iterator (stepper_begin (sh ), sh );
479503 }
480504
481505 template <class D >
482506 inline auto pycontainer<D>::xend() const -> const_broadcast_iterator
483507 {
484- const inner_shape_type& shape = shape ();
485- return const_broadcast_iterator (stepper_end (shape ), shape );
508+ const inner_shape_type& sh = shape ();
509+ return const_broadcast_iterator (stepper_end (sh ), sh );
486510 }
487511
488512 template <class D >
0 commit comments