Skip to content

Commit 96524b0

Browse files
authored
Merge pull request #6 from SylvainCorlay/dynamic-access
Dynamic access operator
2 parents 5df3e05 + 8534ce7 commit 96524b0

File tree

6 files changed

+51
-64
lines changed

6 files changed

+51
-64
lines changed

.appveyor.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ install:
2424
- conda info -a
2525
- conda install pytest -c conda-forge
2626
- cd test
27-
- conda install xtensor==0.1.1 pytest numpy pybind11==1.8.1 -c conda-forge
27+
- conda install xtensor==0.2.0 pytest numpy pybind11==1.8.1 -c conda-forge
2828
- xcopy /S %APPVEYOR_BUILD_FOLDER%\include %MINICONDA%\include
2929

3030
build_script:

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ install:
5959
# Useful for debugging any issues with conda
6060
- conda info -a
6161
- cd test
62-
- conda install xtensor==0.1.1 pytest numpy pybind11==1.8.1 -c conda-forge
62+
- conda install xtensor==0.2.0 pytest numpy pybind11==1.8.1 -c conda-forge
6363
- cp -r $TRAVIS_BUILD_DIR/include/* $HOME/miniconda/include/
6464

6565
script:

conda.recipe/bld.bat

Lines changed: 0 additions & 2 deletions
This file was deleted.

conda.recipe/build.sh

Lines changed: 0 additions & 2 deletions
This file was deleted.

conda.recipe/meta.yaml

Lines changed: 0 additions & 41 deletions
This file was deleted.

include/xtensor-python/pyarray.hpp

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include <cstddef>
1313
#include <algorithm>
14+
#include <vector>
1415

1516
#include "pybind11/numpy.h"
1617
#include "pybind11_backport.hpp"
@@ -33,7 +34,7 @@ namespace xt
3334
class pyarray;
3435

3536
template <class T, int ExtraFlags>
36-
struct array_inner_types<pyarray<T, ExtraFlags>>
37+
struct xcontainer_inner_types<pyarray<T, ExtraFlags>>
3738
{
3839
using temporary_type = pyarray<T, ExtraFlags>;
3940
};
@@ -63,14 +64,14 @@ namespace xt
6364
*/
6465
template <class T, int ExtraFlags = pybind_array::forcecast>
6566
class pyarray : public pybind_array,
66-
public xarray_semantic<pyarray<T, ExtraFlags>>
67+
public xcontainer_semantic<pyarray<T, ExtraFlags>>
6768
{
6869

6970
public:
7071

7172
using self_type = pyarray<T, ExtraFlags>;
7273
using base_type = pybind_array;
73-
using semantic_base = xarray_semantic<self_type>;
74+
using semantic_base = xcontainer_semantic<self_type>;
7475
using value_type = T;
7576
using reference = T&;
7677
using const_reference = const T&;
@@ -89,8 +90,8 @@ namespace xt
8990
using storage_iterator = T*;
9091
using const_storage_iterator = const T*;
9192

92-
using shape_type = xshape<size_type>;
93-
using strides_type = xstrides<size_type>;
93+
using shape_type = std::vector<size_type>;
94+
using strides_type = std::vector<size_type>;
9495
using backstrides_type = pyarray_backstrides<self_type>;
9596

9697
using closure_type = const self_type&;
@@ -101,12 +102,12 @@ namespace xt
101102

102103
explicit pyarray(const buffer_info& info);
103104

104-
pyarray(const xshape<size_type>& shape,
105-
const xstrides<size_type>& strides,
105+
pyarray(const shape_type& shape,
106+
const strides_type& strides,
106107
const T* ptr = nullptr,
107108
handle base = handle());
108109

109-
explicit pyarray(const xshape<size_type>& shape,
110+
explicit pyarray(const shape_type& shape,
110111
const T* ptr = nullptr,
111112
handle base = handle());
112113

@@ -129,14 +130,20 @@ namespace xt
129130
template<typename... Args>
130131
const_reference operator()(Args... args) const;
131132

133+
reference operator[](const xindex& index);
134+
const_reference operator[](const xindex& index) const;
135+
132136
template<typename... Args>
133137
pointer data(Args... args);
134138

135139
template<typename... Args>
136140
const_pointer data(Args... args) const;
137141

138-
bool broadcast_shape(shape_type& shape) const;
139-
bool is_trivial_broadcast(const strides_type& strides) const;
142+
template <class S>
143+
bool broadcast_shape(S& shape) const;
144+
145+
template <class S>
146+
bool is_trivial_broadcast(const S& strides) const;
140147

141148
iterator begin();
142149
iterator end();
@@ -175,9 +182,11 @@ namespace xt
175182
private:
176183

177184
template<typename... Args>
178-
auto index_at(Args... args) const -> size_type;
185+
size_type index_at(Args... args) const;
179186

180-
static constexpr auto itemsize() -> size_type;
187+
size_type data_offset(const xindex& index) const;
188+
189+
static constexpr size_type itemsize();
181190

182191
static bool is_non_null(PyObject* ptr);
183192

@@ -223,16 +232,16 @@ namespace xt
223232
}
224233

225234
template <class T, int ExtraFlags>
226-
inline pyarray<T, ExtraFlags>::pyarray(const xshape<size_type>& shape,
227-
const xstrides<size_type>& strides,
235+
inline pyarray<T, ExtraFlags>::pyarray(const shape_type& shape,
236+
const strides_type& strides,
228237
const T *ptr,
229238
handle base)
230239
: pybind_array(shape, strides, ptr, base)
231240
{
232241
}
233242

234243
template <class T, int ExtraFlags>
235-
inline pyarray<T, ExtraFlags>::pyarray(const xshape<size_type>& shape,
244+
inline pyarray<T, ExtraFlags>::pyarray(const shape_type& shape,
236245
const T* ptr,
237246
handle base)
238247
: pybind_array(shape, ptr, base)
@@ -340,6 +349,18 @@ namespace xt
340349
return *(static_cast<const_pointer>(pybind_array::data()) + pybind_array::byte_offset(args...) / itemsize());
341350
}
342351

352+
template <class T, int ExtraFlags>
353+
inline auto pyarray<T, ExtraFlags>::operator[](const xindex& index) -> reference
354+
{
355+
return *(static_cast<pointer>(pybind_array::mutable_data()) + data_offset(index));
356+
}
357+
358+
template <class T, int ExtraFlags>
359+
inline auto pyarray<T, ExtraFlags>::operator[](const xindex& index) const -> const_reference
360+
{
361+
return *(static_cast<const_pointer>(pybind_array::data()) + data_offset(index));
362+
}
363+
343364
template <class T, int ExtraFlags>
344365
template<typename... Args>
345366
inline auto pyarray<T, ExtraFlags>::data(Args... args) -> pointer
@@ -355,13 +376,15 @@ namespace xt
355376
}
356377

357378
template <class T, int ExtraFlags>
358-
bool pyarray<T, ExtraFlags>::broadcast_shape(shape_type& shape) const
379+
template <class S>
380+
bool pyarray<T, ExtraFlags>::broadcast_shape(S& shape) const
359381
{
360382
return xt::broadcast_shape(this->shape(), shape);
361383
}
362384

363385
template <class T, int ExtraFlags>
364-
bool pyarray<T, ExtraFlags>::is_trivial_broadcast(const strides_type& strides) const
386+
template <class S>
387+
bool pyarray<T, ExtraFlags>::is_trivial_broadcast(const S& strides) const
365388
{
366389
return strides.size() == dimension() &&
367390
std::equal(strides.begin(), strides.end(), this->strides().begin());
@@ -515,6 +538,15 @@ namespace xt
515538
return pybind_array::byte_offset(args...) / itemsize();
516539
}
517540

541+
template <class T, int ExtraFlags>
542+
inline auto pyarray<T, ExtraFlags>::data_offset(const xindex& index) const -> size_type
543+
{
544+
const strides_type& str = strides();
545+
auto iter = index.begin();
546+
iter += index.size() - str.size();
547+
return std::inner_product(str.begin(), str.end(), iter, size_type(0)) / itemsize();
548+
}
549+
518550
template <class T, int ExtraFlags>
519551
constexpr auto pyarray<T, ExtraFlags>::itemsize() -> size_type
520552
{

0 commit comments

Comments
 (0)