Skip to content

Commit 98de977

Browse files
committed
broadcast and operator[] implementation
1 parent 5df3e05 commit 98de977

File tree

1 file changed

+40
-6
lines changed

1 file changed

+40
-6
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,20 @@ namespace xt
129129
template<typename... Args>
130130
const_reference operator()(Args... args) const;
131131

132+
reference operator[](const xindex& index);
133+
const_reference operator[](const xindex& index) const;
134+
132135
template<typename... Args>
133136
pointer data(Args... args);
134137

135138
template<typename... Args>
136139
const_pointer data(Args... args) const;
137140

138-
bool broadcast_shape(shape_type& shape) const;
139-
bool is_trivial_broadcast(const strides_type& strides) const;
141+
template <class S>
142+
bool broadcast_shape(S& shape) const;
143+
144+
template <class S>
145+
bool is_trivial_broadcast(const S& strides) const;
140146

141147
iterator begin();
142148
iterator end();
@@ -175,9 +181,11 @@ namespace xt
175181
private:
176182

177183
template<typename... Args>
178-
auto index_at(Args... args) const -> size_type;
184+
size_type index_at(Args... args) const;
185+
186+
size_type data_offset(const xindex& index) const;
179187

180-
static constexpr auto itemsize() -> size_type;
188+
static constexpr size_type itemsize();
181189

182190
static bool is_non_null(PyObject* ptr);
183191

@@ -340,6 +348,18 @@ namespace xt
340348
return *(static_cast<const_pointer>(pybind_array::data()) + pybind_array::byte_offset(args...) / itemsize());
341349
}
342350

351+
template <class T, int ExtraFlags>
352+
inline auto pyarray<T, ExtraFlags>::operator[](const xindex& index) -> reference
353+
{
354+
return *(static_cast<pointer>(pybind_array::mutable_data()) + data_offset(index));
355+
}
356+
357+
template <class T, int ExtraFlags>
358+
inline auto pyarray<T, ExtraFlags>::operator[](const xindex& index) const -> const_reference
359+
{
360+
return *(static_cast<const_pointer>(pybind_array::data()) + data_offset(index));
361+
}
362+
343363
template <class T, int ExtraFlags>
344364
template<typename... Args>
345365
inline auto pyarray<T, ExtraFlags>::data(Args... args) -> pointer
@@ -355,13 +375,15 @@ namespace xt
355375
}
356376

357377
template <class T, int ExtraFlags>
358-
bool pyarray<T, ExtraFlags>::broadcast_shape(shape_type& shape) const
378+
template <class S>
379+
bool pyarray<T, ExtraFlags>::broadcast_shape(S& shape) const
359380
{
360381
return xt::broadcast_shape(this->shape(), shape);
361382
}
362383

363384
template <class T, int ExtraFlags>
364-
bool pyarray<T, ExtraFlags>::is_trivial_broadcast(const strides_type& strides) const
385+
template <class S>
386+
bool pyarray<T, ExtraFlags>::is_trivial_broadcast(const S& strides) const
365387
{
366388
return strides.size() == dimension() &&
367389
std::equal(strides.begin(), strides.end(), this->strides().begin());
@@ -515,6 +537,18 @@ namespace xt
515537
return pybind_array::byte_offset(args...) / itemsize();
516538
}
517539

540+
template <class T, int ExtraFlags>
541+
inline auto pyarray<T, ExtraFlags>::data_offset(const xindex& index) const -> size_type
542+
{
543+
const strides_type& str = strides();
544+
auto iter = index.begin();
545+
iter += index.size() - str.size();
546+
return std::inner_product(str.begin(), str.end(), iter, size_type(0)) / itemsize();
547+
}
548+
549+
template <class T, int ExtraFlags>
550+
inline auto pyarray<T, ExtraFlags>::data
551+
518552
template <class T, int ExtraFlags>
519553
constexpr auto pyarray<T, ExtraFlags>::itemsize() -> size_type
520554
{

0 commit comments

Comments
 (0)