Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/array_api_stubs/_2021_12/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def concat(
"""


def expand_dims(x: array, /, *, axis: int = 0) -> array:
def expand_dims(x: array, /, axis: int) -> array:
"""
Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``.

Expand Down
2 changes: 1 addition & 1 deletion src/array_api_stubs/_2022_12/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def concat(
"""


def expand_dims(x: array, /, *, axis: int = 0) -> array:
def expand_dims(x: array, /, axis: int) -> array:
"""
Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``.

Expand Down
2 changes: 1 addition & 1 deletion src/array_api_stubs/_2023_12/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def concat(
"""


def expand_dims(x: array, /, *, axis: int = 0) -> array:
def expand_dims(x: array, /, axis: int) -> array:
"""
Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``.

Expand Down
2 changes: 1 addition & 1 deletion src/array_api_stubs/_2024_12/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def concat(
"""


def expand_dims(x: array, /, *, axis: int = 0) -> array:
def expand_dims(x: array, /, axis: int) -> array:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""
Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``.

Expand Down
32 changes: 27 additions & 5 deletions src/array_api_stubs/_draft/manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,26 +79,48 @@ def concat(
"""


def expand_dims(x: array, /, *, axis: int = 0) -> array:
def expand_dims(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
"""
Expands the shape of an array by inserting a new axis of size one at the position specified by ``axis``.
Expands the shape of an array by inserting a new axis of size one at the position (or positions) specified by ``axis``.

Parameters
----------
x: array
input array.
axis: int
axis position (zero-based). A valid ``axis`` **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of axes in ``x``. If an axis is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. Hence, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). If provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). If provided an invalid axis, the function **must** raise an exception. Default: ``0``.
axis: Union[int, Tuple[int, ...]]
axis position(s) (zero-based). If ``axis`` is an integer,

- a valid axis position **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of dimensions in ``x``.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea: would it be clearer here to talk about valid indices in terms of the output dimensions? Then this would change to

- a valid axis position **must** reside on the semi-open interval ``[-M, M)` where
  `M = x.ndim + 1` is the number of dimensions of the *output* array.

then the tuple version of this would be identical, except it would say M = ndim(x) + len(axis)

- if an axis position is specified as a negative integer, the axis position of the inserted singleton dimension in the output array **must** be computed as ``N + axis + 1``. For example, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). Similarly, if provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``).
- if provided an invalid axis position, the function **must** raise an exception.

If ``axis`` is a tuple,

- a valid axis position **must** reside on the closed-interval ``[-M-1, M]``, where ``M = N + len(axis) - 1`` and ``N`` is the number of dimensions in ``x``.
- if an entry is a negative integer, the axis position of the inserted singleton dimension in the output array **must** be computed as ``M + axis + 1``.
- each entry of ``axis`` must resolve to a unique positive axis position.
- for each entry of ``axis``, the corresponding dimension in the expanded output array **must** be a singleton dimension.
- for the remaining dimensions of the expanded output array, the output array dimensions **must** correspond to the dimensions of ``x`` in order.
- if provided an invalid axis position, the function **must** raise an exception.

Returns
-------
out: array
an expanded output array. **Must** have the same data type as ``x``.
an expanded output array. **Must** have the same data type as ``x``. If ``axis`` is an integer, the output array must have ``N + 1`` dimensions. If ``axis`` is a tuple, the output array must have ``N + len(axis)`` dimensions.

Raises
------
IndexError
If provided an invalid ``axis``, an ``IndexError`` **should** be raised.

Notes
-----

- Calling this function with a tuple of axis positions **must** be semantically equivalent to calling this function repeatedly with a single axis position only when the following three conditions are met:

- each entry of the tuple is normalized to positive axis positions according to the number of dimensions in the expanded output array.
- the normalized positive axis positions are sorted in ascending order.
- the normalized positive axis positions are unique.
"""


Expand Down