Skip to content

Commit 5df3e05

Browse files
Merge pull request #7 from SylvainCorlay/fast-pybind
Simplify byte_offset
2 parents cfb000e + 007a721 commit 5df3e05

File tree

2 files changed

+10
-51
lines changed

2 files changed

+10
-51
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -330,24 +330,14 @@ namespace xt
330330
template<typename... Args>
331331
inline auto pyarray<T, ExtraFlags>::operator()(Args... args) -> reference
332332
{
333-
if (sizeof...(args) != dimension())
334-
{
335-
pybind_array::fail_dim_check(sizeof...(args), "index dimension mismatch");
336-
}
337-
// not using pybind_array::offset_at() / index_at() here so as to avoid another dimension check.
338-
return *(static_cast<pointer>(pybind_array::mutable_data()) + pybind_array::get_byte_offset(args...) / itemsize());
333+
return *(static_cast<pointer>(pybind_array::mutable_data()) + pybind_array::byte_offset(args...) / itemsize());
339334
}
340335

341336
template <class T, int ExtraFlags>
342337
template<typename... Args>
343338
inline auto pyarray<T, ExtraFlags>::operator()(Args... args) const -> const_reference
344339
{
345-
if (sizeof...(args) != dimension())
346-
{
347-
pybind_array::fail_dim_check(sizeof...(args), "index dimension mismatch");
348-
}
349-
// not using pybind_array::offset_at() / index_at() here so as to avoid another dimension check.
350-
return *(static_cast<const_pointer>(pybind_array::data()) + pybind_array::get_byte_offset(args...) / itemsize());
340+
return *(static_cast<const_pointer>(pybind_array::data()) + pybind_array::byte_offset(args...) / itemsize());
351341
}
352342

353343
template <class T, int ExtraFlags>
@@ -522,7 +512,7 @@ namespace xt
522512
template<typename... Args>
523513
inline auto pyarray<T, ExtraFlags>::index_at(Args... args) const -> size_type
524514
{
525-
return pybind_array::offset_at(args...) / itemsize();
515+
return pybind_array::byte_offset(args...) / itemsize();
526516
}
527517

528518
template <class T, int ExtraFlags>

include/xtensor-python/pybind11_backport.hpp

Lines changed: 7 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ namespace pybind11
9191

9292
size_type size() const
9393
{
94-
return std::accumulate(shape(), shape() + ndim(), size_type{1}, std::multiplies<size_type>());
94+
return std::accumulate(shape(), shape() + ndim(), size_type(1), std::multiplies<size_type>());
9595
}
9696

9797
size_type itemsize() const
@@ -114,61 +114,30 @@ namespace pybind11
114114
return reinterpret_cast<const size_type*>(PyArray_GET_(m_ptr, strides));
115115
}
116116

117-
template<typename... Ix>
118117
void* data()
119118
{
120119
return static_cast<void*>(PyArray_GET_(m_ptr, data));
121120
}
122121

123-
template<typename... Ix>
124122
void* mutable_data()
125123
{
126-
// check_writeable();
127124
return static_cast<void *>(PyArray_GET_(m_ptr, data));
128125
}
129126

130-
template<typename... Ix>
131-
size_type offset_at(Ix... index) const
132-
{
133-
if (sizeof...(index) > ndim())
134-
{
135-
fail_dim_check(sizeof...(index), "too many indices for an array");
136-
}
137-
return get_byte_offset(index...);
138-
}
127+
protected:
139128

140-
size_type offset_at() const
129+
template<size_t dim = 0>
130+
inline size_type byte_offset() const
141131
{
142132
return 0;
143133
}
144134

145-
protected:
146-
147-
void fail_dim_check(size_type dim, const std::string& msg) const
135+
template <size_t dim = 0, class... Args>
136+
inline size_type byte_offset(size_type i, Args... args) const
148137
{
149-
throw index_error(msg + ": " + std::to_string(dim) +
150-
" (ndim = " + std::to_string(ndim()) + ")");
138+
return i * strides()[dim] + byte_offset<dim + 1>(args...);
151139
}
152140

153-
template<typename... Ix>
154-
size_type get_byte_offset(Ix... index) const
155-
{
156-
const size_type idx[] = { static_cast<size_type>(index)... };
157-
if (!std::equal(idx + 0, idx + sizeof...(index), shape(), std::less<size_type>{}))
158-
{
159-
auto mismatch = std::mismatch(idx + 0, idx + sizeof...(index), shape(), std::less<size_type>{});
160-
throw index_error(std::string("index ") + std::to_string(*mismatch.first) +
161-
" is out of bounds for axis " + std::to_string(mismatch.first - idx) +
162-
" with size " + std::to_string(*mismatch.second));
163-
}
164-
return std::inner_product(idx + 0, idx + sizeof...(index), strides(), size_type{0});
165-
}
166-
167-
size_type get_byte_offset() const
168-
{
169-
return 0;
170-
}
171-
172141
static std::vector<size_type>
173142
default_strides(const std::vector<size_type>& shape, size_type itemsize)
174143
{

0 commit comments

Comments
 (0)