Skip to content

Commit e84008c

Browse files
committed
vectorize implementation
1 parent aa02e58 commit e84008c

File tree

4 files changed

+94
-13
lines changed

4 files changed

+94
-13
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,8 @@ namespace xt
114114
handle base = handle());
115115

116116
size_type dimension() const;
117-
shape_type shape() const;
118-
strides_type strides() const;
117+
const shape_type& shape() const;
118+
const strides_type& strides() const;
119119
backstrides_type backstrides() const;
120120

121121
void reshape(const shape_type& shape);
@@ -182,6 +182,9 @@ namespace xt
182182

183183
static PyObject *ensure_(PyObject* ptr);
184184

185+
mutable shape_type m_shape;
186+
mutable strides_type m_strides;
187+
185188
};
186189

187190
/**************************************
@@ -198,7 +201,7 @@ namespace xt
198201
inline auto pyarray_backstrides<A>::operator[](size_type i) const -> value_type
199202
{
200203
value_type sh = p_a->shape()[i];
201-
value_type res = sh == 1 ? 0 : sh * p_a->strides()[i] / sizeof(typename A::value_type);
204+
value_type res = sh == 1 ? 0 : (sh - 1) * p_a->strides()[i] / sizeof(typename A::value_type);
202205
return res;
203206
}
204207

@@ -250,21 +253,21 @@ namespace xt
250253
}
251254

252255
template <class T, int ExtraFlags>
253-
inline auto pyarray<T, ExtraFlags>::shape() const -> shape_type
256+
inline auto pyarray<T, ExtraFlags>::shape() const -> const shape_type&
254257
{
255258
// Until we have the CRTP on shape types, we copy the shape.
256-
shape_type shape(dimension());
257-
std::copy(pybind_array::shape(), pybind_array::shape() + dimension(), shape.begin());
258-
return shape;
259+
m_shape.resize(dimension());
260+
std::copy(pybind_array::shape(), pybind_array::shape() + dimension(), m_shape.begin());
261+
return m_shape;
259262
}
260263

261264
template <class T, int ExtraFlags>
262-
inline auto pyarray<T, ExtraFlags>::strides() const -> strides_type
265+
inline auto pyarray<T, ExtraFlags>::strides() const -> const strides_type&
263266
{
264-
strides_type strides(dimension());
265-
std::transform(pybind_array::strides(), pybind_array::strides() + dimension(), strides.begin(),
267+
m_strides.resize(dimension());
268+
std::transform(pybind_array::strides(), pybind_array::strides() + dimension(), m_strides.begin(),
266269
[](size_type str) { return str / sizeof(value_type); });
267-
return strides;
270+
return m_strides;
268271
}
269272

270273
template <class T, int ExtraFlags>
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/***************************************************************************
2+
* Copyright (c) 2016, Johan Mabille and Sylvain Corlay *
3+
* *
4+
* Distributed under the terms of the BSD 3-Clause License. *
5+
* *
6+
* The full license is in the file LICENSE, distributed with this software. *
7+
****************************************************************************/
8+
9+
#ifndef PY_VECTORIZE_HPP
10+
#define PY_VECTORIZE_HPP
11+
12+
#include "pyarray.hpp"
13+
#include "xtensor/xvectorize.hpp"
14+
15+
namespace xt
16+
{
17+
18+
template <class Func, class R, class... Args>
19+
struct pyvectorizer
20+
{
21+
xvectorizer<Func, R> m_vectorizer;
22+
23+
template <class F>
24+
pyvectorizer(F&& func)
25+
: m_vectorizer(std::forward<F>(func))
26+
{
27+
}
28+
29+
pybind11::object operator()(pyarray<Args, pybind_array::c_style | pybind_array::forcecast>... args)
30+
{
31+
pyarray<R> res = m_vectorizer(args...);
32+
return res;
33+
}
34+
};
35+
36+
template <class R, class... Args>
37+
inline pyvectorizer<R(*)(Args...), R, Args...> pyvectorize(R(*f) (Args...))
38+
{
39+
return pyvectorizer<R(*) (Args...), R, Args...>(f);
40+
}
41+
42+
template <class F, class R, class... Args>
43+
inline pyvectorizer<F, R, Args...> pyvectorize(F&& f, R(*) (Args...))
44+
{
45+
return pyvectorizer<F, R, Args...>(std::forward<F>(f));
46+
}
47+
48+
template <class F>
49+
inline auto pyvectorize(F&& f) -> decltype(pyvectorize(std::forward<F>(f), (detail::get_function_type<F>*)nullptr))
50+
{
51+
return pyvectorize(std::forward<F>(f), (detail::get_function_type<F>*)nullptr);
52+
}
53+
}
54+
55+
#endif

test/main.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
#include <pybind11/pybind11.h>
22
#include "xtensor/xarray.hpp"
33
#include "xtensor-python/pyarray.hpp"
4+
#include "xtensor-python/pyvectorize.hpp"
5+
#include <iostream>
46

57
namespace py = pybind11;
68

9+
int add(int i, int j) {
10+
return i + j;
11+
}
12+
713
double test0(xt::pyarray<double> &m)
814
{
915
return m(0);
1016
}
1117

12-
xt::pyarray<double> test1(xt::pyarray<double> &m) {
18+
xt::pyarray<double> test1(xt::pyarray<double> &m)
19+
{
1320
return m + 2;
1421
}
1522

@@ -19,6 +26,7 @@ PYBIND11_PLUGIN(xtensor_python_test)
1926

2027
m.def("test0", test0, "");
2128
m.def("test1", test1, "");
29+
m.def("vec_add", xt::pyvectorize(add), "");
2230

2331
return m.ptr();
2432
}

test/test_pyarray.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,23 @@
1111

1212
from unittest import TestCase
1313
import xtensor_python_test as xt
14+
import numpy as np
1415

1516
class ExampleTest(TestCase):
1617

17-
def test_example(self):
18+
def test_example0(self):
1819
self.assertEqual(4, xt.test0([4, 5, 6]))
20+
21+
def test_example1(self):
22+
x = np.array([[0., 1.], [2., 3.]])
23+
res = np.array([[2., 3.], [4., 5.]])
24+
y = xt.test1(x)
25+
np.testing.assert_allclose(y, res, 1e-12)
26+
27+
def test_vectorize(self):
28+
x1 = np.array([[0, 1], [2, 3]])
29+
x2 = np.array([0, 1])
30+
res = np.array([[0, 2], [2, 4]])
31+
y = xt.vec_add(x1, x2)
32+
np.testing.assert_array_equal(y, res)
33+

0 commit comments

Comments
 (0)