1111
1212#include < cstddef>
1313#include < algorithm>
14+ #include < vector>
1415
1516#include " pybind11/numpy.h"
1617#include " pybind11_backport.hpp"
@@ -33,7 +34,7 @@ namespace xt
3334 class pyarray ;
3435
3536 template <class T , int ExtraFlags>
36- struct array_inner_types <pyarray<T, ExtraFlags>>
37+ struct xcontainer_inner_types <pyarray<T, ExtraFlags>>
3738 {
3839 using temporary_type = pyarray<T, ExtraFlags>;
3940 };
@@ -63,14 +64,14 @@ namespace xt
6364 */
6465 template <class T , int ExtraFlags = pybind_array::forcecast>
6566 class pyarray : public pybind_array ,
66- public xarray_semantic <pyarray<T, ExtraFlags>>
67+ public xcontainer_semantic <pyarray<T, ExtraFlags>>
6768 {
6869
6970 public:
7071
7172 using self_type = pyarray<T, ExtraFlags>;
7273 using base_type = pybind_array;
73- using semantic_base = xarray_semantic <self_type>;
74+ using semantic_base = xcontainer_semantic <self_type>;
7475 using value_type = T;
7576 using reference = T&;
7677 using const_reference = const T&;
@@ -89,8 +90,8 @@ namespace xt
8990 using storage_iterator = T*;
9091 using const_storage_iterator = const T*;
9192
92- using shape_type = xshape <size_type>;
93- using strides_type = xstrides <size_type>;
93+ using shape_type = std::vector <size_type>;
94+ using strides_type = std::vector <size_type>;
9495 using backstrides_type = pyarray_backstrides<self_type>;
9596
9697 using closure_type = const self_type&;
@@ -101,12 +102,12 @@ namespace xt
101102
102103 explicit pyarray (const buffer_info& info);
103104
104- pyarray (const xshape<size_type> & shape,
105- const xstrides<size_type> & strides,
105+ pyarray (const shape_type & shape,
106+ const strides_type & strides,
106107 const T* ptr = nullptr ,
107108 handle base = handle());
108109
109- explicit pyarray (const xshape<size_type> & shape,
110+ explicit pyarray (const shape_type & shape,
110111 const T* ptr = nullptr ,
111112 handle base = handle());
112113
@@ -129,14 +130,20 @@ namespace xt
129130 template <typename ... Args>
130131 const_reference operator ()(Args... args) const ;
131132
133+ reference operator [](const xindex& index);
134+ const_reference operator [](const xindex& index) const ;
135+
132136 template <typename ... Args>
133137 pointer data (Args... args);
134138
135139 template <typename ... Args>
136140 const_pointer data (Args... args) const ;
137141
138- bool broadcast_shape (shape_type& shape) const ;
139- bool is_trivial_broadcast (const strides_type& strides) const ;
142+ template <class S >
143+ bool broadcast_shape (S& shape) const ;
144+
145+ template <class S >
146+ bool is_trivial_broadcast (const S& strides) const ;
140147
141148 iterator begin ();
142149 iterator end ();
@@ -175,9 +182,11 @@ namespace xt
175182 private:
176183
177184 template <typename ... Args>
178- auto index_at (Args... args) const -> size_type ;
185+ size_type index_at (Args... args) const ;
179186
180- static constexpr auto itemsize () -> size_type;
187+ size_type data_offset (const xindex& index) const ;
188+
189+ static constexpr size_type itemsize ();
181190
182191 static bool is_non_null (PyObject* ptr);
183192
@@ -223,16 +232,16 @@ namespace xt
223232 }
224233
225234 template <class T , int ExtraFlags>
226- inline pyarray<T, ExtraFlags>::pyarray(const xshape<size_type> & shape,
227- const xstrides<size_type> & strides,
235+ inline pyarray<T, ExtraFlags>::pyarray(const shape_type & shape,
236+ const strides_type & strides,
228237 const T *ptr,
229238 handle base)
230239 : pybind_array(shape, strides, ptr, base)
231240 {
232241 }
233242
234243 template <class T , int ExtraFlags>
235- inline pyarray<T, ExtraFlags>::pyarray(const xshape<size_type> & shape,
244+ inline pyarray<T, ExtraFlags>::pyarray(const shape_type & shape,
236245 const T* ptr,
237246 handle base)
238247 : pybind_array(shape, ptr, base)
@@ -340,6 +349,18 @@ namespace xt
340349 return *(static_cast <const_pointer>(pybind_array::data ()) + pybind_array::byte_offset (args...) / itemsize ());
341350 }
342351
352+ template <class T , int ExtraFlags>
353+ inline auto pyarray<T, ExtraFlags>::operator [](const xindex& index) -> reference
354+ {
355+ return *(static_cast <pointer>(pybind_array::mutable_data ()) + data_offset (index));
356+ }
357+
358+ template <class T , int ExtraFlags>
359+ inline auto pyarray<T, ExtraFlags>::operator [](const xindex& index) const -> const_reference
360+ {
361+ return *(static_cast <const_pointer>(pybind_array::data ()) + data_offset (index));
362+ }
363+
343364 template <class T , int ExtraFlags>
344365 template <typename ... Args>
345366 inline auto pyarray<T, ExtraFlags>::data(Args... args) -> pointer
@@ -355,13 +376,15 @@ namespace xt
355376 }
356377
357378 template <class T , int ExtraFlags>
358- bool pyarray<T, ExtraFlags>::broadcast_shape(shape_type& shape) const
379+ template <class S >
380+ bool pyarray<T, ExtraFlags>::broadcast_shape(S& shape) const
359381 {
360382 return xt::broadcast_shape (this ->shape (), shape);
361383 }
362384
363385 template <class T , int ExtraFlags>
364- bool pyarray<T, ExtraFlags>::is_trivial_broadcast(const strides_type& strides) const
386+ template <class S >
387+ bool pyarray<T, ExtraFlags>::is_trivial_broadcast(const S& strides) const
365388 {
366389 return strides.size () == dimension () &&
367390 std::equal (strides.begin (), strides.end (), this ->strides ().begin ());
@@ -515,6 +538,15 @@ namespace xt
515538 return pybind_array::byte_offset (args...) / itemsize ();
516539 }
517540
541+ template <class T , int ExtraFlags>
542+ inline auto pyarray<T, ExtraFlags>::data_offset(const xindex& index) const -> size_type
543+ {
544+ const strides_type& str = strides ();
545+ auto iter = index.begin ();
546+ iter += index.size () - str.size ();
547+ return std::inner_product (str.begin (), str.end (), iter, size_type (0 )) / itemsize ();
548+ }
549+
518550 template <class T , int ExtraFlags>
519551 constexpr auto pyarray<T, ExtraFlags>::itemsize() -> size_type
520552 {
0 commit comments