Skip to content

Commit 9b52975

Browse files
Merge pull request #37 from SylvainCorlay/shape-comparison
Add shape comparison operators
2 parents dc6446f + 3b7cea9 commit 9b52975

File tree

3 files changed

+76
-11
lines changed

3 files changed

+76
-11
lines changed

include/xtensor-python/pybuffer_adaptor.hpp

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111

1212
#include <cstddef>
1313
#include <iterator>
14+
#include <algorithm>
1415

1516
namespace xt
1617
{
1718

1819
template <class T>
1920
class pybuffer_adaptor
2021
{
21-
2222
public:
2323

2424
using value_type = T;
@@ -72,6 +72,24 @@ namespace xt
7272
size_type m_size;
7373
};
7474

75+
template<class T>
76+
bool operator==(const pybuffer_adaptor<T>& x, const pybuffer_adaptor<T>& y);
77+
78+
template<class T>
79+
bool operator<(const pybuffer_adaptor<T>& x, const pybuffer_adaptor<T>& y);
80+
81+
template<class T>
82+
bool operator!=(const pybuffer_adaptor<T>& x, const pybuffer_adaptor<T>& y);
83+
84+
template<class T>
85+
bool operator>(const pybuffer_adaptor<T>& x, const pybuffer_adaptor<T>& y);
86+
87+
template<class T>
88+
bool operator<=(const pybuffer_adaptor<T>& x, const pybuffer_adaptor<T>& y);
89+
90+
template<class T>
91+
bool operator>=(const pybuffer_adaptor<T>& x, const pybuffer_adaptor<T>& y);
92+
7593
template <std::size_t N>
7694
class pystrides_iterator;
7795

@@ -110,7 +128,7 @@ namespace xt
110128
const_pointer p_data;
111129
size_type m_size;
112130
};
113-
131+
114132
/***********************************
115133
* pybuffer_adaptor implementation *
116134
***********************************/
@@ -240,7 +258,43 @@ namespace xt
240258
{
241259
return rend();
242260
}
243-
261+
262+
template<class T>
263+
inline bool operator==(const pybuffer_adaptor<T>& x, const pybuffer_adaptor<T>& y)
264+
{
265+
return (x.size() == y.size() && std::equal(x.begin(), x.end(), y.begin()));
266+
}
267+
268+
template<class T>
269+
inline bool operator<(const pybuffer_adaptor<T>& x, const pybuffer_adaptor<T>& y)
270+
{
271+
return std::lexicographical_compare(x.begin(), x.end(), y.begin(), y.end());
272+
}
273+
274+
template<class T>
275+
inline bool operator!=(const pybuffer_adaptor<T>& x, const pybuffer_adaptor<T>& y)
276+
{
277+
return !(x == y);
278+
}
279+
280+
template<class T>
281+
inline bool operator>(const pybuffer_adaptor<T>& x, const pybuffer_adaptor<T>& y)
282+
{
283+
return y < x;
284+
}
285+
286+
template<class T>
287+
inline bool operator<=(const pybuffer_adaptor<T>& x, const pybuffer_adaptor<T>& y)
288+
{
289+
return !(y < x);
290+
}
291+
292+
template<class T>
293+
inline bool operator>=(const pybuffer_adaptor<T>& x, const pybuffer_adaptor<T>& y)
294+
{
295+
return !(x < y);
296+
}
297+
244298
/*************************************
245299
* pystrides_iterator implementation *
246300
*************************************/

test/main.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,21 @@ PYBIND11_PLUGIN(xtensor_python_test)
5151

5252
py::module m("xtensor_python_test", "Test module for xtensor python bindings");
5353

54-
m.def("example1", example1, "");
55-
m.def("example2", example2, "");
54+
m.def("example1", example1);
55+
m.def("example2", example2);
5656

57-
m.def("readme_example1", readme_example1, "");
58-
m.def("readme_example2", xt::pyvectorize(readme_example2), "");
57+
m.def("readme_example1", readme_example1);
58+
m.def("readme_example2", xt::pyvectorize(readme_example2));
5959

60-
m.def("vectorize_example1", xt::pyvectorize(add), "");
60+
m.def("vectorize_example1", xt::pyvectorize(add));
6161

62-
m.def("rect_to_polar", [](xt::pyarray<complex_t> const& a) {
63-
return py::make_tuple(xt::pyvectorize([](complex_t x) { return std::abs(x); })(a),
64-
xt::pyvectorize([](complex_t x) { return std::arg(x); })(a));
62+
m.def("rect_to_polar", [](const xt::pyarray<complex_t>& a) {
63+
return py::make_tuple(xt::pyvectorize([](complex_t x) { return std::abs(x); })(a),
64+
xt::pyvectorize([](complex_t x) { return std::arg(x); })(a));
65+
});
66+
67+
m.def("compare_shapes", [](const xt::pyarray<double>& a, const xt::pyarray<double>& b) {
68+
return a.shape() == b.shape();
6569
});
6670

6771
return m.ptr();

test/test_pyarray.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,10 @@ def test_rect_to_polar(self):
5050
z = xt.rect_to_polar(x[::2]);
5151
np.testing.assert_allclose(z, (np.ones(5, dtype=float), np.zeros(5, dtype=float)), 1e-5)
5252

53+
def test_shape_comparison(self):
54+
x = np.ones([4, 4])
55+
y = np.ones([5, 5])
56+
z = np.zeros([4, 4])
57+
self.assertFalse(xt.compare_shapes(x, y))
58+
self.assertTrue(xt.compare_shapes(x, z))
59+

0 commit comments

Comments
 (0)