diff --git a/python/src/indexing.cpp b/python/src/indexing.cpp index 95c6e2a185..5853621667 100644 --- a/python/src/indexing.cpp +++ b/python/src/indexing.cpp @@ -16,10 +16,13 @@ bool is_none_slice(const nb::slice& in_slice) { int get_slice_int(nb::object obj, int default_val) { if (!obj.is_none()) { - if (!nb::isinstance(obj)) { + // Try to cast to int - this handles Python int, numpy scalars, and other + // int-like types + try { + return nb::cast(obj); + } catch (...) { throw std::invalid_argument("Slice indices must be integers or None."); } - return nb::cast(nb::cast(obj)); } return default_val; } @@ -50,10 +53,58 @@ mx::array get_int_index(nb::object idx, int axis_size) { return mx::array(idx_, mx::uint32); } +// Convert boolean mask to integer indices +// Returns a packed array of indices where mask is True +// Uses a simple sort-based algorithm +std::pair boolean_mask_to_indices_and_count( + const mx::array& mask) { + // Flatten the boolean mask if it's multi-dimensional + auto flat_mask = (mask.ndim() > 1) ? flatten(mask) : mask; + + auto size = flat_mask.size(); + + // Count total True values using sum + auto mask_int = astype(flat_mask, mx::int32); + auto num_true_arr = sum(mask_int); + num_true_arr.eval(); // Force evaluation to get the count + int num_true = num_true_arr.item(); + + if (num_true == 0) { + // Return empty array + return {mx::array({}, mx::uint32), 0}; + } + + // Create array of all indices [0, 1, 2, ..., size-1] + auto all_indices = arange(0, size, 1, mx::int32); + + // Use where to assign indices or large sentinel value, then sort + auto large_value = size; // Use size as sentinel for False positions + auto indexed = + where(flat_mask, all_indices, mx::array(large_value, mx::int32)); + auto sorted_result = sort(indexed); + + // Slice to get only valid indices (first num_true elements after sorting) + auto result = slice(sorted_result, {0}, {num_true}, {1}); + + return {astype(result, mx::uint32), num_true}; +} + bool is_valid_index_type(const nb::object& obj) { - return nb::isinstance(obj) || nb::isinstance(obj) || + // Fast path: check common types first + if (nb::isinstance(obj) || nb::isinstance(obj) || nb::isinstance(obj) || obj.is_none() || - nb::ellipsis().is(obj) || nb::isinstance(obj); + nb::ellipsis().is(obj) || nb::isinstance(obj)) { + return true; + } + + // Fallback: try to cast to int (handles numpy scalars and other int-like + // types) + try { + nb::cast(obj); + return true; + } catch (...) { + return false; + } } mx::array mlx_get_item_slice(const mx::array& src, const nb::slice& in_slice) { @@ -84,8 +135,33 @@ mx::array mlx_get_item_array(const mx::array& src, const mx::array& indices) { "too many indices for array: array is 0-dimensional"); } + // Handle boolean indexing if (indices.dtype() == mx::bool_) { - throw std::invalid_argument("boolean indices are not yet supported"); + // Boolean indexing: convert boolean mask to integer indices + auto [int_indices, count] = boolean_mask_to_indices_and_count(indices); + + if (count == 0) { + // Empty selection - return empty array with appropriate shape + mx::Shape out_shape = {0}; + out_shape.insert( + out_shape.end(), src.shape().begin() + 1, src.shape().end()); + return zeros(out_shape, src.dtype()); + } + + // Flatten source if mask is multi-dimensional or doesn't match first dim + if (indices.size() == src.size()) { + // Mask covers entire array - flatten both + auto flat_src = flatten(src); + return take(flat_src, int_indices, 0); + } else if (indices.size() == src.shape(0)) { + // Mask is for first dimension only + return take(src, int_indices, 0); + } else { + throw std::invalid_argument( + "boolean index did not match indexed array; size is " + + std::to_string(indices.size()) + " but corresponding dimension is " + + std::to_string(src.shape(0))); + } } // If only one input array is mentioned, we set axis=0 in take @@ -442,7 +518,16 @@ mx::array mlx_get_item(const mx::array& src, const nb::object& obj) { return mlx_get_item_array( src, array_from_list(nb::cast(obj), {})); } - throw std::invalid_argument("Cannot index mlx array using the given type."); + + // Fallback: try to treat as integer index (handles numpy scalars and other + // int-like types) + try { + // Convert to Python int first to handle numpy scalars + nb::int_ idx = nb::int_(obj); + return mlx_get_item_int(src, idx); + } catch (...) { + throw std::invalid_argument("Cannot index mlx array using the given type."); + } } std::tuple, mx::array, std::vector> @@ -489,6 +574,27 @@ mlx_scatter_args_array( "too many indices for array: array is 0-dimensional"); } + // Handle boolean indexing for scatter + if (indices.dtype() == mx::bool_) { + auto [int_indices, count] = boolean_mask_to_indices_and_count(indices); + + if (count == 0) { + // No elements to update - return empty scatter args + return {{}, src, {}}; + } + + auto up = squeeze_leading_singletons(update); + + // The update shape must broadcast with int_indices.shape + src.shape[1:] + auto up_shape = int_indices.shape(); + up_shape.insert(up_shape.end(), src.shape().begin() + 1, src.shape().end()); + up = broadcast_to(up, up_shape); + up_shape.insert(up_shape.begin() + int_indices.ndim(), 1); + up = reshape(up, up_shape); + + return {{int_indices}, up, {0}}; + } + auto up = squeeze_leading_singletons(update); // The update shape must broadcast with indices.shape + [1] + src.shape[1:] @@ -757,7 +863,15 @@ mlx_compute_scatter_args( src, array_from_list(nb::cast(obj), {}), vals); } - throw std::invalid_argument("Cannot index mlx array using the given type."); + // Fallback: try to treat as integer index (handles numpy scalars and other + // int-like types) + try { + // Convert to Python int first to handle numpy scalars + nb::int_ idx = nb::int_(obj); + return mlx_scatter_args_int(src, idx, vals); + } catch (...) { + throw std::invalid_argument("Cannot index mlx array using the given type."); + } } auto mlx_slice_update( diff --git a/python/tests/test_array.py b/python/tests/test_array.py index a8f9474ef1..97a4d16fc8 100644 --- a/python/tests/test_array.py +++ b/python/tests/test_array.py @@ -1105,6 +1105,108 @@ def index_fn(x, ind): self.assertTrue(mx.array_equal(grad_x, expected)) self.assertTrue(mx.array_equal(grad_ind, mx.zeros(ind.shape))) + def test_numpy_scalar_indexing(self): + """Test indexing with numpy scalar types""" + # Basic numpy scalar indexing + x = mx.array([1, 2, 3, 4, 5]) + result = x[np.int64(1)] + self.assertEqual(result.item(), 2) + + # Numpy scalar in slice start + result = x[np.int64(1) :] + self.assertTrue(np.array_equal(np.array(result), np.array([2, 3, 4, 5]))) + + # Numpy scalar in slice stop + result = x[: np.int64(3)] + self.assertTrue(np.array_equal(np.array(result), np.array([1, 2, 3]))) + + # Other numpy scalar types + result = x[np.int32(2)] + self.assertEqual(result.item(), 3) + + # Negative numpy scalar indexing + result = x[np.int64(-1)] + self.assertEqual(result.item(), 5) + + # Numpy scalar assignment + x_copy = mx.array([1, 2, 3, 4, 5]) + x_copy[np.int64(2)] = 99 + self.assertTrue(np.array_equal(np.array(x_copy), np.array([1, 2, 99, 4, 5]))) + + # Numpy scalar in both slice start and stop + x = mx.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) + result = x[np.int64(2) : np.int64(8)] + expected = np.arange(1, 11)[np.int64(2) : np.int64(8)] + self.assertTrue(np.array_equal(np.array(result), expected)) + + # Test with 2D array + x_2d = mx.array([[1, 2], [3, 4], [5, 6]]) + result = x_2d[np.int32(1)] + self.assertTrue(np.array_equal(np.array(result), np.array([3, 4]))) + + def test_boolean_mask_indexing(self): + """Test boolean mask indexing""" + # Basic boolean indexing + x = mx.array([1, 2, 3, 4, 5]) + mask = x > 2 + result = x[mask] + self.assertTrue(np.array_equal(np.array(result), np.array([3, 4, 5]))) + + # Boolean indexing with all True + x = mx.array([1, 2, 3]) + mask = mx.array([True, True, True]) + result = x[mask] + self.assertTrue(np.array_equal(np.array(result), np.array([1, 2, 3]))) + + # Boolean indexing with all False + x = mx.array([1, 2, 3]) + mask = mx.array([False, False, False]) + result = x[mask] + self.assertEqual(result.size, 0) + + # Boolean indexing with alternating pattern + x = mx.array([10, 20, 30, 40, 50]) + mask = mx.array([True, False, True, False, True]) + result = x[mask] + self.assertTrue(np.array_equal(np.array(result), np.array([10, 30, 50]))) + + # Boolean assignment + x = mx.array([1, 2, 3, 4, 5]) + mask = x > 2 + x[mask] = 99 + self.assertTrue(np.array_equal(np.array(x), np.array([1, 2, 99, 99, 99]))) + + # Boolean indexing with 2D array (flatten behavior) + x = mx.array([[1, 2], [3, 4], [5, 6]]) + mask = x > 3 + result = x[mask] + expected = np.array([4, 5, 6]) + self.assertTrue(np.array_equal(np.array(result), expected)) + + # Boolean indexing with negative values + x = mx.array([-3, -1, 0, 2, 4]) + mask = x < 0 + result = x[mask] + self.assertTrue(np.array_equal(np.array(result), np.array([-3, -1]))) + + # Complex boolean condition + x = mx.array([0, 1, 2, 3, 4, 5, 6]) + mask = (x > 1) & (x < 5) + result = x[mask] + self.assertTrue(np.array_equal(np.array(result), np.array([2, 3, 4]))) + + # Empty result from boolean indexing + x = mx.array([1, 2, 3]) + mask = x > 10 + result = x[mask] + self.assertEqual(result.size, 0) + + # Single element from boolean indexing + x = mx.array([1, 2, 3, 4, 5]) + mask = x == 3 + result = x[mask] + self.assertTrue(np.array_equal(np.array(result), np.array([3]))) + def test_setitem(self): a = mx.array(0) a[None] = 1