feat: add support for specifying a tuple of axis positions in expand_dims#988
feat: add support for specifying a tuple of axis positions in expand_dims#988kgryte wants to merge 5 commits intodata-apis:mainfrom
expand_dims#988Conversation
In data-apis#354, a regression was introduced which reverted a change to the signature of `expand_dims`. Namely, the `axis` argument should not have been made optional and should not have had a default value. Ref: data-apis#331 Ref: data-apis#354
ev-br
left a comment
There was a problem hiding this comment.
It would be very useful to add a comment from #760 (comment)
This behavior is semantically equivalent to calling expand_dims repeatedly with a single axis, only when the axes tuple is normalized to positive values using the final shape, is sorted, and contains no duplicates.
| If ``axis`` is a tuple, | ||
|
|
||
| - each entry of ``axis`` must resolve to a unique axis position. If an entry is a negative integer, the entry **must** resolve to a positive axis position according to the rules described above. | ||
| - if provided an invalid axis position, the function **must** raise an exception. |
There was a problem hiding this comment.
numpy raises AxisError, which derives from IndexError (which torch.unsqueeze raises) and ValueError (which jax.numpy raises). So short of adding AxisError with a prescribed inheritance hierarchy we cannot be more specific on what exception to raise.
|
@ev-br Added the desired note. I believe this is ready for another review. |
|
|
||
|
|
||
| def expand_dims(x: array, /, *, axis: int = 0) -> array: | ||
| def expand_dims(x: array, /, axis: int) -> array: |
There was a problem hiding this comment.
SciPy doesn't look too badly impacted by the reversion, I think just https://github.com/scipy/scipy/blob/341152d40c3274c0e37068321cccfb08733e2707/scipy/signal/_filter_design.py#L87
|
Let's open an issue on merge of this to plan a deprecation over at https://data-apis.org/array-api-extra/generated/array_api_extra.expand_dims.html. I'm not sure exactly what strategy is appropriate, maybe good to discuss. |
| axis: Union[int, Tuple[int, ...]] | ||
| axis position(s) (zero-based). If ``axis`` is an integer, | ||
|
|
||
| - a valid axis position **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of dimensions in ``x``. |
There was a problem hiding this comment.
One idea: would it be clearer here to talk about valid indices in terms of the output dimensions? Then this would change to
- a valid axis position **must** reside on the semi-open interval ``[-M, M)` where
`M = x.ndim + 1` is the number of dimensions of the *output* array.
then the tuple version of this would be identical, except it would say M = ndim(x) + len(axis)
This PR:
expand_dims#760expand_dims. Namely, theaxisargument should not have been made optional and should not have had a default value. This regression had gone unnoticed until working on this PR and a patch has been backported to prior revisions of the standard.expand_dims, thus addressing RFC: add support for a tuple of axes inexpand_dims#760. The added guidance follows the steps outlined in RFC: add support for a tuple of axes inexpand_dims#760 (comment).Notes
array-api-compat.