Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 27 additions & 20 deletions src/array_api_extra/_lib/_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _check_ns_shape_dtype(
check_dtype: bool,
check_shape: bool,
check_scalar: bool,
) -> ModuleType: # numpydoc ignore=RT03
) -> tuple[Array, Array, ModuleType]: # numpydoc ignore=RT03
"""
Assert that namespace, shape and dtype of the two arrays match.

Expand All @@ -55,24 +55,35 @@ def _check_ns_shape_dtype(

Returns
-------
Arrays namespace.
Actual array, desired array, and their namespace.
"""
actual_xp = array_namespace(actual) # Raises on scalars and lists
actual_xp = array_namespace(actual) # Raises on Python scalars and lists
desired_xp = array_namespace(desired)

msg = f"namespaces do not match: {actual_xp} != f{desired_xp}"
assert actual_xp == desired_xp, msg

if is_numpy_namespace(actual_xp) and check_scalar:
# only NumPy distinguishes between scalars and arrays; we do if check_scalar.
_msg = (
"array-ness does not match:\n Actual: "
f"{type(actual)}\n Desired: {type(desired)}"
)
assert np.isscalar(actual) == np.isscalar(desired), _msg
Comment thread
lucascolley marked this conversation as resolved.

# Dask uses nan instead of None for unknown shapes
actual_shape = cast(tuple[float, ...], actual.shape)
desired_shape = cast(tuple[float, ...], desired.shape)
assert None not in actual_shape # Requires explicit support
assert None not in desired_shape

if is_dask_namespace(desired_xp):
if any(math.isnan(i) for i in actual_shape):
actual_shape = actual.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
actual.compute_chunk_sizes() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
actual_shape = cast(tuple[float, ...], actual.shape)
if any(math.isnan(i) for i in desired_shape):
desired_shape = desired.compute().shape # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
desired.compute_chunk_sizes() # type: ignore[attr-defined] # pyright: ignore[reportAttributeAccessIssue]
desired_shape = cast(tuple[float, ...], desired.shape)

if check_shape:
msg = f"shapes do not match: {actual_shape} != f{desired_shape}"
Expand All @@ -82,24 +93,16 @@ def _check_ns_shape_dtype(
# np.testing.assert_array_equal etc even when strict=False, but not for
# non-materializable arrays.
# This check excludes 0d arrays as they are special-cased in NumPy.
actual_size = math.prod(actual_shape) # pyright: ignore[reportUnknownArgumentType]
desired_size = math.prod(desired_shape) # pyright: ignore[reportUnknownArgumentType]
actual_size = math.prod(actual_shape)
desired_size = math.prod(desired_shape)
msg = f"sizes do not match: {actual_size} != f{desired_size}"
assert actual_size == desired_size, msg

if check_dtype:
msg = f"dtypes do not match: {actual.dtype} != {desired.dtype}"
assert actual.dtype == desired.dtype, msg

if is_numpy_namespace(actual_xp) and check_scalar:
# only NumPy distinguishes between scalars and arrays; we do if check_scalar.
_msg = (
"array-ness does not match:\n Actual: "
f"{type(actual)}\n Desired: {type(desired)}"
)
assert np.isscalar(actual) == np.isscalar(desired), _msg

return desired_xp
desired = desired_xp.broadcast_to(desired, actual_shape)
return actual, desired, desired_xp


def _is_materializable(x: Array) -> bool:
Expand Down Expand Up @@ -169,7 +172,9 @@ def xp_assert_equal(
xp_assert_close : Similar function for inexact equality checks.
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
actual, desired, xp = _check_ns_shape_dtype(
actual, desired, check_dtype, check_shape, check_scalar
)
if not _is_materializable(actual):
return
actual_np = as_numpy_array(actual, xp=xp)
Expand Down Expand Up @@ -211,7 +216,7 @@ def xp_assert_less(
xp_assert_close : Similar function for inexact equality checks.
numpy.testing.assert_array_equal : Similar function for NumPy arrays.
"""
xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
x, y, xp = _check_ns_shape_dtype(x, y, check_dtype, check_shape, check_scalar)
if not _is_materializable(x):
return
x_np = as_numpy_array(x, xp=xp)
Expand Down Expand Up @@ -267,7 +272,9 @@ def xp_assert_close(
-----
The default `atol` and `rtol` differ from `xp.all(xpx.isclose(a, b))`.
"""
xp = _check_ns_shape_dtype(actual, desired, check_dtype, check_shape, check_scalar)
actual, desired, xp = _check_ns_shape_dtype(
actual, desired, check_dtype, check_shape, check_scalar
)
if not _is_materializable(actual):
return

Expand Down