Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3,361 changes: 3,361 additions & 0 deletions 03_23_2026.log
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you accidentally committed this

Large diffs are not rendered by default.

12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ classifiers = [
"Programming Language :: Python :: 3.14",
]
dynamic = [ "description", "readme", "version" ]
dependencies = [ "numpy>=2" ]
dependencies = [ "array-api-compat", "numpy>=2" ]
optional-dependencies.accel = [ "numba>=0.57" ]
optional-dependencies.dask = [ "dask>=2023.6.1" ]
optional-dependencies.full = [ "fast-array-utils[accel,dask,sparse]", "h5py", "zarr" ]
optional-dependencies.jax = [ "jax", "jaxlib" ]
optional-dependencies.sparse = [ "scipy>=1.13" ]
optional-dependencies.testing = [ "packaging" ]
urls."Issue Tracker" = "https://github.com/scverse/fast-array-utils/issues"
Expand All @@ -48,6 +49,15 @@ doc = [
"sphinx>=9.0.1",
"sphinx-autofixture>=0.4.1",
]
# for update-mypy-hook
mypy = [
"fast-array-utils[full]",
"scipy-stubs",
# TODO: replace sphinx with this: { include-group = "doc" },
"sphinx",
"types-docutils",
{ include-group = "test" },
]
test-min = [
"coverage[toml]",
"fast-array-utils[sparse,testing]", # include sparse for testing numba-less to_dense
Expand Down
7 changes: 7 additions & 0 deletions src/fast_array_utils/conv/_to_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ def to_dense_(
order: Literal["K", "A", "C", "F"] = "K",
to_cpu_memory: bool = False,
) -> NDArray[Any] | types.CupyArray | types.DaskArray:
import array_api_compat

if not isinstance(x, np.ndarray) and array_api_compat.is_array_api_obj(x):
if to_cpu_memory:
return np.asarray(x, order=order)
return x # already dense
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don’t think that follows, but I also don’t know if we can do better.


del to_cpu_memory # it already is
return np.asarray(x, order=order)

Expand Down
9 changes: 8 additions & 1 deletion src/fast_array_utils/stats/_generic_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,20 @@ def generic_op(
axis: Literal[0, 1] | None = None,
dtype: DTypeLike | None = None,
keep_cupy_as_array: bool = False,
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray:
) -> NDArray[Any] | np.number[Any] | types.CupyArray | types.DaskArray: # switch to Any later
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does the comment mean?

del keep_cupy_as_array
if TYPE_CHECKING:
# these are never passed to this fallback function, but `singledispatch` wants them
assert not isinstance(x, types.CSBase | types.DaskArray | types.CupyArray | types.CupyCSMatrix)
# np supports these, but doesn’t know it. (TODO: test cupy)
assert not isinstance(x, types.ZarrArray | types.H5Dataset)

# doing array_api_compat first
import array_api_compat

if array_api_compat.is_array_api_obj(x):
xp = array_api_compat.array_namespace(x)
return getattr(xp, op)(x, axis=axis, **_dtype_kw(dtype, op))
Comment on lines +53 to +58
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code inside of _run_numpy_op looks almost identical, why not rename that and do xp = array_namespace(x) if is_array_api_obj(x) else np or so?

return cast("NDArray[Any] | np.number[Any]", _run_numpy_op(x, op, axis=axis, dtype=dtype))


Expand Down
17 changes: 16 additions & 1 deletion src/fast_array_utils/stats/_is_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,29 @@

from numpy.typing import NDArray

# checking if all values in an array are the same


@singledispatch
def is_constant_(
a: NDArray[Any] | types.CSBase | types.CupyArray | types.DaskArray,
/,
*,
axis: Literal[0, 1] | None = None,
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray: # pragma: no cover
) -> bool | NDArray[np.bool] | types.CupyArray | types.DaskArray: # switch to Any later

import array_api_compat

if array_api_compat.is_array_api_obj(a):
xp = array_api_compat.array_namespace(a)
match axis:
case None:
return bool((a == xp.reshape(a, (-1,))[0]).all())
case 0:
return is_constant_(a.T, axis=1) # reusing axis = 1
case 1:
b = xp.broadcast_to(a[:, 0:1], a.shape)
return (a == b).all(axis=1)
raise NotImplementedError


Expand Down
7 changes: 7 additions & 0 deletions src/fast_array_utils/stats/_power.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def power[Arr: Array](x: Arr, n: int, /, dtype: DTypeLike | None = None) -> Arr:
def _power(x: Array, n: int, /, dtype: DTypeLike | None = None) -> Array:
if TYPE_CHECKING:
assert not isinstance(x, types.DaskArray | types.CSBase | types.CupyCSMatrix)

import array_api_compat

if array_api_compat.is_array_api_obj(x):
xp = array_api_compat.array_namespace(x)
return xp.pow(x, n) if dtype is None else xp.pow(x.astype(dtype), n)

return x**n if dtype is None else np.power(x, n, dtype=dtype) # type: ignore[operator]


Expand Down
121 changes: 121 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-License-Identifier: MPL-2.0
from __future__ import annotations

from importlib.util import find_spec
from typing import TYPE_CHECKING

import numpy as np
import pytest

from fast_array_utils import stats
from fast_array_utils.conv import to_dense


if TYPE_CHECKING:
from typing import Any, Literal

pytestmark = pytest.mark.skipif(not find_spec("jax"), reason="jax not installed")

if find_spec("jax"):
# enabling 64-bit precision in JAX as it defaults to 32-bit only
# problem as mean_var passes dtype= np.float64 internally, which crashes without this fix
import jax

jax.config.update("jax_enable_x64", True)


@pytest.fixture
def jax_arr() -> Any:
import jax.numpy as jnp

return jnp.array([[1, 0], [2, 0], [3, 0]], dtype=jnp.float32)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_sum(jax_arr: Any, axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

result = stats.sum(jax_arr, axis=axis)
expected = jnp.sum(jax_arr, axis=axis)
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_min(jax_arr: Any, axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

result = stats.min(jax_arr, axis=axis)
expected = jnp.min(jax_arr, axis=axis)
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_max(jax_arr: Any, axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

result = stats.max(jax_arr, axis=axis)
expected = jnp.max(jax_arr, axis=axis)
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_mean(jax_arr: Any, axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

result = stats.mean(jax_arr, axis=axis)
expected = jnp.mean(jax_arr, axis=axis)
assert jnp.allclose(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_is_constant(axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

x = jnp.array(
[
[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 0],
],
dtype=jnp.float32,
)
result = stats.is_constant(x, axis=axis)

if axis is None:
assert bool(result) is False
elif axis == 0:
expected = jnp.array([True, True, False, False])
assert jnp.array_equal(result, expected)
else:
expected = jnp.array([False, False, True, True, False, True])
assert jnp.array_equal(result, expected)


@pytest.mark.parametrize("axis", [None, 0, 1])
def test_mean_var(jax_arr: Any, axis: Literal[0, 1] | None) -> None:
import jax.numpy as jnp

mean, var = stats.mean_var(jax_arr, axis=axis, correction=1)

mean_expected = jnp.mean(jax_arr, axis=axis)
n = jax_arr.size if axis is None else jax_arr.shape[axis]
var_expected = jnp.var(jax_arr, axis=axis) * n / (n - 1)

assert jnp.allclose(mean, mean_expected)
assert jnp.allclose(var, var_expected)


def test_to_dense(jax_arr: Any) -> None:
import jax.numpy as jnp

result = to_dense(jax_arr)
assert jnp.array_equal(result, jax_arr)


def test_to_dense_to_cpu(jax_arr: Any) -> None:
result = to_dense(jax_arr, to_cpu_memory=True)
assert isinstance(result, np.ndarray)
np.testing.assert_array_equal(result, np.asarray(jax_arr))
Loading