@@ -29,7 +29,7 @@ namespace pybind11
2929 namespace detail
3030 {
3131 template <class T , std::size_t N>
32- struct handle_type_name <pytensor<T, std:: size_t N>>
32+ struct handle_type_name <xt:: pytensor<T, N>>
3333 {
3434 static PYBIND11_DESCR name ()
3535 {
@@ -38,9 +38,9 @@ namespace pybind11
3838 };
3939
4040 template <class T , std::size_t N>
41- struct pyobject_caster <pytensor<T, N>>
41+ struct pyobject_caster <xt:: pytensor<T, N>>
4242 {
43- using type = pytensor<T, N>;
43+ using type = xt:: pytensor<T, N>;
4444
4545 bool load (handle src, bool )
4646 {
@@ -50,7 +50,7 @@ namespace pybind11
5050
5151 static handle cast (const handle& src, return_value_policy, handle)
5252 {
53- src.inc_ref ();
53+ return src.inc_ref ();
5454 }
5555
5656 PYBIND11_TYPE_CASTER (type, handle_type_name<type>::name());
@@ -82,12 +82,17 @@ namespace xt
8282 using semantic_base = xcontainer_semantic<self_type>;
8383 using base_type = pycontainer<pytensor<T, N>>;
8484 using container_type = typename base_type::container_type;
85+ using pointer = typename base_type::pointer;
86+ using size_type = typename base_type::size_type;
87+ using shape_type = typename base_type::shape_type;
88+ using strides_type = typename base_type::strides_type;
89+ using backstrides_type = typename base_type::backstrides_type;
8590
8691 pytensor ();
8792
8893 pytensor (pybind11::handle h, borrowed_t );
8994 pytensor (pybind11::handle h, stolen_t );
90- pyarray (const pybind11::object &o);
95+ pytensor (const pybind11::object &o);
9196
9297 pytensor (const shape_type& shape, const strides_type& strides);
9398 explicit pytensor (const shape_type& shape);
@@ -101,6 +106,9 @@ namespace xt
101106 void reshape (const shape_type& shape);
102107 void reshape (const shape_type& shape, const strides_type& strides);
103108
109+ using base_type::begin;
110+ using base_type::end;
111+
104112 static self_type ensure (pybind11::handle h);
105113 static bool check_ (pybind11::handle h);
106114
@@ -132,10 +140,6 @@ namespace xt
132140 template <class T , std::size_t N>
133141 inline pytensor<T, N>::pytensor()
134142 {
135- std::fill (m_shape.begin (), m_shape.end (), T (0 ));
136- std::fill (m_strides.begin (), m_strides.end (), T (0 ));
137- std::fill (m_backstrides.begin (), m_backstrides.end (), T (0 ));
138- m_data = container_type (nullptr , 0 );
139143 }
140144
141145 template <class T , std::size_t N>
@@ -156,20 +160,24 @@ namespace xt
156160 inline pytensor<T, N>::pytensor(const pybind11::object& o)
157161 : pybind11::object(base_type::raw_array_t (o.ptr()), stolen)
158162 {
163+ // std::cout << "Object constructor" << std::endl;
159164 if (!this ->m_ptr )
160165 throw pybind11::error_already_set ();
166+ init_from_python ();
161167 }
162168
163169 template <class T , std::size_t N>
164170 inline pytensor<T, N>::pytensor(const shape_type& shape,
165171 const strides_type& strides)
166172 {
173+ // std::cout << "Shape + strides constructor" << std::endl;
167174 init_tensor (shape, strides);
168175 }
169176
170177 template <class T , std::size_t N>
171178 inline pytensor<T, N>::pytensor(const shape_type& shape)
172179 {
180+ // std::cout << "Shape constructor" << std::endl;
173181 base_type::fill_default_strides (shape, m_strides);
174182 init_tensor (shape, m_strides);
175183 }
@@ -178,6 +186,7 @@ namespace xt
178186 template <class E >
179187 inline pytensor<T, N>::pytensor(const xexpression<E>& e)
180188 {
189+ // std::cout << "Extended constructor" << std::endl;
181190 semantic_base::assign (e);
182191 }
183192
@@ -189,8 +198,9 @@ namespace xt
189198 }
190199
191200 template <class T , std::size_t N>
192- inline void reshape (const shape_type& shape)
201+ inline void pytensor<T, N>:: reshape(const shape_type& shape)
193202 {
203+ // std::cout << "Reshape(shape)" << std::endl;
194204 if (shape != m_shape)
195205 {
196206 strides_type strides;
@@ -200,17 +210,20 @@ namespace xt
200210 }
201211
202212 template <class T , std::size_t N>
203- inline void reshape (const shape_type& shape, const strides_type& strides)
213+ inline void pytensor<T, N>:: reshape(const shape_type& shape, const strides_type& strides)
204214 {
215+ // std::cout << "Reshape(shape, strides)" << std::endl;
205216 self_type tmp (shape, strides);
206217 *this = std::move (tmp);
207218 }
208219
209220 template <class T , std::size_t N>
210221 inline auto pytensor<T, N>::ensure(pybind11::handle h) -> self_type
211222 {
223+ // std::cout << "Ensure" << std::endl;
212224 auto result = pybind11::reinterpret_steal<self_type>(base_type::raw_array_t (h.ptr ()));
213- if (!result)
225+ // auto result = pybind11::reinterpret_steal<self_type>(h.ptr());
226+ if (result.ptr () == nullptr )
214227 PyErr_Clear ();
215228 return result;
216229 }
@@ -219,12 +232,13 @@ namespace xt
219232 inline bool pytensor<T, N>::check_(pybind11::handle h)
220233 {
221234 int type_num = detail::numpy_traits<T>::type_num;
222- return PyArray_Check (h.ptr ()) && PyArray_EquivTypenums (PyArray_Type (h.ptr ()), type_num);
235+ return PyArray_Check (h.ptr ()) && PyArray_EquivTypenums (PyArray_TYPE (h.ptr ()), type_num);
223236 }
224237
225238 template <class T , std::size_t N>
226239 inline void pytensor<T, N>::init_tensor(const shape_type& shape, const strides_type& strides)
227240 {
241+ // std::cout << "init tensor" << std::endl;
228242 npy_intp python_strides[N];
229243 std::transform (strides.beign (), strides.end (), python_strides,
230244 [](auto v) { return sizeof (T) * v; });
@@ -252,22 +266,24 @@ namespace xt
252266 template <class T , std::size_t N>
253267 inline void pytensor<T, N>::init_from_python()
254268 {
269+ // std::cout << "init from python" << std::endl;
255270 if (PyArray_NDIM (this ->m_ptr ) != N)
256271 throw std::runtime_error (" NumPy: ndarray has incorrect number of dimensions" );
257272
258273 std::copy (PyArray_DIMS (this ->m_ptr ), PyArray_DIMS (this ->m_ptr ) + N, m_shape.begin ());
259274 std::transform (PyArray_STRIDES (this ->m_ptr ), PyArray_STRIDES (this ->m_ptr ) + N, m_strides.begin (),
260275 [](auto v) { return v / sizeof (T); });
261276 adapt_strides ();
262- m_data = container_type (PyArray_DATA (this ->m_ptr ), PyArray_SIZE (this ->m_ptr ));
277+ m_data = container_type (reinterpret_cast <pointer>(PyArray_DATA (this ->m_ptr )),
278+ static_cast <size_type>(PyArray_SIZE (this ->m_ptr )));
263279 }
264280
265281 template <class T , std::size_t N>
266282 inline void pytensor<T, N>::adapt_strides()
267283 {
268284 for (size_type i = 0 ; i < m_shape.size (); ++i)
269285 {
270- if (m_shape_ [i] == 1 )
286+ if (m_shape [i] == 1 )
271287 {
272288 m_strides[i] = 0 ;
273289 m_backstrides[i] = 0 ;
@@ -298,13 +314,13 @@ namespace xt
298314 }
299315
300316 template <class T , std::size_t N>
301- inline auto pytensor<T, N>::data () -> container_type&
317+ inline auto pytensor<T, N>::data_impl () -> container_type&
302318 {
303319 return m_data;
304320 }
305321
306322 template <class T , std::size_t N>
307- inline auto pytensor<T, N>::data () const -> const container_type&
323+ inline auto pytensor<T, N>::data_impl () const -> const container_type&
308324 {
309325 return m_data;
310326 }
0 commit comments