From 35b631f718e47985b82119a0b02b6342f0304bc1 Mon Sep 17 00:00:00 2001 From: Evgeni Burovski Date: Thu, 5 Feb 2026 16:10:13 +0100 Subject: [PATCH] WIP: add axis tuple support to torch.expand_dims --- array_api_compat/torch/_aliases.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index 4b232f84..512c1060 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -690,9 +690,24 @@ def triu(x: Array, /, *, k: int = 0) -> Array: return torch.triu(x, k) # Functions that aren't in torch https://github.com/pytorch/pytorch/issues/58742 -def expand_dims(x: Array, /, *, axis: int = 0) -> Array: - return torch.unsqueeze(x, axis) +def expand_dims(x: Array, /, *, axis: int | tuple[int, ...]) -> Array: + if isinstance(axis, int): + return torch.unsqueeze(x, axis) + else: + # follow https://github.com/numpy/numpy/blob/maintenance/2.4.x/numpy/lib/_shape_base_impl.py#L596-L602 + y_ndim = x.ndim + len(axis) + + # normalize + n_axis = tuple(ax + y_ndim if ax < 0 else ax for ax in axis) + if (len(n_axis) != len(set(n_axis)) or + _builtin_any(ax < 0 or ax >= y_ndim for ax in n_axis) + ): + raise ValueError(f"{axis=} not allowed for {x.shape = }") + + shape_it = iter(x.shape) + shape = [1 if ax in n_axis else next(shape_it) for ax in range(y_ndim)] + return torch.reshape(x, shape) def astype( x: Array,