File tree Expand file tree Collapse file tree 4 files changed +68
-2
lines changed
Expand file tree Collapse file tree 4 files changed +68
-2
lines changed Original file line number Diff line number Diff line change @@ -42,8 +42,20 @@ namespace pybind11
4242 {
4343 using type = xt::pyarray<T>;
4444
45- bool load (handle src, bool )
45+ bool load (handle src, bool convert )
4646 {
47+ if (!convert)
48+ {
49+ if (!PyArray_Check (src.ptr ()))
50+ {
51+ return false ;
52+ }
53+ int type_num = xt::detail::numpy_traits<T>::type_num;
54+ if (PyArray_TYPE (reinterpret_cast <PyArrayObject*>(src.ptr ())) != type_num)
55+ {
56+ return false ;
57+ }
58+ }
4759 value = type::ensure (src);
4860 return static_cast <bool >(value);
4961 }
Original file line number Diff line number Diff line change @@ -43,8 +43,21 @@ namespace pybind11
4343 {
4444 using type = xt::pytensor<T, N>;
4545
46- bool load (handle src, bool )
46+ bool load (handle src, bool convert )
4747 {
48+ if (!convert)
49+ {
50+ if (!PyArray_Check (src.ptr ()))
51+ {
52+ return false ;
53+ }
54+ int type_num = xt::detail::numpy_traits<T>::type_num;
55+ if (PyArray_TYPE (reinterpret_cast <PyArrayObject*>(src.ptr ())) != type_num)
56+ {
57+ return false ;
58+ }
59+ }
60+
4861 value = type::ensure (src);
4962 return static_cast <bool >(value);
5063 }
Original file line number Diff line number Diff line change @@ -42,6 +42,25 @@ double readme_example2(double i, double j)
4242 return std::sin (i) - std::cos (j);
4343}
4444
45+ auto complex_overload (const xt::pyarray<std::complex <double >>& a)
46+ {
47+ return a;
48+ }
49+ auto no_complex_overload (const xt::pyarray<double >& a)
50+ {
51+ return a;
52+ }
53+
54+ auto complex_overload_reg (const std::complex <double >& a)
55+ {
56+ return a;
57+ }
58+
59+ auto no_complex_overload_reg (const double & a)
60+ {
61+ return a;
62+ }
63+
4564// Vectorize Examples
4665
4766int add (int i, int j)
@@ -58,6 +77,11 @@ PYBIND11_PLUGIN(xtensor_python_test)
5877 m.def (" example1" , example1);
5978 m.def (" example2" , example2);
6079
80+ m.def (" complex_overload" , no_complex_overload);
81+ m.def (" complex_overload" , complex_overload);
82+ m.def (" complex_overload_reg" , no_complex_overload_reg);
83+ m.def (" complex_overload_reg" , complex_overload_reg);
84+
6185 m.def (" readme_example1" , readme_example1);
6286 m.def (" readme_example2" , xt::pyvectorize (readme_example2));
6387
Original file line number Diff line number Diff line change @@ -36,6 +36,23 @@ def test_readme_example1(self):
3636 y = xt .readme_example1 (v )
3737 np .testing .assert_allclose (y , 1.2853996391883833 , 1e-12 )
3838
39+ def test_complex_overload_reg (self ):
40+ a = 23.23
41+ c = 2.0 + 3.1j
42+ self .assertEqual (xt .complex_overload_reg (a ), a )
43+ self .assertEqual (xt .complex_overload_reg (c ), c )
44+
45+ def test_complex_overload (self ):
46+ a = np .random .rand (3 , 3 )
47+ b = np .random .rand (3 , 3 )
48+ c = a + b * 1j
49+ y = xt .complex_overload (c )
50+ np .testing .assert_allclose (np .imag (y ), np .imag (c ))
51+ np .testing .assert_allclose (np .real (y ), np .real (c ))
52+ x = xt .complex_overload (b )
53+ self .assertEqual (x .dtype , b .dtype )
54+ np .testing .assert_allclose (x , b )
55+
3956 def test_readme_example2 (self ):
4057 x = np .arange (15 ).reshape (3 , 5 )
4158 y = [1 , 2 , 3 , 4 , 5 ]
You can’t perform that action at this time.
0 commit comments