Skip to content

Commit 9d2289e

Browse files
committed
layout
1 parent a331f6d commit 9d2289e

File tree

3 files changed

+50
-26
lines changed

3 files changed

+50
-26
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ namespace xt
122122
pyarray(pybind11::handle h, pybind11::object::stolen_t);
123123
pyarray(const pybind11::object &o);
124124

125-
explicit pyarray(const shape_type& shape);
125+
explicit pyarray(const shape_type& shape, layout l = layout::row_major);
126126
pyarray(const shape_type& shape, const strides_type& strides);
127127

128128
template <class E>
@@ -201,10 +201,10 @@ namespace xt
201201
}
202202

203203
template <class T>
204-
inline pyarray<T>::pyarray(const shape_type& shape)
204+
inline pyarray<T>::pyarray(const shape_type& shape, layout l)
205205
{
206206
strides_type strides;
207-
base_type::fill_default_strides(shape, strides);
207+
base_type::fill_default_strides(shape, l, strides);
208208
init_array(shape, strides);
209209
}
210210

include/xtensor-python/pycontainer.hpp

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
#include <cmath>
1515
#include "pybind11/pybind11.h"
1616
#include "pybind11/common.h"
17-
#include "xtensor/xtensor_forward.hpp"
18-
#include "xtensor/xiterator.hpp"
17+
// Because of layout, else xiterator and xtensor_forward are sufficient
18+
#include "xtensor/xcontainer.hpp"
1919

2020
namespace xt
2121
{
@@ -61,6 +61,7 @@ namespace xt
6161
const backstrides_type& backstrides() const;
6262

6363
void reshape(const shape_type& shape);
64+
void reshape(const shape_type& shape, layout l);
6465
void reshape(const shape_type& shape, const strides_type& strides);
6566

6667
template <class... Args>
@@ -143,6 +144,7 @@ namespace xt
143144
pycontainer& operator=(pycontainer&&) = default;
144145

145146
void fill_default_strides(const shape_type& shape,
147+
layout l,
146148
strides_type& strides);
147149

148150
static derived_type ensure(pybind11::handle h);
@@ -232,16 +234,31 @@ namespace xt
232234
}
233235

234236
template <class D>
235-
inline void pycontainer<D>::fill_default_strides(const shape_type& shape, strides_type& strides)
237+
inline void pycontainer<D>::fill_default_strides(const shape_type& shape, layout l, strides_type& strides)
236238
{
237239
typename strides_type::value_type data_size = 1;
238-
for(size_type i = strides.size(); i != 0; --i)
240+
if(l == layout::row_major)
239241
{
240-
strides[i - 1] = data_size;
241-
data_size = strides[i - 1] * shape[i - 1];
242-
if(shape[i - 1] == 1)
242+
for(size_type i = strides.size(); i != 0; --i)
243243
{
244-
strides[i - 1] = 0;
244+
strides[i - 1] = data_size;
245+
data_size = strides[i - 1] * shape[i - 1];
246+
if(shape[i - 1] == 1)
247+
{
248+
strides[i - 1] = 0;
249+
}
250+
}
251+
}
252+
else
253+
{
254+
for(size_type i = 0; i < strides.size(); ++i)
255+
{
256+
strides[i] = data_size;
257+
data_size = strides[i] * shape[i];
258+
if(shape[i] == 1)
259+
{
260+
strides[i] = 0;
261+
}
245262
}
246263
}
247264
}
@@ -259,7 +276,8 @@ namespace xt
259276
inline bool pycontainer<D>::check_(pybind11::handle h)
260277
{
261278
int type_num = detail::numpy_traits<value_type>::type_num;
262-
return PyArray_Check(h.ptr()) && PyArray_EquivTypenums(PyArray_TYPE(python_array()), type_num);
279+
return PyArray_Check(h.ptr()) &&
280+
PyArray_EquivTypenums(PyArray_TYPE(reinterpret_cast<PyArrayObject*>(h.ptr())), type_num);
263281
}
264282

265283
template <class D>
@@ -338,12 +356,18 @@ namespace xt
338356
{
339357
if(shape.size() != dimension() || !std::equal(shape.begin(), shape.end(), this->shape().begin()))
340358
{
341-
strides_type strides(shape.size());
342-
fill_default_strides(shape, strides);
343-
reshape(shape, strides);
359+
reshape(shape, layout::row_major);
344360
}
345361
}
346362

363+
template <class D>
364+
inline void pycontainer<D>::reshape(const shape_type& shape, layout l)
365+
{
366+
strides_type strides(shape.size());
367+
fill_default_strides(shape, l, strides);
368+
reshape(shape, strides);
369+
}
370+
347371
template <class D>
348372
inline void pycontainer<D>::reshape(const shape_type& shape, const strides_type& strides)
349373
{
@@ -460,29 +484,29 @@ namespace xt
460484
template <class D>
461485
inline auto pycontainer<D>::xbegin() -> broadcast_iterator
462486
{
463-
const inner_shape_type& shape = shape();
464-
return broadcast_iterator(stepper_begin(shape), shape);
487+
const inner_shape_type& sh = shape();
488+
return broadcast_iterator(stepper_begin(sh), sh);
465489
}
466490

467491
template <class D>
468492
inline auto pycontainer<D>::xend() -> broadcast_iterator
469493
{
470-
const inner_shape_type& shape = shape();
471-
return broadcast_iterator(stepper_end(shape), shape);
494+
const inner_shape_type& sh = shape();
495+
return broadcast_iterator(stepper_end(sh), sh);
472496
}
473497

474498
template <class D>
475499
inline auto pycontainer<D>::xbegin() const -> const_broadcast_iterator
476500
{
477-
const inner_shape_type& shape = shape();
478-
return const_broadcast_iterator(stepper_begin(shape), shape);
501+
const inner_shape_type& sh = shape();
502+
return const_broadcast_iterator(stepper_begin(sh), sh);
479503
}
480504

481505
template <class D>
482506
inline auto pycontainer<D>::xend() const -> const_broadcast_iterator
483507
{
484-
const inner_shape_type& shape = shape();
485-
return const_broadcast_iterator(stepper_end(shape), shape);
508+
const inner_shape_type& sh = shape();
509+
return const_broadcast_iterator(stepper_end(sh), sh);
486510
}
487511

488512
template <class D>

include/xtensor-python/pytensor.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ namespace xt
9797
pytensor(pybind11::handle h, pybind11::object::stolen_t);
9898
pytensor(const pybind11::object &o);
9999

100-
explicit pytensor(const shape_type& shape);
100+
explicit pytensor(const shape_type& shape, layout l = layout::row_major);
101101
pytensor(const shape_type& shape, const strides_type& strides);
102102

103103
template <class E>
@@ -159,9 +159,9 @@ namespace xt
159159
}
160160

161161
template <class T, std::size_t N>
162-
inline pytensor<T, N>::pytensor(const shape_type& shape)
162+
inline pytensor<T, N>::pytensor(const shape_type& shape, layout l)
163163
{
164-
base_type::fill_default_strides(shape, m_strides);
164+
base_type::fill_default_strides(shape, l, m_strides);
165165
init_tensor(shape, m_strides);
166166
}
167167

0 commit comments

Comments
 (0)