Skip to content

Commit 7a39162

Browse files
committed
complex arrays fixed
1 parent a6e0512 commit 7a39162

File tree

6 files changed

+20
-3
lines changed

6 files changed

+20
-3
lines changed

benchmark/main.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
#include "xtensor-python/pytensor.hpp"
99
#include "xtensor-python/pyvectorize.hpp"
1010

11-
#include <complex>
1211
using complex_t = std::complex<double>;
1312

1413
namespace py = pybind11;

include/xtensor-python/pybuffer_adaptor.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,12 @@ namespace xt
280280

281281
inline self_type operator+(difference_type n) const { return self_type(p_current + n); }
282282
inline self_type operator-(difference_type n) const { return self_type(p_current - n); }
283+
inline self_type operator-(const self_type& rhs) const
284+
{
285+
self_type tmp(*this);
286+
tmp -= (p_current - rhs.p_current);
287+
return tmp;
288+
}
283289

284290
pointer get_pointer() const { return p_current; }
285291

include/xtensor-python/pyvectorize.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace xt
2626
{
2727
}
2828

29-
pybind11::object operator()(const pyarray<Args>&... args)
29+
inline pyarray<R> operator()(const pyarray<Args>&... args)
3030
{
3131
pyarray<R> res = m_vectorizer(args...);
3232
return res;

test/main.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <numeric>
88

99
namespace py = pybind11;
10+
using complex_t = std::complex<double>;
1011

1112
// Examples
1213

@@ -58,5 +59,10 @@ PYBIND11_PLUGIN(xtensor_python_test)
5859

5960
m.def("vectorize_example1", xt::pyvectorize(add), "");
6061

62+
m.def("rect_to_polar", [](xt::pyarray<complex_t> const& a) {
63+
return py::make_tuple(xt::pyvectorize([](complex_t x) { return std::abs(x); })(a),
64+
xt::pyvectorize([](complex_t x) { return std::arg(x); })(a));
65+
});
66+
6167
return m.ptr();
6268
}

test/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def build_extensions(self):
107107
description='An example project using xtensor-python',
108108
long_description='',
109109
ext_modules=ext_modules,
110-
install_requires=['pybind11==2.0.1'],
110+
install_requires=['pybind11>=2.0.1'],
111111
cmdclass={'build_ext': BuildExt},
112112
zip_safe=False,
113113
)

test/test_pyarray.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,9 @@ def test_readme_example2(self):
5050
[-1.499227, 0.136731, 1.646979, 1.643002, 0.128456],
5151
[-1.084323, -0.583843, 0.45342 , 1.073811, 0.706945]], 1e-5)
5252

53+
def test_rect_to_polar(self):
54+
print("test6")
55+
x = np.ones(10, dtype=complex)
56+
z = xt.rect_to_polar(x[::2]);
57+
np.testing.assert_allclose(z, (np.ones(5, dtype=float), np.zeros(5, dtype=float)), 1e-5)
58+

0 commit comments

Comments
 (0)