@@ -1858,3 +1858,35 @@ def test_put_indices_oob_py_ssize_t(mode):
18581858
18591859 assert dpt .all (x [:- 1 ] == - 1 )
18601860 assert x [- 1 ] == i
1861+
1862+
1863+ def test_take_along_axis_uint64_indices ():
1864+ get_queue_or_skip ()
1865+
1866+ inds = dpt .arange (1 , 10 , 2 , dtype = "u8" )
1867+ x = dpt .tile (dpt .asarray ([0 , - 1 ], dtype = "i4" ), 5 )
1868+ res = dpt .take_along_axis (x , inds )
1869+ assert dpt .all (res == - 1 )
1870+
1871+ sh0 = 2
1872+ inds = dpt .broadcast_to (inds , (sh0 ,) + inds .shape )
1873+ x = dpt .broadcast_to (x , (sh0 ,) + x .shape )
1874+ res = dpt .take_along_axis (x , inds , axis = 1 )
1875+ assert dpt .all (res == - 1 )
1876+
1877+
1878+ def test_put_along_axis_uint64_indices ():
1879+ get_queue_or_skip ()
1880+
1881+ inds = dpt .arange (1 , 10 , 2 , dtype = "u8" )
1882+ x = dpt .zeros (10 , dtype = "i4" )
1883+ dpt .put_along_axis (x , inds , dpt .asarray (2 , dtype = x .dtype ))
1884+ expected = dpt .tile (dpt .asarray ([0 , 2 ], dtype = "i4" ), 5 )
1885+ assert dpt .all (x == expected )
1886+
1887+ sh0 = 2
1888+ inds = dpt .broadcast_to (inds , (sh0 ,) + inds .shape )
1889+ x = dpt .zeros ((sh0 ,) + x .shape , dtype = "i4" )
1890+ dpt .put_along_axis (x , inds , dpt .asarray (2 , dtype = x .dtype ), axis = 1 )
1891+ expected = dpt .tile (dpt .asarray ([0 , 2 ], dtype = "i4" ), (2 , 5 ))
1892+ assert dpt .all (expected == x )
0 commit comments