From 5075dfaa082a0e008e35c06c381f8fb10549b27a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Jan 2026 14:50:49 +0100 Subject: [PATCH 01/15] Fix exclude tag cxx -> cxx_only --- pytensor/link/basic.py | 2 +- pytensor/link/jax/linker.py | 2 +- pytensor/link/numba/linker.py | 2 +- pytensor/link/vm.py | 2 +- pytensor/tensor/rewriting/blas_c.py | 1 + tests/sparse/test_rewriting.py | 17 ++++------------- 6 files changed, 9 insertions(+), 17 deletions(-) diff --git a/pytensor/link/basic.py b/pytensor/link/basic.py index ae28fbf15c..6766c6bb8e 100644 --- a/pytensor/link/basic.py +++ b/pytensor/link/basic.py @@ -284,7 +284,7 @@ class PerformLinker(LocalLinker): """ required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only") - incompatible_rewrites: tuple[str, ...] = ("cxx",) + incompatible_rewrites: tuple[str, ...] = ("cxx_only",) def __init__( self, allow_gc: bool | None = None, schedule: Callable | None = None diff --git a/pytensor/link/jax/linker.py b/pytensor/link/jax/linker.py index 6abf467824..566a9f4469 100644 --- a/pytensor/link/jax/linker.py +++ b/pytensor/link/jax/linker.py @@ -14,7 +14,7 @@ class JAXLinker(JITLinker): "jax", ) # TODO: Distinguish between optional "jax" and "minimum_compile_jax" incompatible_rewrites = ( - "cxx", + "cxx_only", "BlasOpt", "local_careduce_fusion", "scan_save_mem_prealloc", diff --git a/pytensor/link/numba/linker.py b/pytensor/link/numba/linker.py index 068df4a95b..bf5c4304ad 100644 --- a/pytensor/link/numba/linker.py +++ b/pytensor/link/numba/linker.py @@ -7,7 +7,7 @@ class NumbaLinker(JITLinker): "numba", ) # TODO: Distinguish between optional "numba" and "minimum_compile_numba" incompatible_rewrites = ( - "cxx", + "cxx_only", "BlasOpt", "local_careduce_fusion", "scan_save_mem_prealloc", diff --git a/pytensor/link/vm.py b/pytensor/link/vm.py index 8c9a2dc90f..239f73df80 100644 --- a/pytensor/link/vm.py +++ b/pytensor/link/vm.py @@ -840,7 +840,7 @@ def __init__( c_thunks = bool(config.cxx) if not c_thunks: self.required_rewrites: tuple[str, ...] = ("minimum_compile", "py_only") - self.incompatible_rewrites: tuple[str, ...] = ("cxx",) + self.incompatible_rewrites: tuple[str, ...] = ("cxx_only",) self.c_thunks = c_thunks self.allow_partial_eval = allow_partial_eval self.updated_vars = {} diff --git a/pytensor/tensor/rewriting/blas_c.py b/pytensor/tensor/rewriting/blas_c.py index 827aa64077..d4220d6d7a 100644 --- a/pytensor/tensor/rewriting/blas_c.py +++ b/pytensor/tensor/rewriting/blas_c.py @@ -68,5 +68,6 @@ def make_c_gemv_destructive(fgraph, node): "fast_run", "inplace", "c_blas", + "cxx_only", position=70.0, ) diff --git a/tests/sparse/test_rewriting.py b/tests/sparse/test_rewriting.py index 759ee17f34..fe9befd160 100644 --- a/tests/sparse/test_rewriting.py +++ b/tests/sparse/test_rewriting.py @@ -69,13 +69,10 @@ def test_local_csm_grad_c(): not pytensor.config.cxx, reason="G++ not available, so we need to skip this test." ) def test_local_mul_s_d(): - mode = get_default_mode() - mode = mode.including("specialize", "local_mul_s_d") - for sp_format in sparse.sparse_formats: inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), matrix()] - f = pytensor.function(inputs, smath.mul_s_d(*inputs), mode=mode) + f = pytensor.function(inputs, smath.mul_s_d(*inputs), mode="CVM") assert not any( isinstance(node.op, smath.MulSD) for node in f.maker.fgraph.toposort() @@ -92,7 +89,7 @@ def test_local_mul_s_v(): for sp_format in ["csr"]: # Not implemented for other format inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), vector()] - f = pytensor.function(inputs, smath.mul_s_v(*inputs), mode=mode) + f = pytensor.function(inputs, smath.mul_s_v(*inputs), mode="CVM") assert not any( isinstance(node.op, smath.MulSV) for node in f.maker.fgraph.toposort() @@ -103,13 +100,10 @@ def test_local_mul_s_v(): not pytensor.config.cxx, reason="G++ not available, so we need to skip this test." ) def test_local_structured_add_s_v(): - mode = get_default_mode() - mode = mode.including("specialize", "local_structured_add_s_v") - for sp_format in ["csr"]: # Not implemented for other format inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), vector()] - f = pytensor.function(inputs, smath.structured_add_s_v(*inputs), mode=mode) + f = pytensor.function(inputs, smath.structured_add_s_v(*inputs), mode="CVM") assert not any( isinstance(node.op, smath.StructuredAddSV) @@ -121,9 +115,6 @@ def test_local_structured_add_s_v(): not pytensor.config.cxx, reason="G++ not available, so we need to skip this test." ) def test_local_sampling_dot_csr(): - mode = get_default_mode() - mode = mode.including("specialize", "local_sampling_dot_csr") - for sp_format in ["csr"]: # Not implemented for other format inputs = [ matrix(), @@ -131,7 +122,7 @@ def test_local_sampling_dot_csr(): getattr(pytensor.sparse, sp_format + "_matrix")(), ] - f = pytensor.function(inputs, smath.sampling_dot(*inputs), mode=mode) + f = pytensor.function(inputs, smath.sampling_dot(*inputs), mode="CVM") if pytensor.config.blas__ldflags: assert not any( From 90c71954cbdf9b1e4ae24214e46a9338249ed51b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Jan 2026 13:52:43 +0100 Subject: [PATCH 02/15] Seed flaky test --- tests/tensor/test_optimize.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/tensor/test_optimize.py b/tests/tensor/test_optimize.py index e42550ded3..27752c6be3 100644 --- a/tests/tensor/test_optimize.py +++ b/tests/tensor/test_optimize.py @@ -239,7 +239,7 @@ def objective(mu, cov, data): y_star = pytensor.graph_replace(neg_logp, {mu: mu_star}) _ = pt.grad(y_star, [mu, cov, data]) - rng = np.random.default_rng() + rng = np.random.default_rng(242) data_val = rng.normal(size=(d,)).astype(floatX) L = rng.normal(size=(d, d)).astype(floatX) @@ -262,7 +262,10 @@ def min_fn(mu, cov, data): return mu_star.sum() utt.verify_grad( - min_fn, [mu0_val, cov_val, data_val], eps=1e-3 if floatX == "float32" else 1e-6 + min_fn, + [mu0_val, cov_val, data_val], + eps=1e-3 if floatX == "float32" else 1e-6, + rng=rng, ) From db84708dc369db6bc4e9ef06e82db09118ed0c73 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 18 Nov 2025 15:05:37 +0100 Subject: [PATCH 03/15] Cleanup ruff per file ignores --- pyproject.toml | 20 +++++++------------- tests/link/jax/test_random.py | 2 +- tests/sparse/test_utils.py | 4 ---- 3 files changed, 8 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 89e95bcd4d..bd5532cdab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -149,27 +149,21 @@ lines-after-imports = 2 # TODO: Get rid of these: "**/__init__.py" = ["F401", "E402", "F403"] "pytensor/tensor/linalg.py" = ["F403"] +# Modules that use print-statements, skip "T201" "pytensor/link/c/cmodule.py" = ["PTH", "T201"] "pytensor/misc/elemwise_time_test.py" = ["T201"] "pytensor/misc/elemwise_openmp_speedup.py" = ["T201"] "pytensor/misc/check_duplicate_key.py" = ["T201"] "pytensor/misc/check_blas.py" = ["T201"] "pytensor/bin/pytensor_cache.py" = ["T201"] -# For the tests we skip `E402` because `pytest.importorskip` is used: -"tests/link/jax/test_scalar.py" = ["E402"] -"tests/link/jax/test_tensor_basic.py" = ["E402"] -"tests/link/numba/test_basic.py" = ["E402"] -"tests/link/numba/test_cython_support.py" = ["E402"] -"tests/link/numba/test_performance.py" = ["E402"] -"tests/link/numba/test_sparse.py" = ["E402"] -"tests/link/numba/test_tensor_basic.py" = ["E402"] -"tests/tensor/test_math_scipy.py" = ["E402"] -"tests/sparse/test_basic.py" = ["E402"] -"tests/sparse/test_sp2.py" = ["E402"] -"tests/sparse/test_utils.py" = ["E402"] -"tests/sparse/sandbox/test_sp.py" = ["E402", "F401"] "tests/compile/test_monitormode.py" = ["T201"] "scripts/run_mypy.py" = ["T201"] +# Test modules of optional backends that use `pytest.importorskip`, skip "E402" +"tests/link/jax/**/test_*.py" = ["E402"] +"tests/link/numba/**/test_*.py" = ["E402"] +"tests/link/pytorch/**/test_*.py" = ["E402"] +"tests/link/mlx/**/test_*.py" = ["E402"] + [tool.mypy] diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 9cf148412e..60b1cd0bea 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -23,7 +23,7 @@ jax = pytest.importorskip("jax") -from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402 +from pytensor.link.jax.dispatch.random import numpyro_available def compile_random_function(*args, mode=jax_mode, **kwargs): diff --git a/tests/sparse/test_utils.py b/tests/sparse/test_utils.py index dd1c2bb67b..f7e52c322e 100644 --- a/tests/sparse/test_utils.py +++ b/tests/sparse/test_utils.py @@ -1,8 +1,4 @@ import numpy as np -import pytest - - -sp = pytest.importorskip("scipy", minversion="0.7.0") from pytensor.sparse.utils import hash_from_sparse from tests.sparse.test_basic import as_sparse_format From fe8b80f5fe428d696b108e50e31284940efd0b3a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 18 Oct 2025 18:21:49 +0200 Subject: [PATCH 04/15] Allow single integer as TensorType shape --- pytensor/tensor/type.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index c1a199c261..d30226ea2f 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -71,7 +71,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): def __init__( self, dtype: str | npt.DTypeLike, - shape: Iterable[bool | int | None] | None = None, + shape: Iterable[bool | int | None] | int | None = None, name: str | None = None, broadcastable: Iterable[bool] | None = None, ): @@ -99,7 +99,7 @@ def __init__( ) shape = broadcastable - if str(dtype) == "floatX": + if dtype == "floatX": self.dtype = config.floatX else: try: @@ -118,6 +118,8 @@ def parse_bcast_and_shape(s): f"TensorType broadcastable/shape must be a boolean, integer or None, got {type(s)} {s}" ) + if isinstance(shape, int): + shape = (shape,) self.shape = _shape = tuple(parse_bcast_and_shape(s) for s in shape) self.broadcastable = tuple(s == 1 for s in _shape) self.ndim = _ndim = len(_shape) From 447af6442fb9ef369597b538f426af3fe454e134 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sat, 18 Oct 2025 18:21:07 +0200 Subject: [PATCH 05/15] Handle static shape in core sparse methods --- pytensor/sparse/basic.py | 28 ++++++++++++++++++++-------- pytensor/sparse/variable.py | 2 ++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index c2d33dd245..e79d827084 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -207,19 +207,19 @@ def sp_zeros_like(x): # for more dtypes, call SparseTensorType(format, dtype) -def matrix(format, name=None, dtype=None): +def matrix(format, name=None, dtype=None, shape=None): if dtype is None: dtype = config.floatX - type = SparseTensorType(format=format, dtype=dtype) + type = SparseTensorType(format=format, dtype=dtype, shape=shape) return type(name) -def csc_matrix(name=None, dtype=None): - return matrix("csc", name, dtype) +def csc_matrix(name=None, dtype=None, shape=None): + return matrix("csc", name=name, dtype=dtype, shape=shape) -def csr_matrix(name=None, dtype=None): - return matrix("csr", name, dtype) +def csr_matrix(name=None, dtype=None, shape=None): + return matrix("csr", name=name, dtype=dtype, shape=shape) def bsr_matrix(name=None, dtype=None): @@ -434,10 +434,22 @@ def make_node(self, data, indices, indptr, shape): if shape.type.ndim != 1 or shape.type.dtype not in discrete_dtypes: raise TypeError("n_rows must be integer type", shape, shape.type) + static_shape = (None, None) + if ( + shape.owner is not None + and isinstance(shape.owner.op, CSMProperties) + and shape.owner.outputs[3] is shape + ): + static_shape = shape.owner.inputs[0].type.shape + return Apply( self, [data, indices, indptr, shape], - [SparseTensorType(dtype=data.type.dtype, format=self.format)()], + [ + SparseTensorType( + dtype=data.type.dtype, format=self.format, shape=static_shape + )() + ], ) def perform(self, node, inputs, outputs): @@ -698,7 +710,7 @@ def make_node(self, x): return Apply( self, [x], - [TensorType(dtype=x.type.dtype, shape=(None, None))()], + [TensorType(dtype=x.type.dtype, shape=x.type.shape)()], ) def perform(self, node, inputs, outputs): diff --git a/pytensor/sparse/variable.py b/pytensor/sparse/variable.py index d9a97d7de1..6bf51d1a7f 100644 --- a/pytensor/sparse/variable.py +++ b/pytensor/sparse/variable.py @@ -127,6 +127,8 @@ def sum(self, axis=None, sparse_grad=False): def toarray(self): return dense_from_sparse(self) + todense = toarray + @property def shape(self): # TODO: The plan is that the ShapeFeature in ptb.opt will do shape From 63c8171f8ecbf5657ab5db2ac6497afa88f00aca Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Jan 2026 17:00:22 +0100 Subject: [PATCH 06/15] Sparse methods: Do not use deprecated names --- pytensor/sparse/variable.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytensor/sparse/variable.py b/pytensor/sparse/variable.py index 6bf51d1a7f..5f801185f3 100644 --- a/pytensor/sparse/variable.py +++ b/pytensor/sparse/variable.py @@ -22,11 +22,11 @@ gt, le, lt, - mul, + multiply, sp_sum, structured_conjugate, structured_dot, - sub, + subtract, ) from pytensor.sparse.type import SparseTensorType from pytensor.sparse.utils import hash_from_sparse @@ -88,16 +88,16 @@ def __radd__(right, left): return add(left, right) def __sub__(left, right): - return sub(left, right) + return subtract(left, right) def __rsub__(right, left): - return sub(left, right) + return subtract(left, right) def __mul__(left, right): - return mul(left, right) + return multiply(left, right) def __rmul__(left, right): - return mul(left, right) + return multiply(left, right) # comparison operators From 0037f846b89e4239e5c3c65a8c8e2c3ca98c8bf9 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Sun, 16 Nov 2025 17:11:55 +0100 Subject: [PATCH 07/15] SparseMultiply: Cleanup Ops * Handle static shape * Rename to more readable Op classes * Simplify perform --- pytensor/sparse/math.py | 141 ++++++++++++++++++--------------- tests/sparse/test_rewriting.py | 6 +- 2 files changed, 81 insertions(+), 66 deletions(-) diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index bdb2b52ef4..c9ba553658 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -12,6 +12,7 @@ from pytensor.gradient import grad_not_implemented from pytensor.graph import Apply, Op from pytensor.link.c.op import COp +from pytensor.sparse.type import SparseTensorType from pytensor.tensor.shape import specify_broadcastable from pytensor.tensor.type import TensorType, Variable, complex_dtypes, tensor @@ -379,7 +380,7 @@ def make_node(self, x, y): return Apply( self, [x, y], - [psb.SparseTensorType(dtype=out_dtype, format=x.type.format)()], + [SparseTensorType(dtype=out_dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -439,7 +440,7 @@ def make_node(self, x, y): return Apply( self, [x, y], - [psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], + [SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -542,7 +543,7 @@ def make_node(self, x, y): return Apply( self, [x, y], - [psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], + [SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -658,7 +659,7 @@ def sub(x, y): sub.__doc__ = subtract.__doc__ -class MulSS(Op): +class SparseSparseMultiply(Op): # mul(sparse, sparse) # See the doc of mul() for more detail __props__ = () @@ -671,7 +672,7 @@ def make_node(self, x, y): return Apply( self, [x, y], - [psb.SparseTensorType(dtype=out_dtype, format=x.type.format)()], + [SparseTensorType(dtype=out_dtype, format=x.type.format)()], ) def perform(self, node, inputs, outputs): @@ -693,10 +694,10 @@ def infer_shape(self, fgraph, node, shapes): return [shapes[0]] -mul_s_s = MulSS() +mul_s_s = SparseSparseMultiply() -class MulSD(Op): +class SparseDenseMultiply(Op): # mul(sparse, dense) # See the doc of mul() for more detail __props__ = () @@ -713,65 +714,63 @@ def make_node(self, x, y): # objects must be matrices (have dimension 2) # Broadcasting of the sparse matrix is not supported. # We support nd == 0 used by grad of SpSum() - assert y.type.ndim in (0, 2) - out = psb.SparseTensorType(dtype=dtype, format=x.type.format)() + if y.type.ndim not in (0, 2): + raise ValueError(f"y {y} must have 0 or 2 dimensions. Got {y.type.ndim}") + if y.type.ndim == 0: + out_shape = x.type.shape + if y.type.ndim == 2: + # Combine with static shape information from y + out_shape = [] + for x_st_dim_length, y_st_dim_length in zip(x.type.shape, y.type.shape): + if x_st_dim_length is None: + out_shape.append(y_st_dim_length) + else: + out_shape.append(x_st_dim_length) + # If both are known, they must match + if ( + y_st_dim_length is not None + and y_st_dim_length != x_st_dim_length + ): + raise ValueError( + f"Incompatible static shapes {x}: {x.type.shape}, {y}: {y.type.shape}" + ) + out_shape = tuple(out_shape) + out = SparseTensorType(dtype=dtype, format=x.type.format, shape=out_shape)() return Apply(self, [x, y], [out]) def perform(self, node, inputs, outputs): (x, y) = inputs (out,) = outputs + out_dtype = node.outputs[0].dtype assert psb._is_sparse(x) and psb._is_dense(y) - if len(y.shape) == 0: - out_dtype = node.outputs[0].dtype - if x.dtype == out_dtype: - z = x.copy() - else: - z = x.astype(out_dtype) - out[0] = z - out[0].data *= y - elif len(y.shape) == 1: - raise NotImplementedError() # RowScale / ColScale - elif len(y.shape) == 2: + + if x.dtype == out_dtype: + z = x.copy() + else: + z = x.astype(out_dtype) + out[0] = z + z_data = z.data + + if y.ndim == 0: + z_data *= y + else: # y_ndim == 2 # if we have enough memory to fit y, maybe we can fit x.asarray() # too? # TODO: change runtime from O(M*N) to O(nonzeros) M, N = x.shape assert x.shape == y.shape - out_dtype = node.outputs[0].dtype - + indices = x.indices + indptr = x.indptr if x.format == "csc": - indices = x.indices - indptr = x.indptr - if x.dtype == out_dtype: - z = x.copy() - else: - z = x.astype(out_dtype) - z_data = z.data - for j in range(0, N): for i_idx in range(indptr[j], indptr[j + 1]): i = indices[i_idx] z_data[i_idx] *= y[i, j] - out[0] = z elif x.format == "csr": - indices = x.indices - indptr = x.indptr - if x.dtype == out_dtype: - z = x.copy() - else: - z = x.astype(out_dtype) - z_data = z.data - for i in range(0, M): for j_idx in range(indptr[i], indptr[i + 1]): j = indices[j_idx] z_data[j_idx] *= y[i, j] - out[0] = z - else: - warn( - "This implementation of MulSD is deficient: {x.format}", - ) - out[0] = type(x)(x.toarray() * y) def grad(self, inputs, gout): (x, y) = inputs @@ -784,10 +783,10 @@ def infer_shape(self, fgraph, node, shapes): return [shapes[0]] -mul_s_d = MulSD() +mul_s_d = SparseDenseMultiply() -class MulSV(Op): +class SparseDenseVectorMultiply(Op): """Element-wise multiplication of sparse matrix by a broadcasted dense vector element wise. Notes @@ -796,6 +795,8 @@ class MulSV(Op): """ + # TODO: Merge with the SparseDenseMultiply Op + __props__ = () def make_node(self, x, y): @@ -812,17 +813,30 @@ def make_node(self, x, y): assert x.format in ("csr", "csc") y = ptb.as_tensor_variable(y) - assert y.type.ndim == 1 + if y.type.ndim != 1: + raise ValueError(f"y {y} must have 1 dimension. Got {y.type.ndim}") if x.type.dtype != y.type.dtype: raise NotImplementedError( - "MulSV not implemented for differing dtypes." - f"Got {x.type.dtype} and {y.type.dtype}." + f"Differing dtypes not supported. Got {x.type.dtype} and {y.type.dtype}." ) + out_shape = [x.type.shape[0]] + if x.type.shape[-1] is None: + out_shape.append(y.type.shape[0]) + else: + out_shape.append(x.type.shape[-1]) + if y.type.shape[-1] is not None and x.type.shape[-1] != y.type.shape[-1]: + raise ValueError( + f"Incompatible static shapes for multiplication {x}: {x.type.shape}, {y}: {y.type.shape}" + ) return Apply( self, [x, y], - [psb.SparseTensorType(dtype=x.type.dtype, format=x.type.format)()], + [ + SparseTensorType( + dtype=x.type.dtype, format=x.type.format, shape=tuple(out_shape) + )() + ], ) def perform(self, node, inputs, outputs): @@ -852,7 +866,7 @@ def infer_shape(self, fgraph, node, ins_shapes): return [ins_shapes[0]] -mul_s_v = MulSV() +mul_s_v = SparseDenseVectorMultiply() def multiply(x, y): @@ -891,16 +905,17 @@ def multiply(x, y): # mul_s_s is not implemented if the types differ if y.dtype == "float64" and x.dtype == "float32": x = x.astype("float64") - return mul_s_s(x, y) - elif x_is_sparse_variable and not y_is_sparse_variable: + elif x_is_sparse_variable or y_is_sparse_variable: + if y_is_sparse_variable: + x, y = y, x # mul is unimplemented if the dtypes differ if y.dtype == "float64" and x.dtype == "float32": x = x.astype("float64") - - return mul_s_d(x, y) - elif y_is_sparse_variable and not x_is_sparse_variable: - return mul_s_d(y, x) + if y.ndim == 1: + return mul_s_v(x, y) + else: + return mul_s_d(x, y) else: raise NotImplementedError() @@ -950,7 +965,7 @@ def make_node(self, x, y): if x.type.format != y.type.format: raise NotImplementedError() return Apply( - self, [x, y], [psb.SparseTensorType(dtype="uint8", format=x.type.format)()] + self, [x, y], [SparseTensorType(dtype="uint8", format=x.type.format)()] ) def perform(self, node, inputs, outputs): @@ -1203,7 +1218,7 @@ def make_node(self, x, y): raise NotImplementedError() inputs = [x, y] # Need to convert? e.g. assparse - outputs = [psb.SparseTensorType(dtype=x.type.dtype, format=myformat)()] + outputs = [SparseTensorType(dtype=x.type.dtype, format=myformat)()] return Apply(self, inputs, outputs) def perform(self, node, inp, out_): @@ -1324,9 +1339,7 @@ def make_node(self, a, b): raise NotImplementedError("non-matrix b") if psb._is_sparse_variable(b): - return Apply( - self, [a, b], [psb.SparseTensorType(a.type.format, dtype_out)()] - ) + return Apply(self, [a, b], [SparseTensorType(a.type.format, dtype_out)()]) else: return Apply( self, @@ -1348,7 +1361,7 @@ def perform(self, node, inputs, outputs): ) variable = a * b - if isinstance(node.outputs[0].type, psb.SparseTensorType): + if isinstance(node.outputs[0].type, SparseTensorType): assert psb._is_sparse(variable) out[0] = variable return diff --git a/tests/sparse/test_rewriting.py b/tests/sparse/test_rewriting.py index fe9befd160..ad782ce2ae 100644 --- a/tests/sparse/test_rewriting.py +++ b/tests/sparse/test_rewriting.py @@ -75,7 +75,8 @@ def test_local_mul_s_d(): f = pytensor.function(inputs, smath.mul_s_d(*inputs), mode="CVM") assert not any( - isinstance(node.op, smath.MulSD) for node in f.maker.fgraph.toposort() + isinstance(node.op, smath.SparseDenseMultiply) + for node in f.maker.fgraph.toposort() ) @@ -92,7 +93,8 @@ def test_local_mul_s_v(): f = pytensor.function(inputs, smath.mul_s_v(*inputs), mode="CVM") assert not any( - isinstance(node.op, smath.MulSV) for node in f.maker.fgraph.toposort() + isinstance(node.op, smath.SparseDenseVectorMultiply) + for node in f.maker.fgraph.toposort() ) From b6a01cc9546b016a28ee960dc26da67e35ebe96c Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Jan 2026 14:50:15 +0100 Subject: [PATCH 08/15] CSMProperties Op: simplify logic --- pytensor/sparse/basic.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pytensor/sparse/basic.py b/pytensor/sparse/basic.py index e79d827084..3250fa7ca0 100644 --- a/pytensor/sparse/basic.py +++ b/pytensor/sparse/basic.py @@ -297,15 +297,12 @@ def make_node(self, csm): csm = as_sparse_variable(csm) assert csm.format in ("csr", "csc") - data = TensorType(dtype=csm.type.dtype, shape=(None,))() + data = vector(dtype=csm.type.dtype) return Apply(self, [csm], [data, ivector(), ivector(), ivector()]) def perform(self, node, inputs, out): (csm,) = inputs - out[0][0] = csm.data - if str(csm.data.dtype) == "int32": - out[0][0] = np.asarray(out[0][0], dtype="int32") - # backport + out[0][0] = np.asarray(csm.data) out[1][0] = np.asarray(csm.indices, dtype="int32") out[2][0] = np.asarray(csm.indptr, dtype="int32") out[3][0] = np.asarray(csm.shape, dtype="int32") From feecbfd7c2b81be102b7cb74384aefc441ff880d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Jan 2026 15:22:34 +0100 Subject: [PATCH 09/15] Sparse comparison: Fix type violation in perform method --- pytensor/sparse/math.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytensor/sparse/math.py b/pytensor/sparse/math.py index c9ba553658..fe6a087151 100644 --- a/pytensor/sparse/math.py +++ b/pytensor/sparse/math.py @@ -973,7 +973,12 @@ def perform(self, node, inputs, outputs): (out,) = outputs assert psb._is_sparse(x) and psb._is_sparse(y) assert x.shape == y.shape - out[0] = self.comparison(x, y).astype("uint8") + # FIXME: Scipy csc > csc outputs csr format, but make_node assumes it will be the same as inputs + # Casting to respect make_node, but this is very inefficient + # TODO: Why not go with default bool? + out[0] = ( + self.comparison(x, y).astype("uint8").asformat(node.outputs[0].type.format) + ) def infer_shape(self, fgraph, node, ins_shapes): return [ins_shapes[0]] From 38e27e408fee45b70f250a481f8deac3e3d22a26 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Jan 2026 14:23:13 +0100 Subject: [PATCH 10/15] Test: Reduce number of parametrizations 2304 -> 288 --- tests/sparse/test_math.py | 33 +++++++++++---------------------- 1 file changed, 11 insertions(+), 22 deletions(-) diff --git a/tests/sparse/test_math.py b/tests/sparse/test_math.py index f8a3a62bca..5faee2b3b4 100644 --- a/tests/sparse/test_math.py +++ b/tests/sparse/test_math.py @@ -9,6 +9,7 @@ import pytensor.sparse.math as psm import pytensor.tensor as pt from pytensor.configdefaults import config +from pytensor.scalar import upcast from pytensor.sparse.basic import ( CSR, CSMProperties, @@ -367,7 +368,7 @@ def test_upcast(self): ) for dense_dtype in typenames: for sparse_dtype in typenames: - correct_dtype = pytensor.scalar.upcast(sparse_dtype, dense_dtype) + correct_dtype = upcast(sparse_dtype, dense_dtype) a = SparseTensorType("csc", dtype=sparse_dtype)() b = matrix(dtype=dense_dtype) d = structured_dot(a, b) @@ -736,11 +737,10 @@ def setup_method(self): @pytest.mark.slow @pytest.mark.parametrize("dtype1", ["float32", "float64", "int16", "complex64"]) @pytest.mark.parametrize("dtype2", ["float32", "float64", "int16", "complex64"]) - @pytest.mark.parametrize("dtype3", ["float32", "float64", "int16", "complex64"]) - @pytest.mark.parametrize("dtype4", ["float32", "float64", "int16", "complex64"]) + @pytest.mark.parametrize("can_inplace", [False, True]) @pytest.mark.parametrize("format1", ["dense", "csc", "csr"]) @pytest.mark.parametrize("format2", ["dense", "csc", "csr"]) - def test_basic(self, dtype1, dtype2, dtype3, dtype4, format1, format2): + def test_basic(self, dtype1, dtype2, can_inplace, format1, format2): def mat(format, name, dtype): if format == "dense": return matrix(name, dtype=dtype) @@ -750,8 +750,13 @@ def mat(format, name, dtype): if format1 == "dense" and format2 == "dense": pytest.skip("Skipping dense-dense case") + dtype3 = upcast(dtype1, dtype2) + dtype4 = dtype3 if can_inplace else "int32" + inplace = can_inplace + x = mat(format1, "x", dtype1) y = mat(format2, "y", dtype2) + a = scalar("a", dtype=dtype3) z = pytensor.shared(np.asarray(self.z, dtype=dtype4).copy()) @@ -769,9 +774,6 @@ def f_b(z, a, x, y): f_b_out = f_b(z_data, a_data, x_data, y_data) - # Can it work inplace? - inplace = dtype4 == pytensor.scalar.upcast(dtype1, dtype2, dtype3) - # To make it easier to check the toposort mode = pytensor.compile.mode.get_default_mode().excluding("fusion") @@ -782,17 +784,7 @@ def f_b(z, a, x, y): f_a_out = z.get_value(borrow=True) else: f_a = pytensor.function([a, x, y], z - a * psm.dot(x, y), mode=mode) - # In DebugMode there is a strange difference with complex - # So we raise a little the threshold a little. - try: - orig_atol = pytensor.tensor.math.float64_atol - orig_rtol = pytensor.tensor.math.float64_rtol - pytensor.tensor.math.float64_atol = 1e-7 - pytensor.tensor.math.float64_rtol = 1e-6 - f_a_out = f_a(a_data, x_data, y_data) - finally: - pytensor.tensor.math.float64_atol = orig_atol - pytensor.tensor.math.float64_rtol = orig_rtol + f_a_out = f_a(a_data, x_data, y_data) # As we do a dot product of 2 vector of 100 element, # This mean we can have 2*100*eps abs error. @@ -804,7 +796,7 @@ def f_b(z, a, x, y): rtol = None utt.assert_allclose(f_a_out, f_b_out, rtol=rtol, atol=atol) topo = f_a.maker.fgraph.toposort() - up = pytensor.scalar.upcast(dtype1, dtype2, dtype3, dtype4) + up = upcast(dtype1, dtype2, dtype3, dtype4) fast_compile = pytensor.config.mode == "FAST_COMPILE" @@ -906,9 +898,6 @@ def f_b(z, a, x, y): f_b_out = f_b(z_data, a_data, x_data, y_data) - # Can it work inplace? - # inplace = dtype4 == pytensor.scalar.upcast(dtype1, dtype2, dtype3) - # To make it easier to check the toposort mode = pytensor.compile.mode.get_default_mode().excluding("fusion") From 89b4f7a2b757b1a56f37d7edb4c952caee38b306 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Jan 2026 16:10:37 +0100 Subject: [PATCH 11/15] Test: Reduce number of function compilation in sparse cast type --- tests/sparse/test_basic.py | 27 +++++++++------------------ 1 file changed, 9 insertions(+), 18 deletions(-) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index 3117932fc1..0deba8a23d 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -1430,25 +1430,16 @@ def test_cast(self): format, shape=(4, 7), out_dtype=i_dtype ) - func = pytensor.function([variable], cast(variable, o_dtype)) - cls = pytensor.function([variable], Cast(o_dtype)(variable)) - prop = pytensor.function([variable], variable.astype(o_dtype)) - - t_func, t_cls, t_prop = func(data), cls(data), prop(data) - + func_cast = cast(variable, o_dtype) + cls_cast = Cast(o_dtype)(variable) + prop_cast = variable.astype(o_dtype) + utt.assert_equal_computations([func_cast], [cls_cast]) + utt.assert_equal_computations([cls_cast], [prop_cast]) + + func_cast_res = func_cast.eval({variable: data}) + assert func_cast_res.format == format expected = data.toarray().astype(o_dtype) - - assert t_func.format == format - assert t_cls.format == format - assert t_prop.format == format - - t_func = t_func.toarray() - t_cls = t_cls.toarray() - t_prop = t_prop.toarray() - - utt.assert_allclose(expected, t_func) - utt.assert_allclose(expected, t_cls) - utt.assert_allclose(expected, t_prop) + utt.assert_allclose(expected, func_cast_res.toarray()) @pytest.mark.slow def test_infer_shape(self): From d1a52a1b1a8f74ffe30a301ec5f70d71a75d5618 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Jan 2026 16:37:02 +0100 Subject: [PATCH 12/15] Numba sparse: Implement basic functionality Co-authored-by: Adrian Seyboldt Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- pytensor/link/numba/dispatch/basic.py | 25 +- pytensor/link/numba/dispatch/sparse.py | 206 -------------- .../link/numba/dispatch/sparse/__init__.py | 1 + .../link/numba/dispatch/sparse/variable.py | 267 ++++++++++++++++++ tests/link/numba/sparse/__init__.py | 0 tests/link/numba/sparse/test_basic.py | 264 +++++++++++++++++ tests/link/numba/test_sparse.py | 126 --------- 7 files changed, 548 insertions(+), 341 deletions(-) delete mode 100644 pytensor/link/numba/dispatch/sparse.py create mode 100644 pytensor/link/numba/dispatch/sparse/__init__.py create mode 100644 pytensor/link/numba/dispatch/sparse/variable.py create mode 100644 tests/link/numba/sparse/__init__.py create mode 100644 tests/link/numba/sparse/test_basic.py delete mode 100644 tests/link/numba/test_sparse.py diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index e92394667a..0f1a148110 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -14,14 +14,13 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.type import Type from pytensor.link.numba.cache import compile_numba_function_src, hash_from_pickle_dump -from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType from pytensor.link.utils import ( fgraph_to_python, ) from pytensor.scalar.basic import ScalarType from pytensor.sparse import SparseTensorType from pytensor.tensor.random.type import RandomGeneratorType -from pytensor.tensor.type import TensorType +from pytensor.tensor.type import DenseTensorType from pytensor.tensor.utils import hash_from_ndarray from pytensor.typed_list import TypedListType @@ -112,7 +111,7 @@ def get_numba_type( Return Numba scalars for zero dimensional :class:`TensorType`\s. """ - if isinstance(pytensor_type, TensorType): + if isinstance(pytensor_type, DenseTensorType): dtype = pytensor_type.numpy_dtype numba_dtype = numba.from_dtype(dtype) if force_scalar or ( @@ -125,18 +124,26 @@ def get_numba_type( numba_dtype = numba.from_dtype(dtype) return numba_dtype elif isinstance(pytensor_type, SparseTensorType): - dtype = pytensor_type.numpy_dtype - numba_dtype = numba.from_dtype(dtype) + from pytensor.link.numba.dispatch.sparse.variable import ( + CSCMatrixType, + CSRMatrixType, + ) + + data_array = numba.types.Array( + numba.from_dtype(pytensor_type.numpy_dtype), 1, layout + ) + indices_array = numba.types.Array(numba.from_dtype(np.int32), 1, layout) + indptr_array = numba.types.Array(numba.from_dtype(np.int32), 1, layout) if pytensor_type.format == "csr": - return CSRMatrixType(numba_dtype) + return CSRMatrixType(data_array, indices_array, indptr_array) if pytensor_type.format == "csc": - return CSCMatrixType(numba_dtype) + return CSCMatrixType(data_array, indices_array, indptr_array) elif isinstance(pytensor_type, RandomGeneratorType): return numba.types.NumPyRandomGeneratorType("NumPyRandomGeneratorType") elif isinstance(pytensor_type, TypedListType): return numba.types.List(get_numba_type(pytensor_type.ttype)) - else: - raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") + + raise NotImplementedError(f"Numba type not implemented for {pytensor_type}") def create_numba_signature( diff --git a/pytensor/link/numba/dispatch/sparse.py b/pytensor/link/numba/dispatch/sparse.py deleted file mode 100644 index e25083e92d..0000000000 --- a/pytensor/link/numba/dispatch/sparse.py +++ /dev/null @@ -1,206 +0,0 @@ -import numpy as np -import scipy as sp -import scipy.sparse -from numba.core import cgutils, types -from numba.core.imputils import impl_ret_borrowed -from numba.extending import ( - NativeValue, - box, - intrinsic, - make_attribute_wrapper, - models, - overload, - overload_attribute, - overload_method, - register_model, - typeof_impl, - unbox, -) - - -class CSMatrixType(types.Type): - """A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`.""" - - name: str - - @staticmethod - def instance_class(data, indices, indptr, shape): - raise NotImplementedError() - - def __init__(self, dtype): - self.dtype = dtype - self.data = types.Array(dtype, 1, "A") - self.indices = types.Array(types.int32, 1, "A") - self.indptr = types.Array(types.int32, 1, "A") - self.shape = types.UniTuple(types.int64, 2) - super().__init__(self.name) - - @property - def key(self): - return (self.name, self.dtype) - - -make_attribute_wrapper(CSMatrixType, "data", "data") -make_attribute_wrapper(CSMatrixType, "indices", "indices") -make_attribute_wrapper(CSMatrixType, "indptr", "indptr") -make_attribute_wrapper(CSMatrixType, "shape", "shape") - - -class CSRMatrixType(CSMatrixType): - name = "csr_matrix" - - @staticmethod - def instance_class(data, indices, indptr, shape): - return sp.sparse.csr_matrix((data, indices, indptr), shape, copy=False) - - -class CSCMatrixType(CSMatrixType): - name = "csc_matrix" - - @staticmethod - def instance_class(data, indices, indptr, shape): - return sp.sparse.csc_matrix((data, indices, indptr), shape, copy=False) - - -@typeof_impl.register(sp.sparse.csc_matrix) -def typeof_csc_matrix(val, c): - data = typeof_impl(val.data, c) - return CSCMatrixType(data.dtype) - - -@typeof_impl.register(sp.sparse.csr_matrix) -def typeof_csr_matrix(val, c): - data = typeof_impl(val.data, c) - return CSRMatrixType(data.dtype) - - -@register_model(CSRMatrixType) -class CSRMatrixModel(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("data", fe_type.data), - ("indices", fe_type.indices), - ("indptr", fe_type.indptr), - ("shape", fe_type.shape), - ] - super().__init__(dmm, fe_type, members) - - -@register_model(CSCMatrixType) -class CSCMatrixModel(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("data", fe_type.data), - ("indices", fe_type.indices), - ("indptr", fe_type.indptr), - ("shape", fe_type.shape), - ] - super().__init__(dmm, fe_type, members) - - -@unbox(CSCMatrixType) -@unbox(CSRMatrixType) -def unbox_matrix(typ, obj, c): - struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder) - - data = c.pyapi.object_getattr_string(obj, "data") - indices = c.pyapi.object_getattr_string(obj, "indices") - indptr = c.pyapi.object_getattr_string(obj, "indptr") - shape = c.pyapi.object_getattr_string(obj, "shape") - - struct_ptr.data = c.unbox(typ.data, data).value - struct_ptr.indices = c.unbox(typ.indices, indices).value - struct_ptr.indptr = c.unbox(typ.indptr, indptr).value - struct_ptr.shape = c.unbox(typ.shape, shape).value - - c.pyapi.decref(data) - c.pyapi.decref(indices) - c.pyapi.decref(indptr) - c.pyapi.decref(shape) - - is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit) - is_error = c.builder.load(is_error_ptr) - - res = NativeValue(struct_ptr._getvalue(), is_error=is_error) - - return res - - -@box(CSCMatrixType) -@box(CSRMatrixType) -def box_matrix(typ, val, c): - struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) - - data_obj = c.box(typ.data, struct_ptr.data) - indices_obj = c.box(typ.indices, struct_ptr.indices) - indptr_obj = c.box(typ.indptr, struct_ptr.indptr) - shape_obj = c.box(typ.shape, struct_ptr.shape) - - c.pyapi.incref(data_obj) - c.pyapi.incref(indices_obj) - c.pyapi.incref(indptr_obj) - c.pyapi.incref(shape_obj) - - cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class)) - obj = c.pyapi.call_function_objargs( - cls_obj, (data_obj, indices_obj, indptr_obj, shape_obj) - ) - - c.pyapi.decref(data_obj) - c.pyapi.decref(indices_obj) - c.pyapi.decref(indptr_obj) - c.pyapi.decref(shape_obj) - - return obj - - -@overload(np.shape) -def overload_sparse_shape(x): - if isinstance(x, CSMatrixType): - return lambda x: x.shape - - -@overload_attribute(CSMatrixType, "ndim") -def overload_sparse_ndim(inst): - if not isinstance(inst, CSMatrixType): - return - - def ndim(inst): - return 2 - - return ndim - - -@intrinsic -def _sparse_copy(typingctx, inst, data, indices, indptr, shape): - def _construct(context, builder, sig, args): - typ = sig.return_type - struct = cgutils.create_struct_proxy(typ)(context, builder) - _, data, indices, indptr, shape = args - struct.data = data - struct.indices = indices - struct.indptr = indptr - struct.shape = shape - return impl_ret_borrowed( - context, - builder, - sig.return_type, - struct._getvalue(), - ) - - sig = inst(inst, inst.data, inst.indices, inst.indptr, inst.shape) - - return sig, _construct - - -@overload_method(CSMatrixType, "copy") -def overload_sparse_copy(inst): - if not isinstance(inst, CSMatrixType): - return - - def copy(inst): - return _sparse_copy( - inst, inst.data.copy(), inst.indices.copy(), inst.indptr.copy(), inst.shape - ) - - return copy diff --git a/pytensor/link/numba/dispatch/sparse/__init__.py b/pytensor/link/numba/dispatch/sparse/__init__.py new file mode 100644 index 0000000000..737d48a5e2 --- /dev/null +++ b/pytensor/link/numba/dispatch/sparse/__init__.py @@ -0,0 +1 @@ +from pytensor.link.numba.dispatch.sparse import variable diff --git a/pytensor/link/numba/dispatch/sparse/variable.py b/pytensor/link/numba/dispatch/sparse/variable.py new file mode 100644 index 0000000000..209052bc8b --- /dev/null +++ b/pytensor/link/numba/dispatch/sparse/variable.py @@ -0,0 +1,267 @@ +import numpy as np +import scipy as sp +from numba.core import cgutils, types +from numba.core.imputils import impl_ret_borrowed, lower_constant +from numba.extending import ( + NativeValue, + box, + intrinsic, + make_attribute_wrapper, + models, + overload, + overload_attribute, + overload_method, + register_model, + typeof_impl, + unbox, +) + + +class CSMatrixType(types.Type): + """A Numba `Type` modeled after the base class `scipy.sparse.compressed._cs_matrix`.""" + + name: str + + @staticmethod + def instance_class(data, indices, indptr, shape): + raise NotImplementedError() + + def __init__(self, data_type, indices_type, indptr_type): + self._key = (data_type, indices_type, indptr_type) + self.data = data_type + self.indices = indices_type + self.indptr = indptr_type + self.shape = types.UniTuple(types.int32, 2) + super().__init__(self.name) + + @property + def key(self): + return self._key + + +make_attribute_wrapper(CSMatrixType, "data", "data") +make_attribute_wrapper(CSMatrixType, "indices", "indices") +make_attribute_wrapper(CSMatrixType, "indptr", "indptr") +make_attribute_wrapper(CSMatrixType, "shape", "shape") + + +class CSRMatrixType(CSMatrixType): + name = "csr_matrix" + + @staticmethod + def instance_class(data, indices, indptr, shape): + return sp.sparse.csr_matrix((data, indices, indptr), shape, copy=False) + + +class CSCMatrixType(CSMatrixType): + name = "csc_matrix" + + @staticmethod + def instance_class(data, indices, indptr, shape): + return sp.sparse.csc_matrix((data, indices, indptr), shape, copy=False) + + +@typeof_impl.register(sp.sparse.csc_matrix) +@typeof_impl.register(sp.sparse.csr_matrix) +def typeof_cs_matrix(val, ctx): + match val: + case sp.sparse.csc_matrix(): + numba_type = CSCMatrixType + case sp.sparse.csr_matrix(): + numba_type = CSRMatrixType + case _: + raise ValueError(f"val of type {type(val)} not recognized") + return numba_type( + typeof_impl(val.data, ctx), + typeof_impl(val.indices, ctx), + typeof_impl(val.indptr, ctx), + ) + + +@register_model(CSCMatrixType) +@register_model(CSRMatrixType) +class CSMatrixModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("data", fe_type.data), + ("indices", fe_type.indices), + ("indptr", fe_type.indptr), + ("shape", fe_type.shape), + ] + super().__init__(dmm, fe_type, members) + + +@unbox(CSMatrixType) +def unbox_cs_matrix(typ, obj, c): + struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder) + + # Get attributes from python object + data = c.pyapi.object_getattr_string(obj, "data") + indices = c.pyapi.object_getattr_string(obj, "indices") + indptr = c.pyapi.object_getattr_string(obj, "indptr") + shape = c.pyapi.object_getattr_string(obj, "shape") + + # Unbox them into llvm struct + struct_ptr.data = c.unbox(typ.data, data).value + struct_ptr.indices = c.unbox(typ.indices, indices).value + struct_ptr.indptr = c.unbox(typ.indptr, indptr).value + struct_ptr.shape = c.unbox(typ.shape, shape).value + + # Decref created attributes + c.pyapi.decref(data) + c.pyapi.decref(indices) + c.pyapi.decref(indptr) + c.pyapi.decref(shape) + + is_error_ptr = cgutils.alloca_once_value(c.builder, cgutils.false_bit) + is_error = c.builder.load(is_error_ptr) + res = NativeValue(struct_ptr._getvalue(), is_error=is_error) + + return res + + +@box(CSMatrixType) +def box_cs_matrix(typ, val, c): + struct_ptr = cgutils.create_struct_proxy(typ)(c.context, c.builder, value=val) + + data_obj = c.box(typ.data, struct_ptr.data) + indices_obj = c.box(typ.indices, struct_ptr.indices) + indptr_obj = c.box(typ.indptr, struct_ptr.indptr) + shape_obj = c.box(typ.shape, struct_ptr.shape) + + # Call scipy.sparse.cs[c|r]_matrix + cls_obj = c.pyapi.unserialize(c.pyapi.serialize_object(typ.instance_class)) + obj = c.pyapi.call_function_objargs( + cls_obj, (data_obj, indices_obj, indptr_obj, shape_obj) + ) + + c.pyapi.decref(data_obj) + c.pyapi.decref(indices_obj) + c.pyapi.decref(indptr_obj) + c.pyapi.decref(shape_obj) + + return obj + + +def _intrinsic_cs_codegen(context, builder, sig, args): + matrix_type = sig.return_type + struct = cgutils.create_struct_proxy(matrix_type)(context, builder) + data, indices, indptr, shape = args + struct.data = data + struct.indices = indices + struct.indptr = indptr + struct.shape = shape + return impl_ret_borrowed( + context, + builder, + matrix_type, + struct._getvalue(), + ) + + +@intrinsic +def csr_matrix_from_components(typingctx, data, indices, indptr, shape): + sig = CSRMatrixType(data, indices, indptr)(data, indices, indptr, shape) + return sig, _intrinsic_cs_codegen + + +@intrinsic +def csc_matrix_from_components(typingctx, data, indices, indptr, shape): + sig = CSCMatrixType(data, indices, indptr)(data, indices, indptr, shape) + return sig, _intrinsic_cs_codegen + + +@lower_constant(CSRMatrixType) +@lower_constant(CSCMatrixType) +def cs_matrix_constant(context, builder, ty, pyval): + data_const = context.make_constant_array(builder, ty.data, pyval.data) + indices_const = context.make_constant_array(builder, ty.indices, pyval.indices) + indptr_const = context.make_constant_array(builder, ty.indptr, pyval.indptr) + shape = context.get_constant_generic(builder, ty.shape, pyval.shape) + args = (data_const, indices_const, indptr_const, shape) + + sig = ty(*args) + return _intrinsic_cs_codegen(context, builder, sig, args) + + +@overload(sp.sparse.csr_matrix) +def overload_csr_matrix(arg1, shape, dtype=None): + if not isinstance(arg1, types.BaseAnonymousTuple) or len(arg1) != 3: + return None + if isinstance(shape, types.NoneType): + return None + + def impl(arg1, shape, dtype=None): + data, indices, indptr = arg1 + int32_shape = (types.int32(shape[0]), types.int32(shape[1])) + return csr_matrix_from_components(data, indices, indptr, int32_shape) + + return impl + + +@overload(sp.sparse.csc_matrix) +def overload_csc_matrix(arg1, shape, dtype=None): + if not isinstance(arg1, types.BaseAnonymousTuple) or len(arg1) != 3: + return None + if isinstance(shape, types.NoneType): + return None + + def impl(arg1, shape, dtype=None): + data, indices, indptr = arg1 + int32_shape = (types.int32(shape[0]), types.int32(shape[1])) + return csc_matrix_from_components(data, indices, indptr, int32_shape) + + return impl + + +@overload(np.shape) +def overload_sparse_shape(matrix): + if isinstance(matrix, CSMatrixType): + return lambda matrix: matrix.shape + + +@overload_attribute(CSMatrixType, "ndim") +def overload_sparse_ndim(matrix): + return lambda matrix: 2 + + +@overload_method(CSMatrixType, "copy") +def overload_sparse_copy(matrix): + match matrix: + case CSRMatrixType(): + builder = csr_matrix_from_components + case CSCMatrixType(): + builder = csc_matrix_from_components + case _: + return + + def copy(matrix): + return builder( + matrix.data.copy(), + matrix.indices.copy(), + matrix.indptr.copy(), + matrix.shape, + ) + + return copy + + +@overload_method(CSMatrixType, "astype") +def overload_sparse_astype(matrix, dtype): + match matrix: + case CSRMatrixType(): + builder = csr_matrix_from_components + case CSCMatrixType(): + builder = csc_matrix_from_components + case _: + return + + def astype(matrix, dtype): + return builder( + matrix.data.astype(dtype), + matrix.indices.copy(), + matrix.indptr.copy(), + matrix.shape, + ) + + return astype diff --git a/tests/link/numba/sparse/__init__.py b/tests/link/numba/sparse/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/link/numba/sparse/test_basic.py b/tests/link/numba/sparse/test_basic.py new file mode 100644 index 0000000000..6716c9312e --- /dev/null +++ b/tests/link/numba/sparse/test_basic.py @@ -0,0 +1,264 @@ +from functools import partial +from sys import getrefcount + +import numpy as np +import pytest +import scipy +import scipy as sp + +import pytensor.sparse as ps +import pytensor.tensor as pt +from pytensor.graph import Apply, Op +from pytensor.sparse.variable import SparseConstant +from pytensor.tensor.type import DenseTensorType + + +numba = pytest.importorskip("numba") + + +# Make sure the Numba customizations are loaded +import pytensor.link.numba.dispatch.sparse # noqa: F401 +from pytensor import config +from pytensor.sparse import SparseTensorType +from tests.link.numba.test_basic import compare_numba_and_py + + +pytestmark = pytest.mark.filterwarnings("error") + + +def sparse_assert_fn(a, b): + a_is_sparse = sp.sparse.issparse(a) + assert a_is_sparse == sp.sparse.issparse(b) + if a_is_sparse: + assert a.format == b.format + assert a.dtype == b.dtype + assert a.shape == b.shape + np.testing.assert_allclose(a.data, b.data, strict=True) + np.testing.assert_allclose(a.indices, b.indices, strict=True) + np.testing.assert_allclose(a.indptr, b.indptr, strict=True) + else: + np.testing.assert_allclose(a, b, strict=True) + + +compare_numba_and_py_sparse = partial(compare_numba_and_py, assert_fn=sparse_assert_fn) + + +def test_sparse_boxing(): + @numba.njit + def boxing_fn(x, y): + return x, y, y.data.sum() + + x_val = sp.sparse.csr_matrix(np.eye(100)) + y_val = sp.sparse.csc_matrix(np.eye(101)) + + res_x_val, res_y_val, res_y_sum = boxing_fn(x_val, y_val) + + assert np.array_equal(res_x_val.data, x_val.data) + assert np.array_equal(res_x_val.indices, x_val.indices) + assert np.array_equal(res_x_val.indptr, x_val.indptr) + assert res_x_val.shape == x_val.shape + + assert np.array_equal(res_y_val.data, y_val.data) + assert np.array_equal(res_y_val.indices, y_val.indices) + assert np.array_equal(res_y_val.indptr, y_val.indptr) + assert res_y_val.shape == y_val.shape + + np.testing.assert_allclose(res_y_sum, y_val.sum()) + + +def test_sparse_creation_refcount(): + @numba.njit + def create_csr_matrix(data, indices, ind_ptr): + return scipy.sparse.csr_matrix((data, indices, ind_ptr), shape=(5, 5)) + + x = scipy.sparse.random(5, 5, density=0.5, format="csr") + + x_data = x.data + x_indptr = x.indptr + assert getrefcount(x_data) == 3 + assert getrefcount(x_indptr) == 3 + + for i in range(5): + a = create_csr_matrix(x.data, x.indices, x.indptr) + + # a.data is a view of the underlying data under x.data, but doesn't reference it directly + assert getrefcount(x_data) == 3 + # x_indptr is reused directly + assert getrefcount(x_indptr) == 4 + + del a + assert getrefcount(x_data) == 3 + assert getrefcount(x_indptr) == 3 + + +def test_sparse_passthrough_refcount(): + @numba.njit + def identity(a): + return a + + x = scipy.sparse.random(5, 5, density=0.5, format="csr") + + x_data = x.data + assert getrefcount(x_data) == 3 + + for i in range(5): + identity(x) + + assert getrefcount(x_data) == 3 + + +def test_sparse_shape(): + @numba.njit + def test_fn(x): + return np.shape(x) + + x_val = sp.sparse.csr_matrix(np.eye(100)) + + res = test_fn(x_val) + + assert res == (100, 100) + + +def test_sparse_ndim(): + @numba.njit + def test_fn(x): + return x.ndim + + x_val = sp.sparse.csr_matrix(np.eye(100)) + + res = test_fn(x_val) + + assert res == 2 + + +def test_sparse_copy(): + @numba.njit + def test_fn(x): + return x.copy() + + x = sp.sparse.csr_matrix(np.eye(100)) + + y = test_fn(x) + assert y is not x + for attr in ("data", "indices", "indptr"): + y_data = getattr(y, attr) + x_data = getattr(x, attr) + assert y_data is not x_data + assert not np.shares_memory(y_data, x_data) + assert (y_data == x_data).all() + + +@pytest.mark.parametrize( + "func", [sp.sparse.csr_matrix, sp.sparse.csc_matrix], ids=["csr", "csc"] +) +def test_sparse_constructor(func): + @numba.njit + def csr_matrix_constructor(data, indices, indptr): + return func((data, indices, indptr), shape=(3, 3)) + + inp = sp.sparse.random(3, 3, density=0.5, format="csr") + + # Test with pure scipy constructor + out = func((inp.data, inp.indices, inp.indptr), copy=False) + # Scipy does a useless slice on data and indices to trim away useless zeros + # which means these attributes are views of the original arrays. + assert out.data is not inp.data + assert not out.data.flags.owndata + + assert out.indices is not inp.indices + assert not out.indices.flags.owndata + + assert out.indptr is inp.indptr + + # Test numba impl + out_pt = csr_matrix_constructor(inp.data, inp.indices, inp.indptr) + # Should work the same as Scipy's constructor, because it's ultimately used + assert type(out_pt) is type(out) + + assert out_pt.data is not inp.data + assert not out_pt.data.flags.owndata + assert np.shares_memory(out_pt.data, inp.data) + assert (out_pt.data == inp.data).all() + + assert out_pt.indices is not inp.indices + assert not out_pt.indices.flags.owndata + assert np.shares_memory(out_pt.indices, inp.indices) + assert (out_pt.indices == inp.indices).all() + + assert out_pt.indptr is inp.indptr + + +@pytest.mark.parametrize("cache", [True, False]) +@pytest.mark.parametrize("format", ["csr", "csc"]) +def test_sparse_constant(format, cache): + x = sp.sparse.random(3, 3, density=0.5, format=format, random_state=166) + x = ps.as_sparse(x) + assert isinstance(x, SparseConstant) + assert x.type.format == format + y = pt.vector("y", shape=(3,)) + out = x * y + + y_test = np.array([np.pi, np.e, np.euler_gamma]) + with config.change_flags(numba__cache=cache): + with pytest.warns( + UserWarning, + match=r"Numba will use object mode to run SparseDenseVectorMultiply's perform method", + ): + compare_numba_and_py_sparse( + [y], + [out], + [y_test], + eval_obj_mode=False, + ) + + +@pytest.mark.parametrize("format", ["csc", "csr"]) +@pytest.mark.parametrize("dense_out", [True, False]) +def test_sparse_objmode(format, dense_out): + class SparseTestOp(Op): + def make_node(self, x): + out = x.type.clone(shape=(1, x.type.shape[-1]))() + if dense_out: + out = out.todense().type() + return Apply(self, [x], [out]) + + def perform(self, node, inputs, output_storage): + [x] = inputs + [out] = output_storage + out[0] = x[0] + if dense_out: + out[0] = out[0].todense() + + x = ps.matrix(format, dtype=config.floatX, shape=(5, 5), name="x") + + out = SparseTestOp()(x) + assert out.type.shape == (1, 5) + assert isinstance(out.type, DenseTensorType if dense_out else SparseTensorType) + + x_val = sp.sparse.random(5, 5, density=0.25, dtype=config.floatX, format=format) + + with pytest.warns( + UserWarning, + match="Numba will use object mode to run SparseTestOp's perform method", + ): + compare_numba_and_py_sparse([x], out, [x_val]) + + +@pytest.mark.parametrize("format", ["csr", "csc"]) +def test_simple_graph(format): + x = ps.matrix(format, name="x", shape=(3, 3)) + y = pt.tensor("y", shape=(3,)) + z = ps.math.sin(x * y) + + rng = np.random.default_rng((155, format == "csr")) + x_test = sp.sparse.random(3, 3, density=0.5, format=format, random_state=rng) + y_test = rng.normal(size=(3,)) + + with pytest.warns( + UserWarning, match=r"Numba will use object mode to run .* perform method" + ): + compare_numba_and_py_sparse( + [x, y], + z, + [x_test, y_test], + ) diff --git a/tests/link/numba/test_sparse.py b/tests/link/numba/test_sparse.py deleted file mode 100644 index 1e0a1cd5c0..0000000000 --- a/tests/link/numba/test_sparse.py +++ /dev/null @@ -1,126 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import scipy as sp - - -numba = pytest.importorskip("numba") - - -# Make sure the Numba customizations are loaded -import pytensor.link.numba.dispatch.sparse # noqa: F401 -from pytensor import config -from pytensor.sparse import Dot, SparseTensorType -from tests.link.numba.test_basic import compare_numba_and_py - - -pytestmark = pytest.mark.filterwarnings("error") - - -def sparse_assert_fn(a, b): - a_is_sparse = sp.sparse.issparse(a) - assert a_is_sparse == sp.sparse.issparse(b) - if a_is_sparse: - assert a.format == b.format - assert a.dtype == b.dtype - assert a.shape == b.shape - np.testing.assert_allclose(a.data, b.data, strict=True) - np.testing.assert_allclose(a.indices, b.indices, strict=True) - np.testing.assert_allclose(a.indptr, b.indptr, strict=True) - else: - np.testing.assert_allclose(a, b, strict=True) - - -compare_numba_and_py_sparse = partial(compare_numba_and_py, assert_fn=sparse_assert_fn) - - -def test_sparse_unboxing(): - @numba.njit - def test_unboxing(x, y): - return x.shape, y.shape - - x_val = sp.sparse.csr_matrix(np.eye(100)) - y_val = sp.sparse.csc_matrix(np.eye(101)) - - res = test_unboxing(x_val, y_val) - - assert res == (x_val.shape, y_val.shape) - - -def test_sparse_boxing(): - @numba.njit - def test_boxing(x, y): - return x, y - - x_val = sp.sparse.csr_matrix(np.eye(100)) - y_val = sp.sparse.csc_matrix(np.eye(101)) - - res_x_val, res_y_val = test_boxing(x_val, y_val) - - assert np.array_equal(res_x_val.data, x_val.data) - assert np.array_equal(res_x_val.indices, x_val.indices) - assert np.array_equal(res_x_val.indptr, x_val.indptr) - assert res_x_val.shape == x_val.shape - - assert np.array_equal(res_y_val.data, y_val.data) - assert np.array_equal(res_y_val.indices, y_val.indices) - assert np.array_equal(res_y_val.indptr, y_val.indptr) - assert res_y_val.shape == y_val.shape - - -def test_sparse_shape(): - @numba.njit - def test_fn(x): - return np.shape(x) - - x_val = sp.sparse.csr_matrix(np.eye(100)) - - res = test_fn(x_val) - - assert res == (100, 100) - - -def test_sparse_ndim(): - @numba.njit - def test_fn(x): - return x.ndim - - x_val = sp.sparse.csr_matrix(np.eye(100)) - - res = test_fn(x_val) - - assert res == 2 - - -def test_sparse_copy(): - @numba.njit - def test_fn(x): - y = x.copy() - return ( - y is not x and np.all(x.data == y.data) and np.all(x.indices == y.indices) - ) - - x_val = sp.sparse.csr_matrix(np.eye(100)) - - assert test_fn(x_val) - - -def test_sparse_objmode(): - x = SparseTensorType("csc", dtype=config.floatX)() - y = SparseTensorType("csc", dtype=config.floatX)() - - out = Dot()(x, y) - - x_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX, format="csc") - y_val = sp.sparse.random(2, 2, density=0.25, dtype=config.floatX, format="csc") - - with pytest.warns( - UserWarning, - match="Numba will use object mode to run SparseDot's perform method", - ): - compare_numba_and_py_sparse( - [x, y], - out, - [x_val, y_val], - ) From 70c659b2f42ef9a285c5625a856f14b06c7803a6 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 16 Jan 2026 16:47:08 +0100 Subject: [PATCH 13/15] Numba sparse: Implement basic Ops Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- .../link/numba/dispatch/sparse/__init__.py | 2 +- pytensor/link/numba/dispatch/sparse/basic.py | 68 +++++++++++++++++++ pytensor/link/numba/dispatch/typed_list.py | 13 ++++ tests/link/numba/sparse/test_basic.py | 10 ++- 4 files changed, 91 insertions(+), 2 deletions(-) create mode 100644 pytensor/link/numba/dispatch/sparse/basic.py diff --git a/pytensor/link/numba/dispatch/sparse/__init__.py b/pytensor/link/numba/dispatch/sparse/__init__.py index 737d48a5e2..1f47f17ead 100644 --- a/pytensor/link/numba/dispatch/sparse/__init__.py +++ b/pytensor/link/numba/dispatch/sparse/__init__.py @@ -1 +1 @@ -from pytensor.link.numba.dispatch.sparse import variable +from pytensor.link.numba.dispatch.sparse import basic, variable diff --git a/pytensor/link/numba/dispatch/sparse/basic.py b/pytensor/link/numba/dispatch/sparse/basic.py new file mode 100644 index 0000000000..d09aa9d943 --- /dev/null +++ b/pytensor/link/numba/dispatch/sparse/basic.py @@ -0,0 +1,68 @@ +import numpy as np +import scipy as sp +from numba.extending import overload + +from pytensor import config +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.basic import ( + generate_fallback_impl, + register_funcify_default_op_cache_key, +) +from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy +from pytensor.link.numba.dispatch.sparse.variable import CSMatrixType +from pytensor.sparse import CSM, Cast, CSMProperties + + +@overload(numba_deepcopy) +def numba_deepcopy_sparse(x): + if isinstance(x, CSMatrixType): + + def sparse_deepcopy(x): + return x.copy() + + return sparse_deepcopy + + +@register_funcify_default_op_cache_key(CSMProperties) +def numba_funcify_CSMProperties(op, node, **kwargs): + @numba_basic.numba_njit + def csm_properties(x): + # Reconsider this int32/int64. Scipy/base PyTensor use int32 for indices/indptr. + # But this seems to be legacy mistake and devs would choose int64 nowadays, and may move there. + return x.data, x.indices, x.indptr, np.asarray(x.shape, dtype="int32") + + return csm_properties + + +@register_funcify_default_op_cache_key(CSM) +def numba_funcify_CSM(op, node, **kwargs): + format = op.format + + @numba_basic.numba_njit + def csm_constructor(data, indices, indptr, shape): + constructor_arg = (data, indices, indptr) + shape_arg = (shape[0], shape[1]) + if format == "csr": + return sp.sparse.csr_matrix(constructor_arg, shape=shape_arg) + else: + return sp.sparse.csc_matrix(constructor_arg, shape=shape_arg) + + return csm_constructor + + +@register_funcify_default_op_cache_key(Cast) +def numba_funcify_Cast(op, node, **kwargs): + inp_dtype = node.inputs[0].type.dtype + out_dtype = np.dtype(op.out_type) + if not np.can_cast(inp_dtype, out_dtype): + if config.compiler_verbose: + print( # noqa: T201 + f"Sparse Cast fallback to obj mode due to unsafe casting from {inp_dtype} to {out_dtype}" + ) + return generate_fallback_impl(op, node, **kwargs) + + @numba_basic.numba_njit + def cast(x): + return x.astype(out_dtype) + + return cast diff --git a/pytensor/link/numba/dispatch/typed_list.py b/pytensor/link/numba/dispatch/typed_list.py index 90c0e2ac1e..696e9239a5 100644 --- a/pytensor/link/numba/dispatch/typed_list.py +++ b/pytensor/link/numba/dispatch/typed_list.py @@ -9,6 +9,7 @@ register_funcify_default_op_cache_key, ) from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy +from pytensor.link.numba.dispatch.sparse.variable import CSCMatrixType, CSRMatrixType from pytensor.tensor.type_other import SliceType from pytensor.typed_list import ( Append, @@ -64,6 +65,18 @@ def all_equal(x, y): def all_equal(x, y): return x == y + if (isinstance(x, CSRMatrixType) and isinstance(y, CSRMatrixType)) or ( + isinstance(x, CSCMatrixType) and isinstance(y, CSCMatrixType) + ): + + def all_equal(x, y): + return ( + x.shape == y.shape + and (x.data == y.data).all() + and (x.indptr == y.indptr).all() + and (x.indices == y.indices).all() + ) + return all_equal diff --git a/tests/link/numba/sparse/test_basic.py b/tests/link/numba/sparse/test_basic.py index 6716c9312e..407b36a0f2 100644 --- a/tests/link/numba/sparse/test_basic.py +++ b/tests/link/numba/sparse/test_basic.py @@ -255,10 +255,18 @@ def test_simple_graph(format): y_test = rng.normal(size=(3,)) with pytest.warns( - UserWarning, match=r"Numba will use object mode to run .* perform method" + UserWarning, + match=r"Numba will use object mode to run SparseDenseVectorMultiply's perform method", ): compare_numba_and_py_sparse( [x, y], z, [x_test, y_test], ) + + +@pytest.mark.parametrize("format", ("csr", "csc")) +def test_sparse_deepcopy(format): + x = ps.matrix(shape=(3, 3), format=format) + x_test = sp.sparse.random(3, 3, density=0.5, format=format) + compare_numba_and_py_sparse([x], [x], [x_test]) From 2feac06724c9b11cb0db71e2535f291d6f36d298 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 17 Nov 2025 15:13:11 +0100 Subject: [PATCH 14/15] Numba sparse: Implement SparseDenseMultiply Co-authored-by: Jesse Grabowski <48652735+jessegrabowski@users.noreply.github.com> --- .../link/numba/dispatch/sparse/__init__.py | 2 +- pytensor/link/numba/dispatch/sparse/math.py | 90 +++++++++++++++++++ tests/link/numba/sparse/test_basic.py | 30 +++---- tests/link/numba/sparse/test_math.py | 28 ++++++ 4 files changed, 130 insertions(+), 20 deletions(-) create mode 100644 pytensor/link/numba/dispatch/sparse/math.py create mode 100644 tests/link/numba/sparse/test_math.py diff --git a/pytensor/link/numba/dispatch/sparse/__init__.py b/pytensor/link/numba/dispatch/sparse/__init__.py index 1f47f17ead..f54ecc1363 100644 --- a/pytensor/link/numba/dispatch/sparse/__init__.py +++ b/pytensor/link/numba/dispatch/sparse/__init__.py @@ -1 +1 @@ -from pytensor.link.numba.dispatch.sparse import basic, variable +from pytensor.link.numba.dispatch.sparse import basic, math, variable diff --git a/pytensor/link/numba/dispatch/sparse/math.py b/pytensor/link/numba/dispatch/sparse/math.py new file mode 100644 index 0000000000..c1d3f1baba --- /dev/null +++ b/pytensor/link/numba/dispatch/sparse/math.py @@ -0,0 +1,90 @@ +from pytensor.link.numba.dispatch import basic as numba_basic +from pytensor.link.numba.dispatch.basic import register_funcify_default_op_cache_key +from pytensor.sparse import SparseDenseMultiply, SparseDenseVectorMultiply + + +@register_funcify_default_op_cache_key(SparseDenseMultiply) +@register_funcify_default_op_cache_key(SparseDenseVectorMultiply) +def numba_funcify_SparseDenseMultiply(op, node, **kwargs): + x, y = node.inputs + [z] = node.outputs + out_dtype = z.type.dtype + format = z.type.format + same_dtype = x.type.dtype == out_dtype + + if y.ndim == 0: + + @numba_basic.numba_njit + def sparse_multiply_scalar(x, y): + if same_dtype: + z = x.copy() + else: + z = x.astype(out_dtype) + # Numba doesn't know how to handle in-place mutation / assignment of fields + # z.data *= y + z_data = z.data + z_data *= y + return z + + return sparse_multiply_scalar + + elif y.ndim == 1: + + @numba_basic.numba_njit + def sparse_dense_multiply(x, y): + assert x.shape[1] == y.shape[0] + if same_dtype: + z = x.copy() + else: + z = x.astype(out_dtype) + + M, N = x.shape + indices = x.indices + indptr = x.indptr + z_data = z.data + if format == "csc": + for j in range(0, N): + for i_idx in range(indptr[j], indptr[j + 1]): + z_data[i_idx] *= y[j] + return z + + else: + for i in range(0, M): + for j_idx in range(indptr[i], indptr[i + 1]): + j = indices[j_idx] + z_data[j_idx] *= y[j] + + return z + + return sparse_dense_multiply + + else: # y.ndim == 2 + + @numba_basic.numba_njit + def sparse_dense_multiply(x, y): + assert x.shape == y.shape + if same_dtype: + z = x.copy() + else: + z = x.astype(out_dtype) + + M, N = x.shape + indices = x.indices + indptr = x.indptr + z_data = z.data + if format == "csc": + for j in range(0, N): + for i_idx in range(indptr[j], indptr[j + 1]): + i = indices[i_idx] + z_data[i_idx] *= y[i, j] + return z + + else: + for i in range(0, M): + for j_idx in range(indptr[i], indptr[i + 1]): + j = indices[j_idx] + z_data[j_idx] *= y[i, j] + + return z + + return sparse_dense_multiply diff --git a/tests/link/numba/sparse/test_basic.py b/tests/link/numba/sparse/test_basic.py index 407b36a0f2..b43bfe7f32 100644 --- a/tests/link/numba/sparse/test_basic.py +++ b/tests/link/numba/sparse/test_basic.py @@ -200,16 +200,12 @@ def test_sparse_constant(format, cache): y_test = np.array([np.pi, np.e, np.euler_gamma]) with config.change_flags(numba__cache=cache): - with pytest.warns( - UserWarning, - match=r"Numba will use object mode to run SparseDenseVectorMultiply's perform method", - ): - compare_numba_and_py_sparse( - [y], - [out], - [y_test], - eval_obj_mode=False, - ) + compare_numba_and_py_sparse( + [y], + [out], + [y_test], + eval_obj_mode=False, + ) @pytest.mark.parametrize("format", ["csc", "csr"]) @@ -254,15 +250,11 @@ def test_simple_graph(format): x_test = sp.sparse.random(3, 3, density=0.5, format=format, random_state=rng) y_test = rng.normal(size=(3,)) - with pytest.warns( - UserWarning, - match=r"Numba will use object mode to run SparseDenseVectorMultiply's perform method", - ): - compare_numba_and_py_sparse( - [x, y], - z, - [x_test, y_test], - ) + compare_numba_and_py_sparse( + [x, y], + z, + [x_test, y_test], + ) @pytest.mark.parametrize("format", ("csr", "csc")) diff --git a/tests/link/numba/sparse/test_math.py b/tests/link/numba/sparse/test_math.py new file mode 100644 index 0000000000..ec3c744e4e --- /dev/null +++ b/tests/link/numba/sparse/test_math.py @@ -0,0 +1,28 @@ +import numpy as np +import pytest +import scipy + +import pytensor.sparse as ps +import pytensor.tensor as pt +from tests.link.numba.sparse.test_basic import compare_numba_and_py_sparse + + +pytestmark = pytest.mark.filterwarnings("error") + + +@pytest.mark.parametrize("format", ["csr", "csc"]) +@pytest.mark.parametrize("y_ndim", [0, 1, 2]) +def test_sparse_dense_multiply(y_ndim, format): + x = ps.matrix(format, name="x", shape=(3, 3)) + y = pt.tensor("y", shape=(3,) * y_ndim) + z = x * y + + rng = np.random.default_rng((155, y_ndim, format == "csr")) + x_test = scipy.sparse.random(3, 3, density=0.5, format=format, random_state=rng) + y_test = rng.normal(size=(3,) * y_ndim) + + compare_numba_and_py_sparse( + [x, y], + z, + [x_test, y_test], + ) From d8702776860d508c2d0effd6953a92640e6fea02 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 15 Jan 2026 18:23:02 +0100 Subject: [PATCH 15/15] Numba sparse: Remove codebase xfails --- tests/compile/function/test_pfunc.py | 7 +-- tests/sparse/__init__.py | 11 ----- tests/sparse/test_math.py | 70 ++++++++++++++++++---------- tests/sparse/test_rewriting.py | 17 +++++++ tests/test_raise_op.py | 7 +-- tests/typed_list/test_basic.py | 10 ---- 6 files changed, 65 insertions(+), 57 deletions(-) diff --git a/tests/compile/function/test_pfunc.py b/tests/compile/function/test_pfunc.py index bd4269a5b9..5de5383466 100644 --- a/tests/compile/function/test_pfunc.py +++ b/tests/compile/function/test_pfunc.py @@ -3,14 +3,13 @@ import scipy as sp import pytensor.tensor as pt -from pytensor.compile import UnusedInputError, get_default_mode, get_mode +from pytensor.compile import UnusedInputError, get_mode from pytensor.compile.function import function, pfunc from pytensor.compile.function.pfunc import rebuild_collect_shared from pytensor.compile.io import In from pytensor.compile.sharedvalue import shared from pytensor.configdefaults import config from pytensor.graph.utils import MissingInputError -from pytensor.link.numba import NumbaLinker from pytensor.sparse import SparseTensorType from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.type import ( @@ -766,10 +765,6 @@ def test_shared_constructor_copies(self): # rule #2 reading back from pytensor-managed memory assert not np.may_share_memory(A.get_value(borrow=False), data_of(A)) - @pytest.mark.xfail( - condition=isinstance(get_default_mode().linker, NumbaLinker), - reason="Numba does not support Sparse Ops yet", - ) def test_sparse_input_aliasing_affecting_inplace_operations(self): # Note: to trigger this bug with pytensor rev 4586:2bc6fc7f218b, # you need to make in inputs mutable (so that inplace diff --git a/tests/sparse/__init__.py b/tests/sparse/__init__.py index 3672a7fca6..e69de29bb2 100644 --- a/tests/sparse/__init__.py +++ b/tests/sparse/__init__.py @@ -1,11 +0,0 @@ -import pytest - -from pytensor.compile import get_default_mode -from pytensor.link.numba import NumbaLinker - - -if isinstance(get_default_mode().linker, NumbaLinker): - pytest.skip( - reason="Numba does not support Sparse Ops yet", - allow_module_level=True, - ) diff --git a/tests/sparse/test_math.py b/tests/sparse/test_math.py index 5faee2b3b4..eff3d51958 100644 --- a/tests/sparse/test_math.py +++ b/tests/sparse/test_math.py @@ -8,7 +8,9 @@ import pytensor import pytensor.sparse.math as psm import pytensor.tensor as pt +from pytensor.compile import get_default_mode from pytensor.configdefaults import config +from pytensor.link.numba import NumbaLinker from pytensor.scalar import upcast from pytensor.sparse.basic import ( CSR, @@ -427,33 +429,54 @@ def test_opt_unpack(self): ) f(kernvals, imvals) - def test_dot_sparse_sparse(self): + @pytest.mark.parametrize( + "sparse_format_a", + ( + "csc", + "csr", + pytest.param( + "bsr", + marks=pytest.mark.xfail( + isinstance(get_default_mode().linker, NumbaLinker), + reason="Numba does not support bsr", + ), + ), + ), + ) + @pytest.mark.parametrize( + "sparse_format_b", + ( + "csc", + "csr", + pytest.param( + "bsr", + marks=pytest.mark.xfail( + isinstance(get_default_mode().linker, NumbaLinker), + reason="Numba does not support bsr", + ), + ), + ), + ) + def test_dot_sparse_sparse(self, sparse_format_a, sparse_format_b): sparse_dtype = "float64" sp_mat = { "csc": scipy_sparse.csc_matrix, "csr": scipy_sparse.csr_matrix, "bsr": scipy_sparse.csr_matrix, } - - for sparse_format_a in ["csc", "csr", "bsr"]: - for sparse_format_b in ["csc", "csr", "bsr"]: - a = SparseTensorType(sparse_format_a, dtype=sparse_dtype)() - b = SparseTensorType(sparse_format_b, dtype=sparse_dtype)() - d = pt.dot(a, b) - f = pytensor.function([a, b], d) - for M, N, K, nnz in [ - (4, 3, 2, 3), - (40, 30, 20, 3), - (40, 30, 20, 30), - (400, 3000, 200, 6000), - ]: - a_val = sp_mat[sparse_format_a]( - random_lil((M, N), sparse_dtype, nnz) - ) - b_val = sp_mat[sparse_format_b]( - random_lil((N, K), sparse_dtype, nnz) - ) - f(a_val, b_val) + a = SparseTensorType(sparse_format_a, dtype=sparse_dtype)() + b = SparseTensorType(sparse_format_b, dtype=sparse_dtype)() + d = pt.dot(a, b) + f = pytensor.function([a, b], d) + for M, N, K, nnz in [ + (4, 3, 2, 3), + (40, 30, 20, 3), + (40, 30, 20, 30), + (400, 3000, 200, 6000), + ]: + a_val = sp_mat[sparse_format_a](random_lil((M, N), sparse_dtype, nnz)) + b_val = sp_mat[sparse_format_b](random_lil((N, K), sparse_dtype, nnz)) + f(a_val, b_val) # TODO: Test something def test_tensor_dot_types(self): x = csc_matrix("x") @@ -775,7 +798,7 @@ def f_b(z, a, x, y): f_b_out = f_b(z_data, a_data, x_data, y_data) # To make it easier to check the toposort - mode = pytensor.compile.mode.get_default_mode().excluding("fusion") + mode = get_default_mode().excluding("fusion") if inplace: updates = [(z, z - a * psm.dot(x, y))] @@ -815,8 +838,7 @@ def f_b(z, a, x, y): y.type.dtype == up and format1 == "csc" and format2 == "dense" - and not fast_compile - and pytensor.config.cxx + and "cxx_only" not in f_a.maker.linker.incompatible_rewrites and up in ("float32", "float64") ): # The op UsmmCscDense should be inserted diff --git a/tests/sparse/test_rewriting.py b/tests/sparse/test_rewriting.py index ad782ce2ae..3d6d4dbeee 100644 --- a/tests/sparse/test_rewriting.py +++ b/tests/sparse/test_rewriting.py @@ -7,6 +7,7 @@ from pytensor import sparse from pytensor.compile.mode import Mode, get_default_mode from pytensor.configdefaults import config +from pytensor.link.numba import NumbaLinker from pytensor.sparse.rewriting import SamplingDotCSR, sd_csc from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.math import sum as pt_sum @@ -68,6 +69,10 @@ def test_local_csm_grad_c(): @pytest.mark.skipif( not pytensor.config.cxx, reason="G++ not available, so we need to skip this test." ) +@pytest.mark.skipif( + isinstance(get_default_mode().linker, NumbaLinker), + reason="This is a C-specific test", +) def test_local_mul_s_d(): for sp_format in sparse.sparse_formats: inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), matrix()] @@ -83,6 +88,10 @@ def test_local_mul_s_d(): @pytest.mark.skipif( not pytensor.config.cxx, reason="G++ not available, so we need to skip this test." ) +@pytest.mark.skipif( + isinstance(get_default_mode().linker, NumbaLinker), + reason="This is a C-specific test", +) def test_local_mul_s_v(): mode = get_default_mode() mode = mode.including("specialize", "local_mul_s_v") @@ -101,6 +110,10 @@ def test_local_mul_s_v(): @pytest.mark.skipif( not pytensor.config.cxx, reason="G++ not available, so we need to skip this test." ) +@pytest.mark.skipif( + isinstance(get_default_mode().linker, NumbaLinker), + reason="This is a C-specific test", +) def test_local_structured_add_s_v(): for sp_format in ["csr"]: # Not implemented for other format inputs = [getattr(pytensor.sparse, sp_format + "_matrix")(), vector()] @@ -116,6 +129,10 @@ def test_local_structured_add_s_v(): @pytest.mark.skipif( not pytensor.config.cxx, reason="G++ not available, so we need to skip this test." ) +@pytest.mark.skipif( + isinstance(get_default_mode().linker, NumbaLinker), + reason="This is a C-specific test", +) def test_local_sampling_dot_csr(): for sp_format in ["csr"]: # Not implemented for other format inputs = [ diff --git a/tests/test_raise_op.py b/tests/test_raise_op.py index af85ce50f6..9ba6040418 100644 --- a/tests/test_raise_op.py +++ b/tests/test_raise_op.py @@ -4,10 +4,9 @@ import pytensor import pytensor.tensor as pt -from pytensor.compile.mode import OPT_FAST_RUN, Mode, get_default_mode +from pytensor.compile.mode import OPT_FAST_RUN, Mode from pytensor.graph import vectorize_graph from pytensor.graph.basic import Constant, equal_computations -from pytensor.link.numba import NumbaLinker from pytensor.raise_op import Assert, CheckAndRaise, assert_op from pytensor.scalar.basic import ScalarType, float64 from pytensor.sparse import as_sparse_variable @@ -182,10 +181,6 @@ def test_infer_shape_scalar(self): ) -@pytest.mark.xfail( - condition=isinstance(get_default_mode().linker, NumbaLinker), - reason="Numba does not support Sparse Ops yet", -) def test_CheckAndRaise_sparse_variable(): check_and_raise = CheckAndRaise(ValueError, "sparse_check") diff --git a/tests/typed_list/test_basic.py b/tests/typed_list/test_basic.py index d7de404f6c..feae2361a9 100644 --- a/tests/typed_list/test_basic.py +++ b/tests/typed_list/test_basic.py @@ -7,8 +7,6 @@ import pytensor import pytensor.typed_list from pytensor import sparse -from pytensor.compile import get_default_mode -from pytensor.link.numba import NumbaLinker from pytensor.tensor.type import ( TensorType, integer_dtypes, @@ -454,10 +452,6 @@ def test_non_tensor_type(self): assert f([[x, y], [x, y, y]], [x, y]) == 0 - @pytest.mark.xfail( - condition=isinstance(get_default_mode().linker, NumbaLinker), - reason="Numba does not support Sparse Ops yet", - ) def test_sparse(self): mySymbolicSparseList = TypedListType( sparse.SparseTensorType("csr", pytensor.config.floatX) @@ -525,10 +519,6 @@ def test_non_tensor_type(self): assert f([[x, y], [x, y, y]], [x, y]) == 1 - @pytest.mark.xfail( - condition=isinstance(get_default_mode().linker, NumbaLinker), - reason="Numba does not support Sparse Ops yet", - ) def test_sparse(self): mySymbolicSparseList = TypedListType( sparse.SparseTensorType("csr", pytensor.config.floatX)