From d9f2a7a0a39cbea25fd51d693250c137a61ddaf0 Mon Sep 17 00:00:00 2001 From: Athan Date: Mon, 2 Feb 2026 00:23:07 -0800 Subject: [PATCH 1/5] fix: address signature regression in `expand_dims` In https://github.com/data-apis/array-api/pull/354, a regression was introduced which reverted a change to the signature of `expand_dims`. Namely, the `axis` argument should not have been made optional and should not have had a default value. Ref: https://github.com/data-apis/array-api/pull/331 Ref: https://github.com/data-apis/array-api/pull/354 --- src/array_api_stubs/_2021_12/manipulation_functions.py | 2 +- src/array_api_stubs/_2022_12/manipulation_functions.py | 2 +- src/array_api_stubs/_2023_12/manipulation_functions.py | 2 +- src/array_api_stubs/_2024_12/manipulation_functions.py | 2 +- src/array_api_stubs/_draft/manipulation_functions.py | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/array_api_stubs/_2021_12/manipulation_functions.py b/src/array_api_stubs/_2021_12/manipulation_functions.py index 8ae359a3a..9f2c12737 100644 --- a/src/array_api_stubs/_2021_12/manipulation_functions.py +++ b/src/array_api_stubs/_2021_12/manipulation_functions.py @@ -24,7 +24,7 @@ def concat( """ -def expand_dims(x: array, /, *, axis: int = 0) -> array: +def expand_dims(x: array, /, axis: int) -> array: """ Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``. diff --git a/src/array_api_stubs/_2022_12/manipulation_functions.py b/src/array_api_stubs/_2022_12/manipulation_functions.py index 2d7179a8b..a035801b6 100644 --- a/src/array_api_stubs/_2022_12/manipulation_functions.py +++ b/src/array_api_stubs/_2022_12/manipulation_functions.py @@ -58,7 +58,7 @@ def concat( """ -def expand_dims(x: array, /, *, axis: int = 0) -> array: +def expand_dims(x: array, /, axis: int) -> array: """ Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``. diff --git a/src/array_api_stubs/_2023_12/manipulation_functions.py b/src/array_api_stubs/_2023_12/manipulation_functions.py index 131b81eb3..311f815d7 100644 --- a/src/array_api_stubs/_2023_12/manipulation_functions.py +++ b/src/array_api_stubs/_2023_12/manipulation_functions.py @@ -76,7 +76,7 @@ def concat( """ -def expand_dims(x: array, /, *, axis: int = 0) -> array: +def expand_dims(x: array, /, axis: int) -> array: """ Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``. diff --git a/src/array_api_stubs/_2024_12/manipulation_functions.py b/src/array_api_stubs/_2024_12/manipulation_functions.py index dd8d4cd69..cf9ec18ee 100644 --- a/src/array_api_stubs/_2024_12/manipulation_functions.py +++ b/src/array_api_stubs/_2024_12/manipulation_functions.py @@ -79,7 +79,7 @@ def concat( """ -def expand_dims(x: array, /, *, axis: int = 0) -> array: +def expand_dims(x: array, /, axis: int) -> array: """ Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``. diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 33068febc..955542eed 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -79,7 +79,7 @@ def concat( """ -def expand_dims(x: array, /, *, axis: int = 0) -> array: +def expand_dims(x: array, /, axis: int) -> array: """ Expands the shape of an array by inserting a new axis of size one at the position specified by ``axis``. @@ -88,7 +88,7 @@ def expand_dims(x: array, /, *, axis: int = 0) -> array: x: array input array. axis: int - axis position (zero-based). A valid ``axis`` **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of axes in ``x``. If an axis is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. Hence, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). If provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). If provided an invalid axis, the function **must** raise an exception. Default: ``0``. + axis position (zero-based). A valid ``axis`` **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of axes in ``x``. If an axis is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. Hence, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). If provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). If provided an invalid axis, the function **must** raise an exception. Returns ------- From 5a5874cb69e3472ca5fbb1c3f27e02d4e321eaa6 Mon Sep 17 00:00:00 2001 From: Athan Date: Thu, 5 Feb 2026 11:20:54 +0100 Subject: [PATCH 2/5] feat: add support for specifying a tuple of axis positions Closes: https://github.com/data-apis/array-api/issues/760 --- .../_draft/manipulation_functions.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 955542eed..7a4e63e46 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -79,21 +79,28 @@ def concat( """ -def expand_dims(x: array, /, axis: int) -> array: +def expand_dims(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: """ - Expands the shape of an array by inserting a new axis of size one at the position specified by ``axis``. + Expands the shape of an array by inserting a new axis of size one at the position (or positions) specified by ``axis``. Parameters ---------- x: array input array. - axis: int - axis position (zero-based). A valid ``axis`` **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of axes in ``x``. If an axis is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. Hence, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). If provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). If provided an invalid axis, the function **must** raise an exception. + axis: Union[int, Tuple[int, ...]] + axis position(s) (zero-based). A valid axis position **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of dimensions in ``x``. If an axis position is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. Hence, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). If provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). If provided an invalid axis position, the function **must** raise an exception. + + If ``axis`` is a tuple, + + - each entry of ``axis`` must resolve to a unique axis position. If an entry is a negative integer, the entry **must** resolve to a positive axis position according to the rules described above. + - if provided an invalid axis position, the function **must** raise an exception. + - for each entry of ``axis``, the corresponding dimension in the expanded output array **must** be a singleton dimension. + - for the remaining dimensions of the expanded output array, the output array dimensions **must** correspond to the dimensions of ``x`` in order. Returns ------- out: array - an expanded output array. **Must** have the same data type as ``x``. + an expanded output array. **Must** have the same data type as ``x``. If ``axis`` is an integer, the output array must have ``N + 1`` dimensions. If ``axis`` is a tuple, the output array must have ``N + len(axis)`` dimensions. Raises ------ From dceea5e498dc156bdd096fc97f9a350c2b185b22 Mon Sep 17 00:00:00 2001 From: Athan Date: Thu, 5 Feb 2026 15:46:45 +0100 Subject: [PATCH 3/5] fix: clarify resolution rules --- .../_draft/manipulation_functions.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 7a4e63e46..3d2e34ab1 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -88,14 +88,22 @@ def expand_dims(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: x: array input array. axis: Union[int, Tuple[int, ...]] - axis position(s) (zero-based). A valid axis position **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of dimensions in ``x``. If an axis position is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. Hence, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). If provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). If provided an invalid axis position, the function **must** raise an exception. + axis position(s) (zero-based). If ``axis`` is an integer, + + - a valid axis position **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of dimensions in ``x``. + - if an axis position is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. + - if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). + - if provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). + - if provided an invalid axis position, the function **must** raise an exception. If ``axis`` is a tuple, - - each entry of ``axis`` must resolve to a unique axis position. If an entry is a negative integer, the entry **must** resolve to a positive axis position according to the rules described above. - - if provided an invalid axis position, the function **must** raise an exception. + - a valid axis position **must** reside on the closed-interval ``[-M-1, M]``, where ``M = N + len(axis) - 1`` and ``N`` is the number of dimensions in ``x``. + - if an entry is a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``M + axis + 1``. + - each entry of ``axis`` must resolve to a unique positive axis position. - for each entry of ``axis``, the corresponding dimension in the expanded output array **must** be a singleton dimension. - for the remaining dimensions of the expanded output array, the output array dimensions **must** correspond to the dimensions of ``x`` in order. + - if provided an invalid axis position, the function **must** raise an exception. Returns ------- @@ -106,6 +114,15 @@ def expand_dims(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: ------ IndexError If provided an invalid ``axis``, an ``IndexError`` **should** be raised. + + Notes + ----- + + - Calling this function with a tuple of axis positions **must** be semantically equivalent to calling this function repeatedly with a single axis position only when the following three conditions are met: + + - each entry of the tuple is normalized to positive axis positions according to the number of dimensions in the expanded output array. + - the normalized positive axis positions are sorted in ascending order. + - the normalized positive axis positions are unique. """ From aa0b8a1112406f664d123390b563fa754112e0ca Mon Sep 17 00:00:00 2001 From: Athan Date: Thu, 5 Feb 2026 18:42:39 +0100 Subject: [PATCH 4/5] docs: rephrase copy --- src/array_api_stubs/_draft/manipulation_functions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 3d2e34ab1..1e9cc833f 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -91,9 +91,7 @@ def expand_dims(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: axis position(s) (zero-based). If ``axis`` is an integer, - a valid axis position **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of dimensions in ``x``. - - if an axis position is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. - - if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). - - if provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). + - if an axis position is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. For example, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). Similarly, if provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). - if provided an invalid axis position, the function **must** raise an exception. If ``axis`` is a tuple, From 67eaae1039f36ac2c36a206d07cddf5dfc996f3e Mon Sep 17 00:00:00 2001 From: Athan Date: Thu, 5 Feb 2026 18:47:41 +0100 Subject: [PATCH 5/5] docs: update copy --- src/array_api_stubs/_draft/manipulation_functions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/array_api_stubs/_draft/manipulation_functions.py b/src/array_api_stubs/_draft/manipulation_functions.py index 1e9cc833f..d23216005 100644 --- a/src/array_api_stubs/_draft/manipulation_functions.py +++ b/src/array_api_stubs/_draft/manipulation_functions.py @@ -91,13 +91,13 @@ def expand_dims(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array: axis position(s) (zero-based). If ``axis`` is an integer, - a valid axis position **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of dimensions in ``x``. - - if an axis position is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. For example, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). Similarly, if provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). + - if an axis position is specified as a negative integer, the axis position of the inserted singleton dimension in the output array **must** be computed as ``N + axis + 1``. For example, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). Similarly, if provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). - if provided an invalid axis position, the function **must** raise an exception. If ``axis`` is a tuple, - a valid axis position **must** reside on the closed-interval ``[-M-1, M]``, where ``M = N + len(axis) - 1`` and ``N`` is the number of dimensions in ``x``. - - if an entry is a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``M + axis + 1``. + - if an entry is a negative integer, the axis position of the inserted singleton dimension in the output array **must** be computed as ``M + axis + 1``. - each entry of ``axis`` must resolve to a unique positive axis position. - for each entry of ``axis``, the corresponding dimension in the expanded output array **must** be a singleton dimension. - for the remaining dimensions of the expanded output array, the output array dimensions **must** correspond to the dimensions of ``x`` in order.