From e572043f7a1f4613dc3d2a9b287294d72baa5d9b Mon Sep 17 00:00:00 2001 From: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Date: Thu, 14 May 2026 00:29:18 +0530 Subject: [PATCH 1/6] Fixing check_shape and scalar inputs Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> --- src/array_api_extra/_lib/_testing.py | 36 ++++++++++++++++------------ 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index cc2306bd..6e14c7a5 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -37,7 +37,7 @@ def _check_ns_shape_dtype( check_dtype: bool, check_shape: bool, check_scalar: bool, -) -> ModuleType: # numpydoc ignore=RT03 +) -> tuple[Array, Array, ModuleType]: # numpydoc ignore=RT03 """ Assert that namespace, shape and dtype of the two arrays match. @@ -55,7 +55,7 @@ def _check_ns_shape_dtype( Returns ------- - Arrays namespace. + Actual array, desired array, and their namespace. """ actual_xp = array_namespace(actual) # Raises on scalars and lists desired_xp = array_namespace(desired) @@ -63,6 +63,16 @@ def _check_ns_shape_dtype( msg = f"namespaces do not match: {actual_xp} != f{desired_xp}" assert actual_xp == desired_xp, msg + if is_numpy_namespace(actual_xp) and check_scalar: + # only NumPy distinguishes between scalars and arrays; we do if check_scalar. + _msg = ( + "array-ness does not match:\n Actual: " + f"{type(actual)}\n Desired: {type(desired)}" + ) + assert np.isscalar(actual) == np.isscalar(desired), _msg + + actual = desired_xp.asarray(actual) + desired = desired_xp.asarray(desired) # Dask uses nan instead of None for unknown shapes actual_shape = cast(tuple[float, ...], actual.shape) desired_shape = cast(tuple[float, ...], desired.shape) @@ -90,16 +100,8 @@ def _check_ns_shape_dtype( if check_dtype: msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}" assert actual.dtype == desired.dtype, msg - - if is_numpy_namespace(actual_xp) and check_scalar: - # only NumPy distinguishes between scalars and arrays; we do if check_scalar. - _msg = ( - "array-ness does not match:\n Actual: " - f"{type(actual)}\n Desired: {type(desired)}" - ) - assert np.isscalar(actual) == np.isscalar(desired), _msg - - return desired_xp + desired = desired_xp.broadcast_to(desired, actual.shape) + return actual, desired, desired_xp def _is_materializable(x: Array) -> bool: @@ -169,7 +171,9 @@ def xp_assert_equal( xp_assert_close : Similar function for inexact equality checks. numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ - xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) + actual, desired, xp = _check_ns_shape_dtype( + actual, desired, check_dtype, check_shape, check_scalar + ) if not _is_materializable(actual): return actual_np = as_numpy_array(actual, xp=xp) @@ -211,7 +215,7 @@ def xp_assert_less( xp_assert_close : Similar function for inexact equality checks. numpy.testing.assert_array_equal : Similar function for NumPy arrays. """ - xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar) + x, y, xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar) if not _is_materializable(x): return x_np = as_numpy_array(x, xp=xp) @@ -267,7 +271,9 @@ def xp_assert_close( ----- The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`. """ - xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar) + actual, desired, xp = _check_ns_shape_dtype( + actual, desired, check_dtype, check_shape, check_scalar + ) if not _is_materializable(actual): return From fc9cfddbe86855850978a735a2433127e44dd2ec Mon Sep 17 00:00:00 2001 From: Pradyot Ranjan <99216956+prady0t@users.noreply.github.com> Date: Thu, 14 May 2026 08:56:36 +0530 Subject: [PATCH 2/6] Update src/array_api_extra/_lib/_testing.py Co-authored-by: Lucas Colley --- src/array_api_extra/_lib/_testing.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 6e14c7a5..3d7f770e 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -71,8 +71,6 @@ def _check_ns_shape_dtype( ) assert np.isscalar(actual) == np.isscalar(desired), _msg - actual = desired_xp.asarray(actual) - desired = desired_xp.asarray(desired) # Dask uses nan instead of None for unknown shapes actual_shape = cast(tuple[float, ...], actual.shape) desired_shape = cast(tuple[float, ...], desired.shape) From 1b89cfa02b5d125d5020064c58ba6ca278e1616f Mon Sep 17 00:00:00 2001 From: Pradyot Ranjan <99216956+prady0t@users.noreply.github.com> Date: Thu, 14 May 2026 08:56:48 +0530 Subject: [PATCH 3/6] Update src/array_api_extra/_lib/_testing.py Co-authored-by: Lucas Colley --- src/array_api_extra/_lib/_testing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 3d7f770e..bababcf4 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -57,7 +57,7 @@ def _check_ns_shape_dtype( ------- Actual array, desired array, and their namespace. """ - actual_xp = array_namespace(actual) # Raises on scalars and lists + actual_xp = array_namespace(actual) # Raises on Python scalars and lists desired_xp = array_namespace(desired) msg = f"namespaces do not match: {actual_xp} != f{desired_xp}" From c2ee68350e7ab9065daee874aeebde8d2d877538 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 14 May 2026 08:52:57 +0100 Subject: [PATCH 4/6] Update _testing.py --- src/array_api_extra/_lib/_testing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index bababcf4..d26fbebe 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -98,7 +98,7 @@ def _check_ns_shape_dtype( if check_dtype: msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}" assert actual.dtype == desired.dtype, msg - desired = desired_xp.broadcast_to(desired, actual.shape) + desired = desired_xp.broadcast_to(desired, actual_shape) return actual, desired, desired_xp From 8119b38f84ad49a335438339b29670d318dbe85d Mon Sep 17 00:00:00 2001 From: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> Date: Thu, 14 May 2026 17:06:53 +0530 Subject: [PATCH 5/6] special case for dask arrays Signed-off-by: Pradyot Ranjan <99216956+pradyotRanjan@users.noreply.github.com> --- src/array_api_extra/_lib/_testing.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index d26fbebe..7fc05292 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -78,9 +78,11 @@ def _check_ns_shape_dtype( assert None not in desired_shape if is_dask_namespace(desired_xp): if any(math.isnan(i) for i in actual_shape): - actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + actual.compute_chunk_sizes() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + actual_shape = actual.shape if any(math.isnan(i) for i in desired_shape): - desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + desired.compute_chunk_sizes() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] + desired_shape = desired.shape if check_shape: msg = f"shapes do not match: {actual_shape} != f{desired_shape}" From 97b67a80d76cc6a1acdbd06b60fd78db07af9311 Mon Sep 17 00:00:00 2001 From: Lucas Colley Date: Thu, 14 May 2026 12:51:13 +0100 Subject: [PATCH 6/6] typing --- src/array_api_extra/_lib/_testing.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/array_api_extra/_lib/_testing.py b/src/array_api_extra/_lib/_testing.py index 7fc05292..d2a57c86 100644 --- a/src/array_api_extra/_lib/_testing.py +++ b/src/array_api_extra/_lib/_testing.py @@ -76,13 +76,14 @@ def _check_ns_shape_dtype( desired_shape = cast(tuple[float, ...], desired.shape) assert None not in actual_shape # Requires explicit support assert None not in desired_shape + if is_dask_namespace(desired_xp): if any(math.isnan(i) for i in actual_shape): actual.compute_chunk_sizes() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - actual_shape = actual.shape + actual_shape = cast(tuple[float, ...], actual.shape) if any(math.isnan(i) for i in desired_shape): desired.compute_chunk_sizes() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue] - desired_shape = desired.shape + desired_shape = cast(tuple[float, ...], desired.shape) if check_shape: msg = f"shapes do not match: {actual_shape} != f{desired_shape}" @@ -92,8 +93,8 @@ def _check_ns_shape_dtype( # np.testing.assert_array_equal etc even when strict=False, but not for # non-materializable arrays. # This check excludes 0d arrays as they are special-cased in NumPy. - actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType] - desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType] + actual_size = math.prod(actual_shape) + desired_size = math.prod(desired_shape) msg = f"sizes do not match: {actual_size} != f{desired_size}" assert actual_size == desired_size, msg