From e91911af4d33d6a650641f9a69c6669b1525f834 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Mon, 2 Mar 2026 15:21:29 +0100 Subject: [PATCH 1/6] fix: use parallelized `numba` functions if possible --- pyproject.toml | 6 -- src/fast_array_utils/conv/scipy/_to_dense.py | 6 +- src/fast_array_utils/stats/_is_constant.py | 4 +- src/fast_array_utils/stats/_mean_var.py | 6 +- src/fast_array_utils/utils.py | 90 ++++++++++++++++++++ src/testing/fast_array_utils/pytest.py | 17 ++++ 6 files changed, 118 insertions(+), 11 deletions(-) create mode 100644 src/fast_array_utils/utils.py diff --git a/pyproject.toml b/pyproject.toml index 60fa867..f525e5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,12 +177,6 @@ doctest_subpackage_requires = [ "src/fast_array_utils/_plugins/dask.py = dask", "src/fast_array_utils/_plugins/numba_sparse.py = numba;scipy", ] -filterwarnings = [ - "error", - # codspeed seems to break this dtype added by h5py - "ignore:.*numpy[.]longdouble:UserWarning", - "ignore:FNV hashing is not implemented in Numba:UserWarning", -] markers = [ "benchmark: marks tests as benchmark (to run with `--codspeed`)", ] diff --git a/src/fast_array_utils/conv/scipy/_to_dense.py b/src/fast_array_utils/conv/scipy/_to_dense.py index b699a2e..2ebe6ef 100644 --- a/src/fast_array_utils/conv/scipy/_to_dense.py +++ b/src/fast_array_utils/conv/scipy/_to_dense.py @@ -5,6 +5,8 @@ import numba +from fast_array_utils.utils import njit + if TYPE_CHECKING: from typing import Any @@ -18,14 +20,14 @@ __all__ = ["_to_dense_csc_numba", "_to_dense_csr_numba"] -@numba.njit(cache=True) +@njit def _to_dense_csc_numba(x: CSBase, out: NDArray[np.number[Any]]) -> None: for c in numba.prange(out.shape[1]): for i in range(x.indptr[c], x.indptr[c + 1]): out[x.indices[i], c] = x.data[i] -@numba.njit(cache=True) +@njit def _to_dense_csr_numba(x: CSBase, out: NDArray[np.number[Any]]) -> None: for r in numba.prange(out.shape[0]): for i in range(x.indptr[r], x.indptr[r + 1]): diff --git a/src/fast_array_utils/stats/_is_constant.py b/src/fast_array_utils/stats/_is_constant.py index 1ac95d3..abb72fb 100644 --- a/src/fast_array_utils/stats/_is_constant.py +++ b/src/fast_array_utils/stats/_is_constant.py @@ -7,6 +7,8 @@ import numba import numpy as np +from fast_array_utils.utils import njit + from .. import types @@ -64,7 +66,7 @@ def _is_constant_cs(a: types.CSBase, /, *, axis: Literal[0, 1] | None = None) -> return _is_constant_cs_major(a, shape) -@numba.njit(cache=True) +@njit def _is_constant_cs_major(a: types.CSBase, shape: tuple[int, int]) -> NDArray[np.bool]: n = len(a.indptr) - 1 result = np.ones(n, dtype=np.bool) diff --git a/src/fast_array_utils/stats/_mean_var.py b/src/fast_array_utils/stats/_mean_var.py index 9037567..4c6b460 100644 --- a/src/fast_array_utils/stats/_mean_var.py +++ b/src/fast_array_utils/stats/_mean_var.py @@ -6,6 +6,8 @@ import numba import numpy as np +from fast_array_utils.utils import njit + from .. import types from ._power import power @@ -79,7 +81,7 @@ def _sparse_mean_var(mtx: types.CSBase, /, *, axis: Literal[0, 1]) -> tuple[NDAr ) -@numba.njit(cache=True) +@njit def sparse_mean_var_minor_axis( x: types.CSBase, *, @@ -109,7 +111,7 @@ def sparse_mean_var_minor_axis( return means, variances -@numba.njit(cache=True) +@njit def sparse_mean_var_major_axis( x: types.CSBase, *, diff --git a/src/fast_array_utils/utils.py b/src/fast_array_utils/utils.py new file mode 100644 index 0000000..512927f --- /dev/null +++ b/src/fast_array_utils/utils.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import sys +import warnings +from functools import cache, wraps +from typing import TYPE_CHECKING, Literal, cast, overload + + +if TYPE_CHECKING: + from collections.abc import Callable + +type LayerType = Literal["default", "safe", "threadsafe", "forksafe"] +type Layer = Literal["tbb", "omp", "workqueue"] + + +LAYERS: dict[LayerType, set[Layer]] = { + "default": {"tbb", "omp", "workqueue"}, + "safe": {"tbb"}, + "threadsafe": {"tbb", "omp"}, + "forksafe": {"tbb", "workqueue", *(() if sys.platform == "linux" else {"omp"})}, +} + + +@cache +def _numba_threading_layer() -> Layer: + """Get numba’s threading layer. + + This function implements the algorithm as described in + + """ + import importlib + + import numba + + if (available := LAYERS.get(numba.config.THREADING_LAYER)) is None: + # given by direct name + return numba.config.THREADING_LAYER + + # given by layer type (safe, …) + for layer in cast("list[Layer]", numba.config.THREADING_LAYER_PRIORITY): + if layer not in available: + continue + if layer != "workqueue": + try: # `importlib.util.find_spec` doesn’t work here + importlib.import_module(f"numba.np.ufunc.{layer}pool") + except ImportError: + continue + # the layer has been found + return layer + msg = f"No loadable threading layer: {numba.config.THREADING_LAYER=} ({available=}, {numba.config.THREADING_LAYER_PRIORITY=})" + raise ValueError(msg) + + +def _is_in_unsafe_thread_pool() -> bool: + import threading + + current_thread = threading.current_thread() + # ThreadPoolExecutor threads typically have names like 'ThreadPoolExecutor-0_1' + return current_thread.name.startswith("ThreadPoolExecutor") and _numba_threading_layer() not in LAYERS["threadsafe"] + + +@overload +def njit[**P, R](fn: Callable[P, R], /) -> Callable[P, R]: ... +@overload +def njit[**P, R]() -> Callable[[Callable[P, R]], Callable[P, R]]: ... +def njit[**P, R](fn: Callable[P, R] | None = None, /) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: + """Jit-compile a function using numba. + + On call, this function dispatches to a parallel or sequential numba function, + depending on if it has been called from a thread pool. + + See + """ + + def decorator(f: Callable[P, R], /) -> Callable[P, R]: + import numba + + fns: dict[bool, Callable[P, R]] = {parallel: numba.njit(f, cache=True, parallel=parallel) for parallel in (True, False)} + + @wraps(f) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + parallel = not _is_in_unsafe_thread_pool() + if not parallel: + msg = f"Detected unsupported threading environment. Trying to run {f.__name__} in serial mode. In case of problems, install `tbb`." + warnings.warn(msg, UserWarning, stacklevel=2) + return fns[parallel](*args, **kwargs) + + return wrapper + + return decorator if fn is None else decorator(fn) diff --git a/src/testing/fast_array_utils/pytest.py b/src/testing/fast_array_utils/pytest.py index 2500922..370cfe4 100644 --- a/src/testing/fast_array_utils/pytest.py +++ b/src/testing/fast_array_utils/pytest.py @@ -89,6 +89,23 @@ def _skip_if_unimportable(array_type: ArrayType) -> pytest.MarkDecorator: SUPPORTED_TYPE_PARAMS = [pytest.param(t, id=str(t), marks=_skip_if_unimportable(t)) for t in SUPPORTED_TYPES] +@pytest.fixture(autouse=True) +def dask_single_threaded() -> Generator[None]: + """Switch to a single-threaded scheduler for tests since numba crashes otherwise.""" + if not find_spec("dask"): + yield + return + + import dask.config + + prev_scheduler = dask.config.get("scheduler", "threads") + dask.config.set(scheduler="single-threaded") + try: + yield + finally: + dask.config.set(scheduler=prev_scheduler) + + @pytest.fixture(scope="session", params=SUPPORTED_TYPE_PARAMS) def array_type(request: pytest.FixtureRequest) -> ArrayType: """Fixture for a supported :class:`~testing.fast_array_utils.ArrayType`. From ba7da131d6c7d3fc183ba905f185a2ef59887934 Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Fri, 20 Mar 2026 14:33:06 +0100 Subject: [PATCH 2/6] fix: no dask threaded --- src/testing/fast_array_utils/pytest.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/src/testing/fast_array_utils/pytest.py b/src/testing/fast_array_utils/pytest.py index 370cfe4..2500922 100644 --- a/src/testing/fast_array_utils/pytest.py +++ b/src/testing/fast_array_utils/pytest.py @@ -89,23 +89,6 @@ def _skip_if_unimportable(array_type: ArrayType) -> pytest.MarkDecorator: SUPPORTED_TYPE_PARAMS = [pytest.param(t, id=str(t), marks=_skip_if_unimportable(t)) for t in SUPPORTED_TYPES] -@pytest.fixture(autouse=True) -def dask_single_threaded() -> Generator[None]: - """Switch to a single-threaded scheduler for tests since numba crashes otherwise.""" - if not find_spec("dask"): - yield - return - - import dask.config - - prev_scheduler = dask.config.get("scheduler", "threads") - dask.config.set(scheduler="single-threaded") - try: - yield - finally: - dask.config.set(scheduler=prev_scheduler) - - @pytest.fixture(scope="session", params=SUPPORTED_TYPE_PARAMS) def array_type(request: pytest.FixtureRequest) -> ArrayType: """Fixture for a supported :class:`~testing.fast_array_utils.ArrayType`. From 9bbae70ed1fd8a091f22104a56797d70a8219a46 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 23 Mar 2026 15:22:33 +0100 Subject: [PATCH 3/6] hopefully fix tests --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5dd7a7a..2598a4c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ ] dynamic = [ "description", "readme", "version" ] dependencies = [ "numpy>=2" ] -optional-dependencies.accel = [ "numba>=0.57" ] +optional-dependencies.accel = [ "numba>=0.57", "tbb" ] optional-dependencies.dask = [ "dask>=2023.6.1" ] optional-dependencies.full = [ "fast-array-utils[accel,dask,sparse]", "h5py", "zarr" ] optional-dependencies.sparse = [ "scipy>=1.13" ] From e5ee035a1431d67704c13a43709e21f35dd4f86b Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 23 Mar 2026 15:31:14 +0100 Subject: [PATCH 4/6] only on linux and win --- pyproject.toml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2598a4c..b21a207 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ ] dynamic = [ "description", "readme", "version" ] dependencies = [ "numpy>=2" ] -optional-dependencies.accel = [ "numba>=0.57", "tbb" ] +optional-dependencies.accel = [ "numba>=0.57" ] optional-dependencies.dask = [ "dask>=2023.6.1" ] optional-dependencies.full = [ "fast-array-utils[accel,dask,sparse]", "h5py", "zarr" ] optional-dependencies.sparse = [ "scipy>=1.13" ] @@ -50,11 +50,12 @@ doc = [ ] test-min = [ "coverage[toml]", - "fast-array-utils[sparse,testing]", # include sparse for testing numba-less to_dense + "fast-array-utils[sparse,testing]", # include sparse for testing numba-less to_dense "pytest", "pytest-codspeed", "pytest-doctestplus", "pytest-xdist", + "tbb; sys_platform=='linux' or sys_platform=='windows'", # robust tests ] [tool.hatch] From 4db019b941a2f99fa4ff1fc4ac13eac88d0f3681 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 23 Mar 2026 16:09:43 +0100 Subject: [PATCH 5/6] re-introduce workaround conditionally --- pyproject.toml | 3 +++ src/fast_array_utils/{utils.py => _numba.py} | 15 ++++++++----- src/fast_array_utils/conv/scipy/_to_dense.py | 2 +- src/fast_array_utils/stats/_is_constant.py | 2 +- src/fast_array_utils/stats/_mean_var.py | 2 +- src/testing/fast_array_utils/pytest.py | 23 ++++++++++++++++++++ typings/numba/__init__.pyi | 15 ++++++++++++- typings/numba/core/config.pyi | 5 +++++ 8 files changed, 58 insertions(+), 9 deletions(-) rename src/fast_array_utils/{utils.py => _numba.py} (87%) create mode 100644 typings/numba/core/config.pyi diff --git a/pyproject.toml b/pyproject.toml index b21a207..1e46d12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -175,3 +175,6 @@ mypy_path = [ "$MYPY_CONFIG_FILE_DIR/typings", "$MYPY_CONFIG_FILE_DIR/src" ] [tool.pyright] stubPath = "./typings" reportPrivateUsage = false + +[tool.ty] +environment.extra-paths = [ "./typings" ] diff --git a/src/fast_array_utils/utils.py b/src/fast_array_utils/_numba.py similarity index 87% rename from src/fast_array_utils/utils.py rename to src/fast_array_utils/_numba.py index 512927f..4d0ebe0 100644 --- a/src/fast_array_utils/utils.py +++ b/src/fast_array_utils/_numba.py @@ -1,3 +1,6 @@ +# SPDX-License-Identifier: MPL-2.0 +"""Numba utilities.""" + from __future__ import annotations import sys @@ -22,7 +25,7 @@ @cache -def _numba_threading_layer() -> Layer: +def _numba_threading_layer(layer_name: Layer | LayerType | None = None) -> Layer: """Get numba’s threading layer. This function implements the algorithm as described in @@ -32,12 +35,14 @@ def _numba_threading_layer() -> Layer: import numba - if (available := LAYERS.get(numba.config.THREADING_LAYER)) is None: - # given by direct name - return numba.config.THREADING_LAYER + if layer_name is None: + layer_name = numba.config.THREADING_LAYER + + if (available := LAYERS.get(layer_name)) is None: # type: ignore[arg-type] + return cast("Layer", layer_name) # given by direct name # given by layer type (safe, …) - for layer in cast("list[Layer]", numba.config.THREADING_LAYER_PRIORITY): + for layer in numba.config.THREADING_LAYER_PRIORITY: if layer not in available: continue if layer != "workqueue": diff --git a/src/fast_array_utils/conv/scipy/_to_dense.py b/src/fast_array_utils/conv/scipy/_to_dense.py index 2ebe6ef..af07246 100644 --- a/src/fast_array_utils/conv/scipy/_to_dense.py +++ b/src/fast_array_utils/conv/scipy/_to_dense.py @@ -5,7 +5,7 @@ import numba -from fast_array_utils.utils import njit +from fast_array_utils._numba import njit if TYPE_CHECKING: diff --git a/src/fast_array_utils/stats/_is_constant.py b/src/fast_array_utils/stats/_is_constant.py index abb72fb..ba648b3 100644 --- a/src/fast_array_utils/stats/_is_constant.py +++ b/src/fast_array_utils/stats/_is_constant.py @@ -7,7 +7,7 @@ import numba import numpy as np -from fast_array_utils.utils import njit +from fast_array_utils._numba import njit from .. import types diff --git a/src/fast_array_utils/stats/_mean_var.py b/src/fast_array_utils/stats/_mean_var.py index 4c6b460..e3b7073 100644 --- a/src/fast_array_utils/stats/_mean_var.py +++ b/src/fast_array_utils/stats/_mean_var.py @@ -6,7 +6,7 @@ import numba import numpy as np -from fast_array_utils.utils import njit +from fast_array_utils._numba import njit from .. import types from ._power import power diff --git a/src/testing/fast_array_utils/pytest.py b/src/testing/fast_array_utils/pytest.py index 2500922..b14685e 100644 --- a/src/testing/fast_array_utils/pytest.py +++ b/src/testing/fast_array_utils/pytest.py @@ -15,6 +15,7 @@ import pytest from fast_array_utils import types +from fast_array_utils._numba import _numba_threading_layer from testing.fast_array_utils import SUPPORTED_TYPES, ArrayType, ConversionContext, Flags @@ -89,6 +90,28 @@ def _skip_if_unimportable(array_type: ArrayType) -> pytest.MarkDecorator: SUPPORTED_TYPE_PARAMS = [pytest.param(t, id=str(t), marks=_skip_if_unimportable(t)) for t in SUPPORTED_TYPES] +@pytest.fixture(autouse=True) +def dask_single_threaded() -> Generator[None]: + """Switch to a single-threaded scheduler for tests on macOS since numba crashes otherwise.""" + if not find_spec("dask"): + yield + return + try: # if a safe threading layer is available, we use that + _numba_threading_layer("threadsafe") + except ValueError: + yield + return + + import dask.config + + prev_scheduler = dask.config.get("scheduler", "threads") + dask.config.set(scheduler="single-threaded") + try: + yield + finally: + dask.config.set(scheduler=prev_scheduler) + + @pytest.fixture(scope="session", params=SUPPORTED_TYPE_PARAMS) def array_type(request: pytest.FixtureRequest) -> ArrayType: """Fixture for a supported :class:`~testing.fast_array_utils.ArrayType`. diff --git a/typings/numba/__init__.pyi b/typings/numba/__init__.pyi index c3a1395..1f3bfa8 100644 --- a/typings/numba/__init__.pyi +++ b/typings/numba/__init__.pyi @@ -2,6 +2,7 @@ from collections.abc import Callable, Iterable from typing import Literal, SupportsIndex, overload +from .core import config as config from .core.types import * type __Signature = str | Type @@ -9,7 +10,19 @@ type _Signature = str | Type | tuple[__Signature, ...] # https://numba.readthedocs.io/en/stable/reference/jit-compilation.html#numba.jit @overload -def njit[F: Callable[..., object]](f: F) -> F: ... +def njit[F: Callable[..., object]]( + f: F, + *, + nopython: bool = True, + nogil: bool = False, + cache: bool = False, + forceobj: bool = False, + parallel: bool = False, + error_model: Literal["python", "numpy"] = "python", + fastmath: bool = False, + locals: dict[str, object] = {}, + boundscheck: bool = False, +) -> F: ... @overload def njit[F: Callable[..., object]]( signature: _Signature | list[_Signature] | None = None, diff --git a/typings/numba/core/config.pyi b/typings/numba/core/config.pyi new file mode 100644 index 0000000..702dd1e --- /dev/null +++ b/typings/numba/core/config.pyi @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: MPL-2.0 +from fast_array_utils._numba import Layer, LayerType + +THREADING_LAYER: Layer | LayerType +THREADING_LAYER_PRIORITY: list[Layer] From a8256eae3f78f50dbc0176927c247b30bf162850 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Mon, 23 Mar 2026 16:26:22 +0100 Subject: [PATCH 6/6] whoops --- src/testing/fast_array_utils/pytest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/testing/fast_array_utils/pytest.py b/src/testing/fast_array_utils/pytest.py index b14685e..dec4b79 100644 --- a/src/testing/fast_array_utils/pytest.py +++ b/src/testing/fast_array_utils/pytest.py @@ -96,9 +96,11 @@ def dask_single_threaded() -> Generator[None]: if not find_spec("dask"): yield return - try: # if a safe threading layer is available, we use that + try: _numba_threading_layer("threadsafe") except ValueError: + pass + else: # if a safe threading layer is available, we use that yield return