Skip to content

Commit aa02e58

Browse files
committed
reshape and bug fixes
1 parent a13a53d commit aa02e58

File tree

2 files changed

+148
-23
lines changed

2 files changed

+148
-23
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 142 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,25 @@ namespace xt
3737
using temporary_type = pyarray<T, ExtraFlags>;
3838
};
3939

40+
template <class A>
41+
class pyarray_backstrides
42+
{
43+
44+
public:
45+
46+
using array_type = A;
47+
using value_type = typename array_type::size_type;
48+
using size_type = typename array_type::size_type;
49+
50+
pyarray_backstrides(const A& a);
51+
52+
value_type operator[](size_type i) const;
53+
54+
private:
55+
56+
const pybind_array* p_a;
57+
};
58+
4059
/**
4160
* @class pyarray
4261
* @brief Wrapper on the Python buffer protocol.
@@ -60,8 +79,8 @@ namespace xt
6079
using size_type = std::size_t;
6180
using difference_type = std::ptrdiff_t;
6281

63-
using stepper = xstepper<T>;
64-
using const_stepper = xstepper<const T>;
82+
using stepper = xstepper<self_type>;
83+
using const_stepper = xstepper<const self_type>;
6584

6685
using iterator = xiterator<stepper>;
6786
using const_iterator = xiterator<const_stepper>;
@@ -71,6 +90,7 @@ namespace xt
7190

7291
using shape_type = xshape<size_type>;
7392
using strides_type = xstrides<size_type>;
93+
using backstrides_type = pyarray_backstrides<self_type>;
7494

7595
using closure_type = const self_type&;
7696

@@ -93,19 +113,26 @@ namespace xt
93113
const T* ptr = nullptr,
94114
handle base = handle());
95115

96-
auto dimension() const -> size_type;
116+
size_type dimension() const;
117+
shape_type shape() const;
118+
strides_type strides() const;
119+
backstrides_type backstrides() const;
120+
121+
void reshape(const shape_type& shape);
122+
void reshape(const shape_type& shape, layout l);
123+
void reshape(const shape_type& shape, const strides_type& strides);
97124

98125
template<typename... Args>
99-
auto operator()(Args... args) -> reference;
126+
reference operator()(Args... args);
100127

101128
template<typename... Args>
102-
auto operator()(Args... args) const -> const_reference;
129+
const_reference operator()(Args... args) const;
103130

104131
template<typename... Args>
105-
auto data(Args... args) -> pointer;
132+
pointer data(Args... args);
106133

107134
template<typename... Args>
108-
auto data(Args... args) const -> const_pointer;
135+
const_pointer data(Args... args) const;
109136

110137
bool broadcast_shape(shape_type& shape) const;
111138
bool is_trivial_broadcast(const strides_type& strides) const;
@@ -138,8 +165,6 @@ namespace xt
138165
const_storage_iterator storage_begin() const;
139166
const_storage_iterator storage_end() const;
140167

141-
shape_type shape() const;
142-
143168
template <class E>
144169
pyarray(const xexpression<E>& e);
145170

@@ -158,6 +183,24 @@ namespace xt
158183
static PyObject *ensure_(PyObject* ptr);
159184

160185
};
186+
187+
/**************************************
188+
* pyarray_backstrides implementation *
189+
**************************************/
190+
191+
template <class A>
192+
inline pyarray_backstrides<A>::pyarray_backstrides(const A& a)
193+
: p_a(&a)
194+
{
195+
}
196+
197+
template <class A>
198+
inline auto pyarray_backstrides<A>::operator[](size_type i) const -> value_type
199+
{
200+
value_type sh = p_a->shape()[i];
201+
value_type res = sh == 1 ? 0 : sh * p_a->strides()[i] / sizeof(typename A::value_type);
202+
return res;
203+
}
161204

162205
/**************************
163206
* pyarray implementation *
@@ -206,6 +249,79 @@ namespace xt
206249
return pybind_array::ndim();
207250
}
208251

252+
template <class T, int ExtraFlags>
253+
inline auto pyarray<T, ExtraFlags>::shape() const -> shape_type
254+
{
255+
// Until we have the CRTP on shape types, we copy the shape.
256+
shape_type shape(dimension());
257+
std::copy(pybind_array::shape(), pybind_array::shape() + dimension(), shape.begin());
258+
return shape;
259+
}
260+
261+
template <class T, int ExtraFlags>
262+
inline auto pyarray<T, ExtraFlags>::strides() const -> strides_type
263+
{
264+
strides_type strides(dimension());
265+
std::transform(pybind_array::strides(), pybind_array::strides() + dimension(), strides.begin(),
266+
[](size_type str) { return str / sizeof(value_type); });
267+
return strides;
268+
}
269+
270+
template <class T, int ExtraFlags>
271+
inline auto pyarray<T, ExtraFlags>::backstrides() const -> backstrides_type
272+
{
273+
backstrides_type tmp(*this);
274+
return tmp;
275+
}
276+
277+
template <class T, int ExtraFlags>
278+
void pyarray<T, ExtraFlags>::reshape(const shape_type& shape)
279+
{
280+
if (!m_ptr || shape.size() != dimension() || !std::equal(shape.begin(), shape.end(), pybind_array::shape()))
281+
{
282+
reshape(shape, layout::row_major);
283+
}
284+
}
285+
286+
template <class T, int ExtraFlags>
287+
void pyarray<T, ExtraFlags>::reshape(const shape_type& shape, layout l)
288+
{
289+
strides_type strides(shape.size());
290+
size_type data_size = sizeof(value_type);
291+
if (l == layout::row_major)
292+
{
293+
for (size_type i = strides.size(); i != 0; --i)
294+
{
295+
strides[i - 1] = data_size;
296+
data_size = strides[i - 1] * shape[i - 1];
297+
if (shape[i - 1] == 1)
298+
{
299+
strides[i - 1] = 0;
300+
}
301+
}
302+
}
303+
else
304+
{
305+
for (size_type i = 0; i < strides.size(); ++i)
306+
{
307+
strides[i] = data_size;
308+
data_size = strides[i] * shape[i];
309+
if (shape[i] == 1)
310+
{
311+
strides[i] = 0;
312+
}
313+
}
314+
}
315+
reshape(shape, strides);
316+
}
317+
318+
template <class T, int ExtraFlags>
319+
void pyarray<T, ExtraFlags>::reshape(const shape_type& shape, const strides_type& strides)
320+
{
321+
self_type tmp(shape, strides);
322+
*this = std::move(tmp);
323+
}
324+
209325
template <class T, int ExtraFlags>
210326
template<typename... Args>
211327
inline auto pyarray<T, ExtraFlags>::operator()(Args... args) -> reference
@@ -243,7 +359,20 @@ namespace xt
243359
{
244360
return static_cast<const T*>(pybind_array::data(args...));
245361
}
246-
362+
363+
template <class T, int ExtraFlags>
364+
bool pyarray<T, ExtraFlags>::broadcast_shape(shape_type& shape) const
365+
{
366+
return xt::broadcast_shape(this->shape(), shape);
367+
}
368+
369+
template <class T, int ExtraFlags>
370+
bool pyarray<T, ExtraFlags>::is_trivial_broadcast(const strides_type& strides) const
371+
{
372+
return strides.size() == dimension() &&
373+
std::equal(strides.begin(), strides.end(), this->strides().begin());
374+
}
375+
247376
template <class T, int ExtraFlags>
248377
inline auto pyarray<T, ExtraFlags>::begin() -> iterator
249378
{
@@ -347,7 +476,7 @@ namespace xt
347476
template <class T, int ExtraFlags>
348477
inline auto pyarray<T, ExtraFlags>::storage_begin() -> storage_iterator
349478
{
350-
return static_cast<storage_iterator>(PyArray_GET_(m_ptr, data));
479+
return reinterpret_cast<storage_iterator>(PyArray_GET_(m_ptr, data));
351480
}
352481

353482
template <class T, int ExtraFlags>
@@ -359,7 +488,7 @@ namespace xt
359488
template <class T, int ExtraFlags>
360489
inline auto pyarray<T, ExtraFlags>::storage_begin() const -> const_storage_iterator
361490
{
362-
return static_cast<const_storage_iterator>(PyArray_GET_(m_ptr, data));
491+
return reinterpret_cast<const_storage_iterator>(PyArray_GET_(m_ptr, data));
363492
}
364493

365494
template <class T, int ExtraFlags>
@@ -368,21 +497,12 @@ namespace xt
368497
return storage_begin() + pybind_array::size();
369498
}
370499

371-
template <class T, int ExtraFlags>
372-
inline auto pyarray<T, ExtraFlags>::shape() const -> shape_type
373-
{
374-
// Until we have the CRTP on shape types, we copy the shape.
375-
shape_type shape(dimension());
376-
std::copy(pybind_array::shape(), pybind_array::shape() + dimension(), shape.begin());
377-
return shape;
378-
}
379-
380500
template <class T, int ExtraFlags>
381501
template <class E>
382502
inline pyarray<T, ExtraFlags>::pyarray(const xexpression<E>& e)
383503
: pybind_array()
384504
{
385-
semantic_base::operator=(e);
505+
semantic_base::assign(e);
386506
}
387507

388508
template <class T, int ExtraFlags>

test/main.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,21 @@
44

55
namespace py = pybind11;
66

7-
int test0(xt::pyarray<double> &m)
7+
double test0(xt::pyarray<double> &m)
88
{
99
return m(0);
1010
}
1111

12+
xt::pyarray<double> test1(xt::pyarray<double> &m) {
13+
return m + 2;
14+
}
15+
1216
PYBIND11_PLUGIN(xtensor_python_test)
1317
{
1418
py::module m("xtensor_python_test", "Test module for xtensor python bindings");
1519

1620
m.def("test0", test0, "");
21+
m.def("test1", test1, "");
1722

1823
return m.ptr();
1924
}

0 commit comments

Comments
 (0)