@@ -129,14 +129,20 @@ namespace xt
129129 template <typename ... Args>
130130 const_reference operator ()(Args... args) const ;
131131
132+ reference operator [](const xindex& index);
133+ const_reference operator [](const xindex& index) const ;
134+
132135 template <typename ... Args>
133136 pointer data (Args... args);
134137
135138 template <typename ... Args>
136139 const_pointer data (Args... args) const ;
137140
138- bool broadcast_shape (shape_type& shape) const ;
139- bool is_trivial_broadcast (const strides_type& strides) const ;
141+ template <class S >
142+ bool broadcast_shape (S& shape) const ;
143+
144+ template <class S >
145+ bool is_trivial_broadcast (const S& strides) const ;
140146
141147 iterator begin ();
142148 iterator end ();
@@ -175,9 +181,11 @@ namespace xt
175181 private:
176182
177183 template <typename ... Args>
178- auto index_at (Args... args) const -> size_type;
184+ size_type index_at (Args... args) const ;
185+
186+ size_type data_offset (const xindex& index) const ;
179187
180- static constexpr auto itemsize () -> size_type ;
188+ static constexpr size_type itemsize ();
181189
182190 static bool is_non_null (PyObject* ptr);
183191
@@ -340,6 +348,18 @@ namespace xt
340348 return *(static_cast <const_pointer>(pybind_array::data ()) + pybind_array::byte_offset (args...) / itemsize ());
341349 }
342350
351+ template <class T , int ExtraFlags>
352+ inline auto pyarray<T, ExtraFlags>::operator [](const xindex& index) -> reference
353+ {
354+ return *(static_cast <pointer>(pybind_array::mutable_data ()) + data_offset (index));
355+ }
356+
357+ template <class T , int ExtraFlags>
358+ inline auto pyarray<T, ExtraFlags>::operator [](const xindex& index) const -> const_reference
359+ {
360+ return *(static_cast <const_pointer>(pybind_array::data ()) + data_offset (index));
361+ }
362+
343363 template <class T , int ExtraFlags>
344364 template <typename ... Args>
345365 inline auto pyarray<T, ExtraFlags>::data(Args... args) -> pointer
@@ -355,13 +375,15 @@ namespace xt
355375 }
356376
357377 template <class T , int ExtraFlags>
358- bool pyarray<T, ExtraFlags>::broadcast_shape(shape_type& shape) const
378+ template <class S >
379+ bool pyarray<T, ExtraFlags>::broadcast_shape(S& shape) const
359380 {
360381 return xt::broadcast_shape (this ->shape (), shape);
361382 }
362383
363384 template <class T , int ExtraFlags>
364- bool pyarray<T, ExtraFlags>::is_trivial_broadcast(const strides_type& strides) const
385+ template <class S >
386+ bool pyarray<T, ExtraFlags>::is_trivial_broadcast(const S& strides) const
365387 {
366388 return strides.size () == dimension () &&
367389 std::equal (strides.begin (), strides.end (), this ->strides ().begin ());
@@ -515,6 +537,18 @@ namespace xt
515537 return pybind_array::byte_offset (args...) / itemsize ();
516538 }
517539
540+ template <class T , int ExtraFlags>
541+ inline auto pyarray<T, ExtraFlags>::data_offset(const xindex& index) const -> size_type
542+ {
543+ const strides_type& str = strides ();
544+ auto iter = index.begin ();
545+ iter += index.size () - str.size ();
546+ return std::inner_product (str.begin (), str.end (), iter, size_type (0 )) / itemsize ();
547+ }
548+
549+ template <class T , int ExtraFlags>
550+ inline auto pyarray<T, ExtraFlags>::data
551+
518552 template <class T , int ExtraFlags>
519553 constexpr auto pyarray<T, ExtraFlags>::itemsize() -> size_type
520554 {
0 commit comments