Skip to content

Commit a6e0512

Browse files
committed
complex benchmark
1 parent 3285e44 commit a6e0512

File tree

5 files changed

+37
-0
lines changed

5 files changed

+37
-0
lines changed

benchmark/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ endif()
7878
configure_file(benchmark_pyarray.py benchmark_pyarray.py COPYONLY)
7979
configure_file(benchmark_pytensor.py benchmark_pytensor.py COPYONLY)
8080
configure_file(benchmark_pybind_array.py benchmark_pybind_array.py COPYONLY)
81+
configure_file(benchmark_pyvectorize.py benchmark_pyvectorize.py COPYONLY)
82+
configure_file(benchmark_pybind_vectorize.py benchmark_pybind_vectorize.py COPYONLY)
8183

8284
add_custom_target(xbenchmark DEPENDS ${XTENSOR_PYTHON_BENCHMARK_TARGET})
8385

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from benchmark_xtensor_python import pybind_rect_to_polar
2+
import numpy as np
3+
4+
from timeit import timeit
5+
w = np.ones(100000, dtype=complex)
6+
print (timeit('pybind_rect_to_polar(w[::2])', 'from __main__ import w, pybind_rect_to_polar', number=1000))

benchmark/benchmark_pyvectorize.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from benchmark_xtensor_python import rect_to_polar
2+
import numpy as np
3+
4+
from timeit import timeit
5+
w = np.ones(100000, dtype=complex)
6+
print (timeit('rect_to_polar(w[::2])', 'from __main__ import w, rect_to_polar', number=1000))

benchmark/main.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
#include "xtensor/xarray.hpp"
77
#include "xtensor-python/pyarray.hpp"
88
#include "xtensor-python/pytensor.hpp"
9+
#include "xtensor-python/pyvectorize.hpp"
910

1011
#include <complex>
12+
using complex_t = std::complex<double>;
1113

1214
namespace py = pybind11;
1315

@@ -47,5 +49,18 @@ PYBIND11_PLUGIN(benchmark_xtensor_python)
4749
}
4850
);
4951

52+
m.def("rect_to_polar", [](xt::pyarray<complex_t> const& a) {
53+
return py::make_tuple(xt::pyvectorize([](complex_t x) { return std::abs(x); })(a),
54+
xt::pyvectorize([](complex_t x) { return std::arg(x); })(a));
55+
});
56+
57+
m.def("pybind_rect_to_polar", [](py::array a) {
58+
if (py::isinstance<py::array_t<complex_t>>(a))
59+
return py::make_tuple(py::vectorize([](complex_t x) { return std::abs(x); })(a),
60+
py::vectorize([](complex_t x) { return std::arg(x); })(a));
61+
else
62+
throw py::type_error("rect_to_polar unhandled type");
63+
});
64+
5065
return m.ptr();
5166
}

include/xtensor-python/pycontainer.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <cmath>
1515
#include "pybind11/pybind11.h"
1616
#include "pybind11/common.h"
17+
#include "pybind11/complex.h"
1718
// Because of layout, else xiterator and xtensor_forward are sufficient
1819
#include "xtensor/xcontainer.hpp"
1920

@@ -188,6 +189,13 @@ namespace xt
188189
std::is_same<T, double>::value ? 1 : std::is_same<T, long double>::value ? 2 : 0));
189190
};
190191

192+
template <class T>
193+
struct is_fmt_numeric<std::complex<T>>
194+
{
195+
static constexpr bool value = true;
196+
static constexpr int index = is_fmt_numeric<T>::index + 3;
197+
};
198+
191199
template <class T>
192200
struct numpy_traits
193201
{

0 commit comments

Comments
 (0)