|
15 | 15 | #include "xtensor-python/pyarray.hpp" |
16 | 16 | #include "xtensor-python/pytensor.hpp" |
17 | 17 | #include "xtensor-python/pyvectorize.hpp" |
| 18 | +#include "xtensor/xadapt.hpp" |
| 19 | +#include "xtensor/xstrided_view.hpp" |
18 | 20 |
|
19 | 21 | namespace py = pybind11; |
20 | 22 | using complex_t = std::complex<double>; |
@@ -133,6 +135,49 @@ class C |
133 | 135 | array_type m_array; |
134 | 136 | }; |
135 | 137 |
|
| 138 | +struct test_native_casters |
| 139 | +{ |
| 140 | + using array_type = xt::xarray<double>; |
| 141 | + array_type a = xt::ones<double>({50, 50}); |
| 142 | + |
| 143 | + const auto & get_array() |
| 144 | + { |
| 145 | + return a; |
| 146 | + } |
| 147 | + |
| 148 | + auto get_strided_view() |
| 149 | + { |
| 150 | + return xt::strided_view(a, {xt::range(0, 1), xt::range(0, 3, 2)}); |
| 151 | + } |
| 152 | + |
| 153 | + auto get_array_adapter() |
| 154 | + { |
| 155 | + using shape_type = std::vector<size_t>; |
| 156 | + shape_type shape = {2, 2}; |
| 157 | + shape_type stride = {3, 2}; |
| 158 | + return xt::adapt(a.data(), 4, xt::no_ownership(), shape, stride); |
| 159 | + } |
| 160 | + |
| 161 | + auto get_tensor_adapter() |
| 162 | + { |
| 163 | + using shape_type = std::array<size_t, 2>; |
| 164 | + shape_type shape = {2, 2}; |
| 165 | + shape_type stride = {3, 2}; |
| 166 | + return xt::adapt(a.data(), 4, xt::no_ownership(), shape, stride); |
| 167 | + } |
| 168 | + |
| 169 | + auto get_owning_array_adapter() |
| 170 | + { |
| 171 | + size_t size = 100; |
| 172 | + int * data = new int[size]; |
| 173 | + std::fill(data, data + size, 1); |
| 174 | + |
| 175 | + using shape_type = std::vector<size_t>; |
| 176 | + shape_type shape = {size}; |
| 177 | + return xt::adapt(std::move(data), size, xt::acquire_ownership(), shape); |
| 178 | + } |
| 179 | +}; |
| 180 | + |
136 | 181 | xt::pyarray<A> dtype_to_python() |
137 | 182 | { |
138 | 183 | A a1{123, 321, 'a', {1, 2, 3}}; |
@@ -257,4 +302,15 @@ PYBIND11_MODULE(xtensor_python_test, m) |
257 | 302 |
|
258 | 303 | m.def("diff_shape_overload", [](xt::pytensor<int, 1> a) { return 1; }); |
259 | 304 | m.def("diff_shape_overload", [](xt::pytensor<int, 2> a) { return 2; }); |
| 305 | + |
| 306 | + py::class_<test_native_casters>(m, "test_native_casters") |
| 307 | + .def(py::init<>()) |
| 308 | + .def("get_array", &test_native_casters::get_array, py::return_value_policy::reference_internal) // memory managed by the class instance |
| 309 | + .def("get_strided_view", &test_native_casters::get_strided_view, py::keep_alive<0, 1>()) // keep_alive<0, 1>() => do not free "self" before the returned view |
| 310 | + .def("get_array_adapter", &test_native_casters::get_array_adapter, py::keep_alive<0, 1>()) // keep_alive<0, 1>() => do not free "self" before the returned adapter |
| 311 | + .def("get_tensor_adapter", &test_native_casters::get_tensor_adapter, py::keep_alive<0, 1>()) // keep_alive<0, 1>() => do not free "self" before the returned adapter |
| 312 | + .def("get_owning_array_adapter", &test_native_casters::get_owning_array_adapter) // auto memory management as the adapter owns its memory |
| 313 | + .def("view_keep_alive_member_function", [](test_native_casters & self, xt::pyarray<double> & a) // keep_alive<0, 2>() => do not free second parameter before the returned view |
| 314 | + {return xt::reshape_view(a, {a.size(), });}, |
| 315 | + py::keep_alive<0, 2>()); |
260 | 316 | } |
0 commit comments