1414#include < vector>
1515
1616#include " pybind11/numpy.h"
17- #include " pybind11_backport.hpp"
1817
1918#include " xtensor/xexpression.hpp"
2019#include " xtensor/xsemantic.hpp"
2120#include " xtensor/xiterator.hpp"
2221
2322namespace xt
2423{
24+ template <class T , int ExtraFlags>
25+ class pyarray ;
26+ }
2527
26- using pybind_array = pybind11::backport::array;
27- using buffer_info = pybind11::buffer_info;
28+ namespace pybind11
29+ {
30+ namespace detail
31+ {
32+ template <typename T, int ExtraFlags>
33+ struct pyobject_caster <xt::pyarray<T, ExtraFlags>>
34+ {
35+ using type = xt::pyarray<T, ExtraFlags>;
36+
37+ bool load (handle src, bool )
38+ {
39+ value = type::ensure (src);
40+ return static_cast <bool >(value);
41+ }
42+
43+ static handle cast (const handle &src, return_value_policy, handle)
44+ {
45+ return src.inc_ref ();
46+ }
47+
48+ PYBIND11_TYPE_CASTER (type, handle_type_name<type>::name());
49+ };
50+ }
51+ }
52+
53+ namespace xt
54+ {
55+
56+ using pybind_array = pybind11::array;
2857
2958 /* **********************
3059 * pyarray declaration *
@@ -95,11 +124,11 @@ namespace xt
95124
96125 using closure_type = const self_type&;
97126
98- PYBIND11_OBJECT_CVT (pyarray, pybind_array, is_non_null, m_ptr = ensure_(m_ptr));
99-
100127 pyarray ();
101128
102- explicit pyarray (const buffer_info& info);
129+ pyarray (pybind11::handle h, borrowed_t );
130+ pyarray (pybind11::handle h, stolen_t );
131+ pyarray (const pybind11::object &o);
103132
104133 pyarray (const shape_type& shape,
105134 const strides_type& strides,
@@ -188,6 +217,9 @@ namespace xt
188217 template <class E >
189218 pyarray& operator =(const xexpression<E>& e);
190219
220+ static pyarray ensure (pybind11::handle h);
221+ static bool _check (pybind11::handle h);
222+
191223 private:
192224
193225 template <typename ... Args>
@@ -199,11 +231,10 @@ namespace xt
199231
200232 static bool is_non_null (PyObject* ptr);
201233
202- static PyObject *ensure_ (PyObject* ptr);
203-
204234 mutable shape_type m_shape;
205235 mutable strides_type m_strides;
206236
237+ static PyObject* raw_array_t (PyObject* ptr);
207238 };
208239
209240 /* *************************************
@@ -230,16 +261,29 @@ namespace xt
230261
231262 template <class T , int ExtraFlags>
232263 inline pyarray<T, ExtraFlags>::pyarray()
233- : pybind_array()
264+ : pybind_array(0 , static_cast <const_pointer>(nullptr ))
265+ {
266+ }
267+
268+ template <class T , int ExtraFlags>
269+ inline pyarray<T, ExtraFlags>::pyarray(pybind11::handle h, borrowed_t ) : pybind_array(h, borrowed)
234270 {
235271 }
236272
237273 template <class T , int ExtraFlags>
238- inline pyarray<T, ExtraFlags>::pyarray(const buffer_info& info)
239- : pybind_array(info)
274+ inline pyarray<T, ExtraFlags>::pyarray(pybind11::handle h, stolen_t ) : pybind_array(h, stolen)
240275 {
241276 }
242277
278+ template <class T , int ExtraFlags>
279+ inline pyarray<T, ExtraFlags>::pyarray(const pybind11::object &o) : pybind_array(raw_array_t (o.ptr()), stolen)
280+ {
281+ if (!m_ptr)
282+ {
283+ throw pybind11::error_already_set ();
284+ }
285+ }
286+
243287 template <class T , int ExtraFlags>
244288 inline pyarray<T, ExtraFlags>::pyarray(const shape_type& shape,
245289 const strides_type& strides,
@@ -512,7 +556,7 @@ namespace xt
512556 template <class T , int ExtraFlags>
513557 inline auto pyarray<T, ExtraFlags>::storage_begin() -> storage_iterator
514558 {
515- return reinterpret_cast <storage_iterator>(pybind11::backport ::array_proxy (m_ptr)->data );
559+ return reinterpret_cast <storage_iterator>(pybind11::detail ::array_proxy (m_ptr)->data );
516560 }
517561
518562 template <class T , int ExtraFlags>
@@ -524,7 +568,7 @@ namespace xt
524568 template <class T , int ExtraFlags>
525569 inline auto pyarray<T, ExtraFlags>::storage_begin() const -> const_storage_iterator
526570 {
527- return reinterpret_cast <const_storage_iterator>(pybind11::backport ::array_proxy (m_ptr)->data );
571+ return reinterpret_cast <const_storage_iterator>(pybind11::detail ::array_proxy (m_ptr)->data );
528572 }
529573
530574 template <class T , int ExtraFlags>
@@ -536,7 +580,7 @@ namespace xt
536580 template <class T , int ExtraFlags>
537581 inline auto pyarray<T, ExtraFlags>::storage_cbegin() const -> const_storage_iterator
538582 {
539- return reinterpret_cast <const_storage_iterator>(pybind11::backport ::array_proxy (m_ptr)->data );
583+ return reinterpret_cast <const_storage_iterator>(pybind11::detail ::array_proxy (m_ptr)->data );
540584 }
541585
542586 template <class T , int ExtraFlags>
@@ -560,6 +604,25 @@ namespace xt
560604 return semantic_base::operator =(e);
561605 }
562606
607+ template <class T , int ExtraFlags>
608+ inline pyarray<T, ExtraFlags> pyarray<T, ExtraFlags>::ensure(pybind11::handle h)
609+ {
610+ auto result = pybind11::reinterpret_steal<pyarray>(raw_array_t (h.ptr ()));
611+ if (!pybind11::handle (result))
612+ {
613+ PyErr_Clear ();
614+ }
615+ return result;
616+ }
617+
618+ template <class T , int ExtraFlags>
619+ inline bool pyarray<T, ExtraFlags>::_check(pybind11::handle h)
620+ {
621+ const auto &api = pybind11::detail::npy_api::get ();
622+ return api.PyArray_Check_ (h.ptr ())
623+ && api.PyArray_EquivTypes_ (pybind11::detail::array_proxy (h.ptr ())->descr , pybind11::dtype::of<T>().ptr ());
624+ }
625+
563626 // Private methods
564627
565628 template <class T , int ExtraFlags>
@@ -591,23 +654,17 @@ namespace xt
591654 }
592655
593656 template <class T , int ExtraFlags>
594- inline PyObject* pyarray<T, ExtraFlags>::ensure_ (PyObject* ptr)
657+ inline PyObject* pyarray<T, ExtraFlags>::raw_array_t (PyObject* ptr)
595658 {
596659 if (ptr == nullptr )
597660 {
598661 return nullptr ;
599662 }
600- API& api = lookup_api ();
601- PyObject* descr = api.PyArray_DescrFromType_ (pybind11::detail::npy_format_descriptor<T>::value);
602- PyObject* result = api.PyArray_FromAny_ (ptr, descr, 0 , 0 , API::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr );
603- if (!result)
604- {
605- PyErr_Clear ();
606- }
607- Py_DECREF (ptr);
608- return result;
663+ return pybind11::detail::npy_api::get ().PyArray_FromAny_ (
664+ ptr, pybind11::dtype::of<T>().release ().ptr (), 0 , 0 ,
665+ pybind11::detail::npy_api::NPY_ENSURE_ARRAY_ | ExtraFlags, nullptr
666+ );
609667 }
610-
611668}
612669
613670#endif
0 commit comments