1515#include " pybind11/pybind11.h"
1616#include " pybind11/common.h"
1717#include " pybind11/complex.h"
18- // Because of layout, else xiterator and xtensor_forward are sufficient
18+
19+ // Because of layout, otherwise xiterator and xtensor_forward are sufficient
1920#include " xtensor/xcontainer.hpp"
2021
2122namespace xt
2223{
2324
2425 template <class D >
25- class pycontainer : public pybind11 ::object
26+ class pycontainer : public pybind11 ::object, public xiterable<D>
2627 {
2728
2829 public:
@@ -45,14 +46,16 @@ namespace xt
4546 using inner_shape_type = typename inner_types::inner_shape_type;
4647 using inner_strides_type = typename inner_types::inner_strides_type;
4748
48- using iterator = typename container_type::iterator;
49- using const_iterator = typename container_type::const_iterator;
49+ using iterable_base = xiterable<D>;
5050
51- using stepper = xstepper<D> ;
52- using const_stepper = xstepper< const D> ;
51+ using iterator = typename iterable_base::iterator ;
52+ using const_iterator = typename iterable_base::const_iterator ;
5353
54- using broadcast_iterator = xiterator<stepper, inner_shape_type*>;
55- using const_broadcast_iterator = xiterator<const_stepper, inner_shape_type*>;
54+ using stepper = typename iterable_base::stepper;
55+ using const_stepper = typename iterable_base::const_stepper;
56+
57+ using broadcast_iterator = typename iterable_base::broadcast_iterator;
58+ using const_broadcast_iterator = typename iterable_base::broadcast_iterator;
5659
5760 size_type size () const ;
5861 size_type dimension () const ;
@@ -97,28 +100,6 @@ namespace xt
97100 const_iterator cbegin () const ;
98101 const_iterator cend () const ;
99102
100- broadcast_iterator xbegin ();
101- broadcast_iterator xend ();
102-
103- const_broadcast_iterator xbegin () const ;
104- const_broadcast_iterator xend () const ;
105- const_broadcast_iterator cxbegin () const ;
106- const_broadcast_iterator cxend () const ;
107-
108- template <class S >
109- xiterator<stepper, S> xbegin (const S& shape);
110- template <class S >
111- xiterator<stepper, S> xend (const S& shape);
112-
113- template <class S >
114- xiterator<const_stepper, S> xbegin (const S& shape) const ;
115- template <class S >
116- xiterator<const_stepper, S> xend (const S& shape) const ;
117- template <class S >
118- xiterator<const_stepper, S> cxbegin (const S& shape) const ;
119- template <class S >
120- xiterator<const_stepper, S> cxend (const S& shape) const ;
121-
122103 template <class S >
123104 stepper stepper_begin (const S& shape);
124105 template <class S >
@@ -144,26 +125,22 @@ namespace xt
144125 pycontainer (pycontainer&&) = default ;
145126 pycontainer& operator =(pycontainer&&) = default ;
146127
147- void fill_default_strides (const shape_type& shape,
148- layout l,
149- strides_type& strides);
150-
151128 static derived_type ensure (pybind11::handle h);
152129 static bool check_ (pybind11::handle h);
153130 static PyObject* raw_array_t (PyObject* ptr);
154131
155132 PyArrayObject* python_array ();
133+ };
156134
157- private:
158-
159- template <size_t dim = 0 >
160- size_type data_offset (const inner_strides_type&) const ;
161-
162- template <size_t dim = 0 , class ... Args>
163- size_type data_offset (const inner_strides_type& strides, size_type i, Args... args) const ;
164-
165- template <class It >
166- size_type element_offset (const inner_strides_type& strides, It first, It last) const ;
135+ template <class D >
136+ struct pycontainer_iterable_types
137+ : xcontainer_iterable_types<D>
138+ {
139+ using stepper = xstepper<D>;
140+ using const_stepper = xstepper<const D>;
141+ using inner_shape_type = typename xcontainer_inner_types<D>::shape_type;
142+ using broadcast_iterator = xiterator<stepper, inner_shape_type*>;
143+ using const_broadcast_iterator = xiterator<const_stepper, inner_shape_type*>;
167144 };
168145
169146 namespace detail
@@ -241,36 +218,6 @@ namespace xt
241218 throw pybind11::error_already_set ();
242219 }
243220
244- template <class D >
245- inline void pycontainer<D>::fill_default_strides(const shape_type& shape, layout l, strides_type& strides)
246- {
247- typename strides_type::value_type data_size = 1 ;
248- if (l == layout::row_major)
249- {
250- for (size_type i = strides.size (); i != 0 ; --i)
251- {
252- strides[i - 1 ] = data_size;
253- data_size = strides[i - 1 ] * shape[i - 1 ];
254- if (shape[i - 1 ] == 1 )
255- {
256- strides[i - 1 ] = 0 ;
257- }
258- }
259- }
260- else
261- {
262- for (size_type i = 0 ; i < strides.size (); ++i)
263- {
264- strides[i] = data_size;
265- data_size = strides[i] * shape[i];
266- if (shape[i] == 1 )
267- {
268- strides[i] = 0 ;
269- }
270- }
271- }
272- }
273-
274221 template <class D >
275222 inline auto pycontainer<D>::ensure(pybind11::handle h) -> derived_type
276223 {
@@ -306,29 +253,6 @@ namespace xt
306253 return reinterpret_cast <PyArrayObject*>(this ->m_ptr );
307254 }
308255
309- template <class D >
310- template <size_t dim>
311- inline auto pycontainer<D>::data_offset(const inner_strides_type&) const -> size_type
312- {
313- return 0 ;
314- }
315-
316- template <class D >
317- template <size_t dim, class ... Args>
318- inline auto pycontainer<D>::data_offset(const inner_strides_type& strides, size_type i, Args... args) const -> size_type
319- {
320- return i * strides[dim] + data_offset<dim + 1 >(strides, args...);
321- }
322-
323- template <class D >
324- template <class It >
325- inline auto pycontainer<D>::element_offset(const inner_strides_type& strides, It, It last) const -> size_type
326- {
327- It first = last;
328- first -= strides.size ();
329- return std::inner_product (strides.begin (), strides.end (), first, size_type (0 ));
330- }
331-
332256 template <class D >
333257 inline auto pycontainer<D>::size() const -> size_type
334258 {
@@ -344,7 +268,7 @@ namespace xt
344268 template <class D >
345269 inline auto pycontainer<D>::shape() const -> const inner_shape_type&
346270 {
347- return static_cast <const derived_type*>(this )-> shape_impl ();
271+ return static_cast <const derived_type*>(this )->shape_impl ();
348272 }
349273
350274 template <class D >
@@ -372,7 +296,7 @@ namespace xt
372296 inline void pycontainer<D>::reshape(const shape_type& shape, layout l)
373297 {
374298 strides_type strides (shape.size ());
375- fill_default_strides (shape, l, strides);
299+ compute_strides (shape, l, strides);
376300 reshape (shape, strides);
377301 }
378302
@@ -387,15 +311,15 @@ namespace xt
387311 template <class ... Args>
388312 inline auto pycontainer<D>::operator ()(Args... args) -> reference
389313 {
390- size_type index = data_offset (strides (), static_cast <size_type>(args)...);
314+ size_type index = data_offset<size_type> (strides (), static_cast <size_type>(args)...);
391315 return data ()[index];
392316 }
393317
394318 template <class D >
395319 template <class ... Args>
396320 inline auto pycontainer<D>::operator ()(Args... args) const -> const_reference
397321 {
398- size_type index = data_offset (strides (), static_cast <size_type>(args)...);
322+ size_type index = data_offset<size_type> (strides (), static_cast <size_type>(args)...);
399323 return data ()[index];
400324 }
401325
@@ -415,14 +339,14 @@ namespace xt
415339 template <class It >
416340 inline auto pycontainer<D>::element(It first, It last) -> reference
417341 {
418- return data ()[element_offset (strides (), first, last)];
342+ return data ()[element_offset<size_type> (strides (), first, last)];
419343 }
420344
421345 template <class D >
422346 template <class It >
423347 inline auto pycontainer<D>::element(It first, It last) const -> const_reference
424348 {
425- return data ()[element_offset (strides (), first, last)];
349+ return data ()[element_offset<size_type> (strides (), first, last)];
426350 }
427351
428352 template <class D >
@@ -468,107 +392,25 @@ namespace xt
468392 template <class D >
469393 inline auto pycontainer<D>::begin() const -> const_iterator
470394 {
471- return data (). cbegin ();
395+ return cbegin ();
472396 }
473397
474398 template <class D >
475399 inline auto pycontainer<D>::end() const -> const_iterator
476400 {
477- return data (). cend ();
401+ return cend ();
478402 }
479403
480404 template <class D >
481405 inline auto pycontainer<D>::cbegin() const -> const_iterator
482406 {
483- return begin ();
407+ return data (). cbegin ();
484408 }
485409
486410 template <class D >
487411 inline auto pycontainer<D>::cend() const -> const_iterator
488412 {
489- return end ();
490- }
491-
492- template <class D >
493- inline auto pycontainer<D>::xbegin() -> broadcast_iterator
494- {
495- const inner_shape_type& sh = shape ();
496- return broadcast_iterator (stepper_begin (sh), sh);
497- }
498-
499- template <class D >
500- inline auto pycontainer<D>::xend() -> broadcast_iterator
501- {
502- const inner_shape_type& sh = shape ();
503- return broadcast_iterator (stepper_end (sh), sh);
504- }
505-
506- template <class D >
507- inline auto pycontainer<D>::xbegin() const -> const_broadcast_iterator
508- {
509- const inner_shape_type& sh = shape ();
510- return const_broadcast_iterator (stepper_begin (sh), sh);
511- }
512-
513- template <class D >
514- inline auto pycontainer<D>::xend() const -> const_broadcast_iterator
515- {
516- const inner_shape_type& sh = shape ();
517- return const_broadcast_iterator (stepper_end (sh), sh);
518- }
519-
520- template <class D >
521- inline auto pycontainer<D>::cxbegin() const -> const_broadcast_iterator
522- {
523- return xbegin ();
524- }
525-
526- template <class D >
527- inline auto pycontainer<D>::cxend() const -> const_broadcast_iterator
528- {
529- return xend ();
530- }
531-
532- template <class D >
533- template <class S >
534- inline auto pycontainer<D>::xbegin(const S& shape) -> xiterator<stepper, S>
535- {
536- return xiterator<stepper, S>(stepper_begin (shape), shape);
537- }
538-
539- template <class D >
540- template <class S >
541- inline auto pycontainer<D>::xend(const S& shape) -> xiterator<stepper, S>
542- {
543- return xiterator<stepper, S>(stepper_end (shape), shape);
544- }
545-
546- template <class D >
547- template <class S >
548- inline auto pycontainer<D>::xbegin(const S& shape) const -> xiterator<const_stepper, S>
549- {
550- return xiterator<const_stepper, S>(stepper_begin (shape), shape);
551- }
552-
553- template <class D >
554- template <class S >
555- inline auto pycontainer<D>::xend(const S& shape) const -> xiterator<const_stepper, S>
556- {
557- return xiterator<const_stepper, S>(stepper_end (shape), shape);
558- }
559-
560- template <class D >
561- template <class S >
562- inline auto pycontainer<D>::cxbegin(const S& shape) const -> xiterator<const_stepper, S>
563- {
564- return xbegin (shape);
565- }
566-
567- template <class D >
568- template <class S >
569- inline auto pycontainer<D>::cxend(const S& shape) const -> xiterator<const_stepper, S>
570- {
571- return xend (shape);
413+ return data ().cend ();
572414 }
573415
574416 template <class D >
@@ -602,7 +444,6 @@ namespace xt
602444 size_type offset = shape.size () - dimension ();
603445 return const_stepper (static_cast <const derived_type*>(this ), data ().end (), offset);
604446 }
605-
606447}
607448
608449#endif
0 commit comments