Skip to content

Commit b0ac9d1

Browse files
wolfvJohanMabille
authored andcommitted
add test for complex overload (#87)
* add test for complex overload * add non failing complex overload test * fix complex overload loading
1 parent de8f064 commit b0ac9d1

File tree

4 files changed

+68
-2
lines changed

4 files changed

+68
-2
lines changed

include/xtensor-python/pyarray.hpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,20 @@ namespace pybind11
4242
{
4343
using type = xt::pyarray<T>;
4444

45-
bool load(handle src, bool)
45+
bool load(handle src, bool convert)
4646
{
47+
if (!convert)
48+
{
49+
if (!PyArray_Check(src.ptr()))
50+
{
51+
return false;
52+
}
53+
int type_num = xt::detail::numpy_traits<T>::type_num;
54+
if (PyArray_TYPE(reinterpret_cast<PyArrayObject*>(src.ptr())) != type_num)
55+
{
56+
return false;
57+
}
58+
}
4759
value = type::ensure(src);
4860
return static_cast<bool>(value);
4961
}

include/xtensor-python/pytensor.hpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,21 @@ namespace pybind11
4343
{
4444
using type = xt::pytensor<T, N>;
4545

46-
bool load(handle src, bool)
46+
bool load(handle src, bool convert)
4747
{
48+
if (!convert)
49+
{
50+
if (!PyArray_Check(src.ptr()))
51+
{
52+
return false;
53+
}
54+
int type_num = xt::detail::numpy_traits<T>::type_num;
55+
if (PyArray_TYPE(reinterpret_cast<PyArrayObject*>(src.ptr())) != type_num)
56+
{
57+
return false;
58+
}
59+
}
60+
4861
value = type::ensure(src);
4962
return static_cast<bool>(value);
5063
}

test_python/main.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,25 @@ double readme_example2(double i, double j)
4242
return std::sin(i) - std::cos(j);
4343
}
4444

45+
auto complex_overload(const xt::pyarray<std::complex<double>>& a)
46+
{
47+
return a;
48+
}
49+
auto no_complex_overload(const xt::pyarray<double>& a)
50+
{
51+
return a;
52+
}
53+
54+
auto complex_overload_reg(const std::complex<double>& a)
55+
{
56+
return a;
57+
}
58+
59+
auto no_complex_overload_reg(const double& a)
60+
{
61+
return a;
62+
}
63+
4564
// Vectorize Examples
4665

4766
int add(int i, int j)
@@ -58,6 +77,11 @@ PYBIND11_PLUGIN(xtensor_python_test)
5877
m.def("example1", example1);
5978
m.def("example2", example2);
6079

80+
m.def("complex_overload", no_complex_overload);
81+
m.def("complex_overload", complex_overload);
82+
m.def("complex_overload_reg", no_complex_overload_reg);
83+
m.def("complex_overload_reg", complex_overload_reg);
84+
6185
m.def("readme_example1", readme_example1);
6286
m.def("readme_example2", xt::pyvectorize(readme_example2));
6387

test_python/test_pyarray.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,23 @@ def test_readme_example1(self):
3636
y = xt.readme_example1(v)
3737
np.testing.assert_allclose(y, 1.2853996391883833, 1e-12)
3838

39+
def test_complex_overload_reg(self):
40+
a = 23.23
41+
c = 2.0 + 3.1j
42+
self.assertEqual(xt.complex_overload_reg(a), a)
43+
self.assertEqual(xt.complex_overload_reg(c), c)
44+
45+
def test_complex_overload(self):
46+
a = np.random.rand(3, 3)
47+
b = np.random.rand(3, 3)
48+
c = a + b * 1j
49+
y = xt.complex_overload(c)
50+
np.testing.assert_allclose(np.imag(y), np.imag(c))
51+
np.testing.assert_allclose(np.real(y), np.real(c))
52+
x = xt.complex_overload(b)
53+
self.assertEqual(x.dtype, b.dtype)
54+
np.testing.assert_allclose(x, b)
55+
3956
def test_readme_example2(self):
4057
x = np.arange(15).reshape(3, 5)
4158
y = [1, 2, 3, 4, 5]

0 commit comments

Comments
 (0)