From 95f9490f164c61a8bd8a3c1f243b357694198383 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 20 Oct 2025 12:59:36 +0200 Subject: [PATCH 01/42] Add StructuredSparseTensor --- .gitignore | 3 + src/torchjd/autogram/_engine.py | 18 +- src/torchjd/sparse/__init__.py | 3 + .../_aten_function_overrides/__init__.py | 1 + .../_aten_function_overrides/backward.py | 36 ++ .../sparse/_aten_function_overrides/einsum.py | 253 +++++++++++ .../_aten_function_overrides/pointwise.py | 125 ++++++ .../sparse/_aten_function_overrides/shape.py | 289 ++++++++++++ src/torchjd/sparse/_coalesce.py | 19 + .../sparse/_structured_sparse_tensor.py | 278 ++++++++++++ tests/unit/sparse/__init__.py | 0 .../sparse/test_structured_sparse_tensor.py | 422 ++++++++++++++++++ 12 files changed, 1434 insertions(+), 13 deletions(-) create mode 100644 src/torchjd/sparse/__init__.py create mode 100644 src/torchjd/sparse/_aten_function_overrides/__init__.py create mode 100644 src/torchjd/sparse/_aten_function_overrides/backward.py create mode 100644 src/torchjd/sparse/_aten_function_overrides/einsum.py create mode 100644 src/torchjd/sparse/_aten_function_overrides/pointwise.py create mode 100644 src/torchjd/sparse/_aten_function_overrides/shape.py create mode 100644 src/torchjd/sparse/_coalesce.py create mode 100644 src/torchjd/sparse/_structured_sparse_tensor.py create mode 100644 tests/unit/sparse/__init__.py create mode 100644 tests/unit/sparse/test_structured_sparse_tensor.py diff --git a/.gitignore b/.gitignore index 902e607c9..01f539d22 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +# Jupyter notebooks +*.ipynb + # uv uv.lock diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 361743a40..964b94a67 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -4,6 +4,8 @@ from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge +from torchjd.sparse import make_sst + from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator from ._gramian_computer import GramianComputer, JacobianBasedGramianComputerWithCrossTerms @@ -173,7 +175,9 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: ) output_dims = list(range(output.ndim)) - jac_output = _make_initial_jac_output(output) + identity = torch.eye(output.ndim, dtype=torch.int64) + strides = torch.concatenate([identity, identity], dim=0) + jac_output = make_sst(torch.ones_like(output), strides) vmapped_diff = differentiation for _ in output_dims: @@ -193,15 +197,3 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: gramian_computer.reset() return gramian - - -def _make_initial_jac_output(output: Tensor) -> Tensor: - if output.ndim == 0: - return torch.ones_like(output) - p_index_ranges = [torch.arange(s, device=output.device) for s in output.shape] - p_indices_grid = torch.meshgrid(*p_index_ranges, indexing="ij") - v_indices_grid = p_indices_grid + p_indices_grid - - res = torch.zeros(list(output.shape) * 2, device=output.device, dtype=output.dtype) - res[v_indices_grid] = 1.0 - return res diff --git a/src/torchjd/sparse/__init__.py b/src/torchjd/sparse/__init__.py new file mode 100644 index 000000000..7a161b6ad --- /dev/null +++ b/src/torchjd/sparse/__init__.py @@ -0,0 +1,3 @@ +# Need to import this to execute the code inside and thus to override the functions +from . import _aten_function_overrides +from ._structured_sparse_tensor import StructuredSparseTensor, make_sst diff --git a/src/torchjd/sparse/_aten_function_overrides/__init__.py b/src/torchjd/sparse/_aten_function_overrides/__init__.py new file mode 100644 index 000000000..b33cf8d62 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/__init__.py @@ -0,0 +1 @@ +from . import backward, einsum, pointwise, shape diff --git a/src/torchjd/sparse/_aten_function_overrides/backward.py b/src/torchjd/sparse/_aten_function_overrides/backward.py new file mode 100644 index 000000000..9168c7653 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/backward.py @@ -0,0 +1,36 @@ +from torch import Tensor +from torch.ops import aten # type: ignore + +from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl + + +@impl(aten.threshold_backward.default) +def threshold_backward_default( + grad_output: StructuredSparseTensor, self: Tensor, threshold +) -> StructuredSparseTensor: + new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) + + return StructuredSparseTensor(new_physical, grad_output.strides) + + +@impl(aten.hardtanh_backward.default) +def hardtanh_backward_default( + grad_output: StructuredSparseTensor, + self: Tensor, + min_val: Tensor | int | float, + max_val: Tensor | int | float, +) -> StructuredSparseTensor: + if isinstance(self, StructuredSparseTensor): + raise NotImplementedError() + + new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) + return StructuredSparseTensor(new_physical, grad_output.strides) + + +@impl(aten.hardswish_backward.default) +def hardswish_backward_default(grad_output: StructuredSparseTensor, self: Tensor): + if isinstance(self, StructuredSparseTensor): + raise NotImplementedError() + + new_physical = aten.hardswish_backward.default(grad_output.physical, self) + return StructuredSparseTensor(new_physical, grad_output.strides) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py new file mode 100644 index 000000000..9775cccc3 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -0,0 +1,253 @@ +import torch +from torch import Tensor, tensor +from torch.ops import aten # type: ignore + +from torchjd.sparse._structured_sparse_tensor import ( + StructuredSparseTensor, + impl, + to_most_efficient_tensor, + to_structured_sparse_tensor, +) + + +def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) -> Tensor: + raise NotImplementedError() + + # First part of the algorithm, determine how to cluster physical indices as well as the common + # p_shapes corresponding to matching v_dims. Second part translates to physical einsum. + + # get a map from einsum index to (tensor_idx, v_dims) + # get a map from einsum index to merge of strides corresponding to v_dims with that index + # use to_target_physical_strides on each physical and v_to_ps + # cluster pairs of (einsum_index, new_stride) using new_v_to_ps and possibly its corresponding + # p_to_vs + # get unique indices + # map output indices (there can be splits) + # call physical einsum + # build resulting sst + + # OVER + + # an index in the physical einsum is uniquely characterized by a virtual einsum index and a + # stride corresponding to the physical stride in the virtual one (note that as the virtual shape + # for two virtual index that match should match, then we want to match the strides and reshape + # accordingly). + # We want to cluster such indices whenever several appear in the same p_to_vs + + # TODO: Handle ellipsis + # If we have an index v for some virtual dim whose corresponding v_to_ps is a non-trivial list + # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension. + # For this reason, an index is decomposed into sub-indices that are then independently + # clustered. + # So if an index i in args for some StructuredSparseTensor corresponds to a v_to_ps [j, k, l], + # We will consider three indices (i, 0), (i, 1) and (i, 2). + # If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then + # (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in + # the resulting einsum). + # Note that this is a problem if two virtual dimensions (from possibly different + # StructuredSparseTensors) have the same size but not the same decomposition into physical + # dimension sizes. For now lets leave the responsibility to care about that in the calling + # functions, if we can factor code later on we will. + + index_parents = dict[tuple[int, int], tuple[int, int]]() + + def get_representative(index: tuple[int, int]) -> tuple[int, int]: + if index not in index_parents: + # If an index is not yet in a cluster, put it in its own. + index_parents[index] = index + current = index_parents[index] + if current != index: + # Compress path to representative + index_parents[index] = get_representative(current) + return index_parents[index] + + def group_indices(indices: list[tuple[int, int]]) -> None: + first_representative = get_representative(indices[0]) + for i in indices[1:]: + curr_representative = get_representative(i) + index_parents[curr_representative] = first_representative + + new_indices_pair = list[list[tuple[int, int]]]() + physicals = list[Tensor]() + indices_to_n_pdims = dict[int, int]() + for t, indices in args: + assert isinstance(t, StructuredSparseTensor) + physicals.append(t.physical) + for pdims, index in zip(t.v_to_ps, indices): + if index in indices_to_n_pdims: + if indices_to_n_pdims[index] != len(pdims): + raise NotImplementedError( + "einsum currently does not support having a different number of physical " + "dimensions corresponding to matching virtual dimensions of different " + f"tensors. Found {[(t.debug_info(), indices) for t, indices in args]}, " + f"output_indices={output}." + ) + else: + indices_to_n_pdims[index] = len(pdims) + p_to_vs = ... # p_to_vs_from_v_to_ps(t.v_to_ps) + for indices_ in p_to_vs: + # elements in indices[indices_] map to the same dimension, they should be clustered + # together + group_indices([(indices[i], sub_i) for i, sub_i in indices_]) + # record the physical dimensions, index[v] for v in vs will end-up mapping to the same + # final dimension as they were just clustered, so we can take the first, which exists as + # t is a valid SST. + new_indices_pair.append([(indices[vs[0][0]], vs[0][1]) for vs in p_to_vs]) + + current = 0 + pair_to_int = dict[tuple[int, int], int]() + + def unique_int(pair: tuple[int, int]) -> int: + nonlocal current + if pair in pair_to_int: + return pair_to_int[pair] + pair_to_int[pair] = current + current += 1 + return pair_to_int[pair] + + new_indices = [ + [unique_int(get_representative(i)) for i in indices] for indices in new_indices_pair + ] + new_output = list[int]() + v_to_ps = list[list[int]]() + for i in output: + current_v_to_ps = [] + for j in range(indices_to_n_pdims[i]): + k = unique_int(get_representative((i, j))) + if k in new_output: + current_v_to_ps.append(new_output.index(k)) + else: + current_v_to_ps.append(len(new_output)) + new_output.append(k) + v_to_ps.append(current_v_to_ps) + + physical = torch.einsum(*[x for y in zip(physicals, new_indices) for x in y], new_output) + # Need to use the safe constructor, otherwise the dimensions may not be maximally grouped. + # Maybe there is a way to fix that though. + return to_most_efficient_tensor(physical, v_to_ps) + + +def prepare_for_elementwise_op( + t1: Tensor | int | float, t2: Tensor | int | float +) -> tuple[StructuredSparseTensor, StructuredSparseTensor]: + """ + Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being + a SST, Tensor, int or float. + """ + + assert isinstance(t1, StructuredSparseTensor) or isinstance(t2, StructuredSparseTensor) + + if isinstance(t1, int) or isinstance(t1, float): + t1_ = tensor(t1, device=t2.device) + else: + t1_ = t1 + + if isinstance(t2, int) or isinstance(t2, float): + t2_ = tensor(t2, device=t1.device) + else: + t2_ = t2 + + t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_]) + t1_ = to_structured_sparse_tensor(t1_) + t2_ = to_structured_sparse_tensor(t2_) + + return t1_, t2_ + + +@impl(aten.mul.Tensor) +def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: + # Element-wise multiplication with broadcasting + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + all_dims = list(range(t1_.ndim)) + return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) + + +@impl(aten.div.Tensor) +def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.strides) + all_dims = list(range(t1_.ndim)) + return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) + + +@impl(aten.mul.Scalar) +def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor: + # TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check + # that + + assert isinstance(t, StructuredSparseTensor) + new_physical = aten.mul.Scalar(t.physical, scalar) + return StructuredSparseTensor(new_physical, t.strides) + + +@impl(aten.add.Tensor) +def add_Tensor( + t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0 +) -> StructuredSparseTensor: + t1_, t2_ = prepare_for_elementwise_op(t1, t2) + + if torch.equal(t1_.strides, t2_.strides): + new_physical = t1_.physical + t2_.physical * alpha + return StructuredSparseTensor(new_physical, t1_.strides) + else: + raise NotImplementedError() + + +@impl(aten.bmm.default) +def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: + assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor) + assert ( + mat1.ndim == 3 + and mat2.ndim == 3 + and mat1.shape[0] == mat2.shape[0] + and mat1.shape[2] == mat2.shape[1] + ) + + mat1_ = to_structured_sparse_tensor(mat1) + mat2_ = to_structured_sparse_tensor(mat2) + + # TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes + # decompositions. If not, can reshape to common decomposition? + return einsum((mat1_, [0, 1, 2]), (mat2_, [0, 2, 3]), output=[0, 1, 3]) + + +@impl(aten.mm.default) +def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor: + assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor) + assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0] + + mat1_ = to_structured_sparse_tensor(mat1) + mat2_ = to_structured_sparse_tensor(mat2) + + return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2]) + + +@impl(aten.mean.default) +def mean_default(t: StructuredSparseTensor) -> Tensor: + assert isinstance(t, StructuredSparseTensor) + return aten.sum.default(t.physical) / t.numel() + + +@impl(aten.sum.default) +def sum_default(t: StructuredSparseTensor) -> Tensor: + assert isinstance(t, StructuredSparseTensor) + return aten.sum.default(t.physical) + + +@impl(aten.sum.dim_IntList) +def sum_dim_IntList( + t: StructuredSparseTensor, dim: list[int], keepdim: bool = False, dtype=None +) -> Tensor: + assert isinstance(t, StructuredSparseTensor) + + if dtype: + raise NotImplementedError() + + all_dims = list(range(t.ndim)) + result = einsum((t, all_dims), output=[d for d in all_dims if d not in dim]) + + if keepdim: + for d in dim: + result = result.unsqueeze(d) + + return result diff --git a/src/torchjd/sparse/_aten_function_overrides/pointwise.py b/src/torchjd/sparse/_aten_function_overrides/pointwise.py new file mode 100644 index 000000000..9d389c10b --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/pointwise.py @@ -0,0 +1,125 @@ +from torch.ops import aten # type: ignore + +from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl + +# pointwise functions applied to one Tensor with `0.0 → 0` +_POINTWISE_FUNCTIONS = [ + aten.abs.default, + aten.absolute.default, + aten.asin.default, + aten.asinh.default, + aten.atan.default, + aten.atanh.default, + aten.ceil.default, + aten.erf.default, + aten.erfinv.default, + aten.expm1.default, + aten.fix.default, + aten.floor.default, + aten.hardtanh.default, + aten.leaky_relu.default, + aten.log1p.default, + aten.neg.default, + aten.negative.default, + aten.positive.default, + aten.relu.default, + aten.round.default, + aten.sgn.default, + aten.sign.default, + aten.sin.default, + aten.sinh.default, + aten.sqrt.default, + aten.square.default, + aten.tan.default, + aten.tanh.default, + aten.trunc.default, +] + +_IN_PLACE_POINTWISE_FUNCTIONS = [ + aten.abs_.default, + aten.absolute_.default, + aten.asin_.default, + aten.asinh_.default, + aten.atan_.default, + aten.atanh_.default, + aten.ceil_.default, + aten.erf_.default, + aten.erfinv_.default, + aten.expm1_.default, + aten.fix_.default, + aten.floor_.default, + aten.hardtanh_.default, + aten.leaky_relu_.default, + aten.log1p_.default, + aten.neg_.default, + aten.negative_.default, + aten.relu_.default, + aten.round_.default, + aten.sgn_.default, + aten.sign_.default, + aten.sin_.default, + aten.sinh_.default, + aten.sqrt_.default, + aten.square_.default, + aten.tan_.default, + aten.tanh_.default, + aten.trunc_.default, +] + + +def _override_pointwise(op): + @impl(op) + def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + return StructuredSparseTensor(op(t.physical), t.strides) + + return func_ + + +def _override_inplace_pointwise(op): + @impl(op) + def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + op(t.physical) + return t + + +for pointwise_func in _POINTWISE_FUNCTIONS: + _override_pointwise(pointwise_func) + +for pointwise_func in _IN_PLACE_POINTWISE_FUNCTIONS: + _override_inplace_pointwise(pointwise_func) + + +@impl(aten.pow.Tensor_Scalar) +def pow_Tensor_Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + + if exponent <= 0.0: + # Need to densify because we don't have pow(0.0, exponent) = 0.0 + return aten.pow.Tensor_Scalar(t.to_dense(), exponent) + + new_physical = aten.pow.Tensor_Scalar(t.physical, exponent) + return StructuredSparseTensor(new_physical, t.strides) + + +# Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. +@impl(aten.pow_.Scalar) +def pow__Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + + if exponent <= 0.0: + # Need to densify because we don't have pow(0.0, exponent) = 0.0 + # Note sure if it's even possible to densify in-place, so let's just raise an error. + raise ValueError(f"in-place pow with an exponent of {exponent} (<= 0) is not supported.") + + aten.pow_.Scalar(t.physical, exponent) + return t + + +@impl(aten.div.Scalar) +def div_Scalar(t: StructuredSparseTensor, divisor: float) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + + new_physical = aten.div.Scalar(t.physical, divisor) + return StructuredSparseTensor(new_physical, t.strides) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py new file mode 100644 index 000000000..a4c255607 --- /dev/null +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -0,0 +1,289 @@ +import operator +from itertools import accumulate +from math import prod +from typing import cast + +import torch +from torch import Tensor, arange, tensor +from torch.ops import aten # type: ignore + +from torchjd.sparse._structured_sparse_tensor import ( + StructuredSparseTensor, + impl, + print_fallback, + to_most_efficient_tensor, + unwrap_to_dense, +) + + +@impl(aten.view.default) +def view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: + """ + The main condition that we want to respect is that the indexing in the flattened virtual + tensor should remain the same before and after the reshape, i.e. + + c.T S = c'.T S' (1) + where: + * c is the reversed vector of cumulative physical shape before the reshape, i.e. + c.T = [prod(t.shape[1:]), prod(t.shape[2:]), ..., t.shape[-1], 1] + * c' is the same thing but after the reshape, i.e. + c'.T = [prod(shape[1:]), prod(shape[2:]), ..., shape[-1], 1] + * S is the original matrix of strides (t.strides) + * S' is the matrix of strides after reshaping. + + For u, v in Z^m and c in Z, say that u ≡ v (mod c) if u_i ≡ v_i (mod c) for all i. + Note that c'.T S' ≡ S'[-1] (mod shape[-1]) + So if we set S'[-1] = c.T S % shape[-1], we have c.T S ≡ c'.T S' (mod shape[-1]) + + (c'.T S' - S'[-1]) // shape[-1] ≡ S'[-1] (mod shape[-1]) + ... + """ + + assert isinstance(t, StructuredSparseTensor) + + shape = infer_shape(shape, t.numel()) + + if prod(shape) != t.numel(): + raise ValueError(f"shape '{shape}' is invalid for input of size {t.numel()}") + + S = t.strides + vshape = list(t.shape) + c = _reverse_cumulative_product(vshape) + c_prime = _reverse_cumulative_product(shape) + new_strides = ((c @ S).unsqueeze(0) // c_prime.unsqueeze(1)) % tensor(shape).unsqueeze(1) + return to_most_efficient_tensor(t.physical, new_strides) + + +def _reverse_cumulative_product(values: list[int]) -> Tensor: + return tensor(list(accumulate((values[1:] + [1])[::-1], operator.mul))[::-1]) + + +def infer_shape(shape: list[int], numel: int) -> list[int]: + if shape.count(-1) > 1: + raise ValueError("Only one dimension can be inferred") + known = 1 + for s in shape: + if s != -1: + known *= s + inferred = numel // known + return [inferred if s == -1 else s for s in shape] + + +def unsquash_pdim( + physical: Tensor, strides: Tensor, pdim: int, new_pdim_shape: list[int] +) -> tuple[Tensor, Tensor]: + """ + EXAMPLE: + + physical = [ + [1, 2, 3, 4, 5, 6], + [7, 8, 9, 10, 11, 12], + [13, 14, 15, 16, 17, 18], + ] + strides = [ + [1, 1], + [0, 2], + ] + + dim = 1 + shape = [2, 3] + + new_physical = [[ + [1, 2, 3], + [4, 5, 6], + ], [ + [7, 8, 9], + [10, 11, 12], + ], [ + [13, 14, 15], + [16, 17, 18], + ]] + + new_strides = [ + [1, 3, 1], + [0, 6, 2] + """ + + # TODO: handle working with multiple dimensions at once + + old_shape = list(physical.shape) + new_shape = old_shape[:pdim] + new_pdim_shape + old_shape[pdim + 1 :] + new_physical = physical.reshape(new_shape) + + stride_multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))]) + + new_strides = torch.concat( + [ + strides[:, :pdim], + torch.outer(strides[:, pdim], stride_multipliers), + strides[:, pdim + 1 :], + ], + dim=1, + ) + + return new_physical, new_strides + + +@impl(aten._unsafe_view.default) +def _unsafe_view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: + return view_default( + t, shape + ) # We don't do the optimizations that they do in https://github.com/pytorch/pytorch/blame/main/aten/src/ATen/native/TensorShape.cpp + + +@impl(aten.unsqueeze.default) +def unsqueeze_default(t: StructuredSparseTensor, dim: int) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + assert -t.ndim - 1 <= dim < t.ndim + 1 + + if dim < 0: + dim = t.ndim + dim + 1 + + new_strides = torch.concatenate( + [t.strides[:dim], torch.zeros(1, t.strides.shape[1], dtype=torch.int64), t.strides[dim:]] + ) + return StructuredSparseTensor(t.physical, new_strides) + + +@impl(aten.squeeze.dims) +def squeeze_dims(t: StructuredSparseTensor, dims: list[int] | int | None) -> Tensor: + assert isinstance(t, StructuredSparseTensor) + + if dims is None: + excluded = set(range(t.ndim)) + elif isinstance(dims, int): + excluded = {dims} + else: + excluded = set(dims) + + is_row_kept = [i not in excluded for i in range(t.ndim)] + new_strides = t.strides[is_row_kept] + return to_most_efficient_tensor(t.physical, new_strides) + + +@impl(aten.permute.default) +def permute_default(t: StructuredSparseTensor, dims: list[int]) -> StructuredSparseTensor: + new_strides = t.strides[torch.tensor(dims)] + return StructuredSparseTensor(t.physical, new_strides) + + +@impl(aten.cat.default) +def cat_default(tensors: list[Tensor], dim: int) -> Tensor: + if any(not isinstance(t, StructuredSparseTensor) for t in tensors): + print_fallback(aten.cat.default, (tensors, dim), {}) + return aten.cat.default([unwrap_to_dense(t) for t in tensors]) + + tensors_ = [cast(StructuredSparseTensor, t) for t in tensors] + ref_tensor = tensors_[0] + ref_strides = ref_tensor.strides + if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]): + raise NotImplementedError( + "Override for aten.cat.default does not support SSTs that do not all have the same " + f"strides. Found the following tensors:\n{[t.debug_info() for t in tensors_]} and the " + f"following dim: {dim}." + ) + + # We need to try to find the (pretty sure it either does not exist or is unique) physical + # dimension that makes us only move on virtual dimension dim. It also needs to be such that + # traversing it entirely brings us exactly to the end of virtual dimension dim. + + ref_virtual_dim_size = ref_tensor.shape[dim] + indices = torch.argwhere( + torch.eq(ref_strides[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) + & torch.eq(ref_strides.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) + ) + assert len(indices) <= 1 + + if len(indices) == 0: + # Add a physical dimension pdim on which we can concatenate the physicals such that this + # translates into a concatenation of the virtuals on virtual dimension dim. + + pdim = ref_tensor.physical.ndim + physicals = [t.physical.unsqueeze(-1) for t in tensors_] + new_stride_column = torch.zeros(ref_tensor.ndim, 1, dtype=torch.int64) + new_stride_column[dim, 0] = ref_virtual_dim_size + new_strides = torch.concatenate([ref_tensor.strides, new_stride_column], dim=1) + else: + # Such a physical dimension already exists. Note that an alternative implementation would be + # to simply always add the physical dimension, and squash it if it ends up being not needed. + physicals = [t.physical for t in tensors_] + pdim = indices[0][0] + new_strides = ref_tensor.strides + + new_physical = aten.cat.default(physicals, dim=pdim) + return StructuredSparseTensor(new_physical, new_strides) + + +@impl(aten.expand.default) +def expand_default(t: StructuredSparseTensor, sizes: list[int]) -> StructuredSparseTensor: + # note that sizes could also be just an int, or a torch.Size i think + assert isinstance(t, StructuredSparseTensor) + assert isinstance(sizes, list) + assert len(sizes) >= t.ndim + + # Add as many dimensions as needed at the beginning of the tensor (as torch.expand works) + for _ in range(len(sizes) - t.ndim): + t = t.unsqueeze(0) + + # Try to expand each dimension to its new size + new_physical = t.physical + new_strides = t.strides + for d, (vstride, orig_size, new_size) in enumerate(zip(t.strides, t.shape, sizes, strict=True)): + if vstride.sum() > 0 and orig_size != new_size and new_size != -1: + raise ValueError( + f"Cannot expand dim {d} of size != 1. Found size {orig_size} and target size " + f"{new_size}." + ) + + if vstride.sum() == 0 and new_size != 1 and new_size != -1: + # Add a dimension of size new_size at the end of the physical tensor. + new_physical_shape = list(new_physical.shape) + [new_size] + new_physical = new_physical.unsqueeze(-1).expand(new_physical_shape) + + # Make this new physical dimension have a stride of 1 at virtual dimension d and 0 at + # every other virtual dimension + new_stride_column = torch.zeros(t.ndim, 1, dtype=torch.int64) + new_stride_column[d, 0] = 1 + new_strides = torch.cat([new_strides, new_stride_column], dim=1) + + return StructuredSparseTensor(new_physical, new_strides) + + +@impl(aten.broadcast_tensors.default) +def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: + if len(tensors) != 2: + raise NotImplementedError() + + t1, t2 = tensors + + if t1.shape == t2.shape: + return t1, t2 + + a = t1 if t1.ndim >= t2.ndim else t2 + b = t2 if t1.ndim >= t2.ndim else t1 + + a_shape = list(a.shape) + padded_b_shape = [1] * (a.ndim - b.ndim) + list(b.shape) + + new_shape = list[int]() + + for s_a, s_b in zip(a_shape, padded_b_shape): + if s_a != 1 and s_b != 1 and s_a != s_b: + raise ValueError("Incompatible shapes for broadcasting") + else: + new_shape.append(max(s_a, s_b)) + + return aten.expand.default(t1, new_shape), aten.expand.default(t2, new_shape) + + +@impl(aten.transpose.int) +def transpose_int(t: StructuredSparseTensor, dim0: int, dim1: int) -> StructuredSparseTensor: + assert isinstance(t, StructuredSparseTensor) + return StructuredSparseTensor(t.physical, _swap_rows(t.strides, dim0, dim1)) + + +def _swap_rows(matrix: Tensor, c0: int, c1: int) -> Tensor: + index = arange(matrix.shape[0]) + index[c0] = c1 + index[c1] = c0 + return matrix[index] diff --git a/src/torchjd/sparse/_coalesce.py b/src/torchjd/sparse/_coalesce.py new file mode 100644 index 000000000..0da8c777d --- /dev/null +++ b/src/torchjd/sparse/_coalesce.py @@ -0,0 +1,19 @@ +import torch +from torch import Tensor + + +def fix_zero_stride_columns(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: + """ + Remove columns of strides that are all 0 and sum the corresponding elements in the physical + tensor. + """ + + are_columns_zero = (strides == 0).all(dim=0) + + if not are_columns_zero.any(): + return physical, strides + + zero_column_indices = torch.arange(len(are_columns_zero))[are_columns_zero].tolist() + physical = physical.sum(dim=zero_column_indices) + strides = strides[:, ~are_columns_zero] + return physical, strides diff --git a/src/torchjd/sparse/_structured_sparse_tensor.py b/src/torchjd/sparse/_structured_sparse_tensor.py new file mode 100644 index 000000000..6168904b9 --- /dev/null +++ b/src/torchjd/sparse/_structured_sparse_tensor.py @@ -0,0 +1,278 @@ +import itertools +import operator +from functools import wraps +from itertools import accumulate +from math import prod + +import torch +from torch import Tensor, arange, meshgrid, stack, tensor, tensordot, zeros +from torch.utils._pytree import tree_map + + +class StructuredSparseTensor(Tensor): + _HANDLED_FUNCTIONS = dict() + + @staticmethod + def __new__(cls, physical: Tensor, strides: Tensor): + assert strides.dtype == torch.int64 + + # Note [Passing requires_grad=true tensors to subclasses] + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Calling _make_subclass directly in an autograd context is + # never the right thing to do, as this will detach you from + # the autograd graph. You must create an autograd function + # representing the "constructor" (NegativeView, in this case) + # and call that instead. This assert helps prevent direct usage + # (which is bad!) + assert not physical.requires_grad or not torch.is_grad_enabled() + + pshape = tensor(physical.shape, dtype=torch.int64) + vshape = strides @ (pshape - 1) + 1 + return Tensor._make_wrapper_subclass( + cls, tuple(vshape.tolist()), dtype=physical.dtype, device=physical.device + ) + + def __init__(self, physical: Tensor, strides: Tensor): + """ + This constructor is made for specifying physical and strides exactly. It should not modify + it. + + For this reason, another constructor will be made to either modify the physical / strides to + simplify the result, or to create a dense tensor directly if it's already dense. + + :param physical: The dense tensor holding the actual data. + :param strides: Integer (int64) tensor of shape [virtual_ndim, physical_ndim], representing + the linear transformation between an index in the physical tensor and the corresponding + index in the virtual tensor, i.e. v_index = strides @ p_index. + """ + + if any(s == 1 for s in physical.shape): + raise ValueError( + "physical must not contain any dimension of size 1. Found physical.shape=" + f"{physical.shape}." + ) + if strides.dtype is not torch.int64: + raise ValueError( + f"strides should be of int64 dtype. Found strides.dtype={strides.dtype}." + ) + if not (strides >= 0).all(): + raise ValueError(f"All strides must be non-negative. Found strides={strides}.") + if strides.shape[1] != physical.ndim: + raise ValueError( + f"strides should have 1 column per physical dimension. Found strides={strides} and " + f"physical.shape={physical.shape}." + ) + if (strides.sum(dim=0) == 0).any(): + raise ValueError( + f"strides should not have any column full of zeros. Found strides={strides}." + ) + groups = get_groupings(list(physical.shape), strides) + if any(len(group) != 1 for group in groups): + raise ValueError( + f"Dimensions must be maximally grouped. Found strides={strides} and " + f"groups={groups}" + ) + + self.physical = physical + self.strides = strides + + def to_dense( + self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None + ) -> Tensor: + assert dtype is None # We may add support for this later + assert masked_grad is None # We may add support for this later + + if self.physical.ndim == 0: + return self.physical + + p_index_ranges = [arange(s) for s in self.physical.shape] + p_indices_grid = stack(meshgrid(*p_index_ranges, indexing="ij")) + + # addmm_cuda not implemented for Long tensors => gotta have these tensors on cpu + v_indices_grid = tensordot(self.strides, p_indices_grid, dims=1) + res = zeros(self.shape, device=self.device, dtype=self.dtype) + res[tuple(v_indices_grid)] = self.physical + return res + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func in cls._HANDLED_FUNCTIONS: + return cls._HANDLED_FUNCTIONS[func](*args, **kwargs) + + print_fallback(func, args, kwargs) + unwrapped_args = tree_map(unwrap_to_dense, args) + unwrapped_kwargs = tree_map(unwrap_to_dense, kwargs) + return func(*unwrapped_args, **unwrapped_kwargs) + + def __repr__(self, *, tensor_contents=None) -> str: + return f"StructuredSparseTensor(physical={self.physical}, strides={self.strides})" + + def debug_info(self) -> str: + info = ( + f"vshape: {self.shape}\n" + f"pshape: {self.physical.shape}\n" + f"strides: {self.strides}\n" + ) + return info + + @classmethod + def implements(cls, torch_function): + """Register a torch function override for ScalarTensor""" + + @wraps(torch_function) + def decorator(func): + cls._HANDLED_FUNCTIONS[torch_function] = func + return func + + return decorator + + +impl = StructuredSparseTensor.implements + + +def print_fallback(func, args, kwargs) -> None: + def tensor_to_str(t: Tensor) -> str: + result = f"{t.__class__.__name__} - vshape: {t.shape}" + if isinstance(t, StructuredSparseTensor): + result += f" - pshape: {t.physical.shape} - strides: {t.strides}" + + return result + + print(f"Falling back to dense for {func.__name__}") + if len(args) > 0: + print("* args:") + for arg in args: + if isinstance(arg, Tensor): + print(f" > {tensor_to_str(arg)}") + elif isinstance(arg, list) and len(arg) > 0 and isinstance(arg[0], Tensor): + list_content = "\n ".join([tensor_to_str(t) for t in arg]) + print(f" > [{list_content}]") + else: + print(f" > {arg}") + if len(kwargs) > 0: + print("* kwargs:") + for k, v in kwargs.items(): + print(f" > {k}: {v}") + print() + + +def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]: + """ + From a list of physical dimensions corresponding to a virtual dimension, and from the physical + shape, get the stride indicating how moving on each physical dimension makes you move on the + virtual dimension. + + Example: + Imagine a vector of size 3, and of value [1, 2, 3]. + Imagine a SST t of shape [3, 3] using this vector as physical and using [[0, 0]] as v_to_ps. + t.to_dense() is [1, 0, 0, 0, 2, 0, 0, 0, 3] (it's the flattening of the diagonal matrix + [[1, 0, 0], [0, 2, 0], [0, 0, 3]]). + When you move by 1 on physical dimension 0, you move by 4 on virtual dimension 0, i.e. + strides_v2([0, 0], [3]) = 4 + In the 2D view, you'd move by 1 row (3 indices) and 1 column (1 index). + + Example: + strides_v2([0, 0, 1], [3,4]) # [16, 1] + Moving by 1 on physical dimension 0 makes you move by 16 on the virtual dimension. Moving by + 1 on physical dimension 1 makes you move by 1 on the virtual dimension. + """ + + strides_v1 = list(accumulate([1] + [physical_shape[d] for d in p_dims[:0:-1]], operator.mul))[ + ::-1 + ] + result = [0 for _ in range(len(physical_shape))] + for i, d in enumerate(p_dims): + result[d] += strides_v1[i] + return result + + +def get_groupings(pshape: list[int], strides: Tensor) -> list[list[int]]: + strides_time_pshape = strides * tensor(pshape, dtype=torch.int64) + groups = {i: {i} for i, column in enumerate(strides.T)} + group_ids = [i for i in range(len(strides.T))] + for i1, i2 in itertools.combinations(range(strides.shape[1]), 2): + if torch.equal(strides[:, i1], strides_time_pshape[:, i2]): + groups[group_ids[i1]].update(groups[group_ids[i2]]) + group_ids[i2] = group_ids[i1] + + new_columns = [sorted(groups[group_id]) for group_id in sorted(set(group_ids))] + + if len(new_columns) != len(pshape): + print(f"Combined pshape with the following new columns: {new_columns}.") + + return new_columns + + +def to_structured_sparse_tensor(t: Tensor) -> StructuredSparseTensor: + if isinstance(t, StructuredSparseTensor): + return t + else: + return make_sst(physical=t, strides=torch.eye(t.ndim, dtype=torch.int64)) + + +def to_most_efficient_tensor(physical: Tensor, strides: Tensor) -> Tensor: + physical, strides = fix_dim_of_size_1(physical, strides) + physical, strides = fix_ungrouped_dims(physical, strides) + + if (strides.sum(dim=0) == 1).all(): + # TODO: this can be done more efficiently (without even creating the SST) + return StructuredSparseTensor(physical, strides).to_dense() + else: + return StructuredSparseTensor(physical, strides) + + +def unwrap_to_dense(t: Tensor): + if isinstance(t, StructuredSparseTensor): + return t.to_dense() + else: + return t + + +def get_full_source(source: list[int], destination: list[int], ndim: int) -> list[int]: + """ + Doing a movedim with source and destination is always equivalent to doing a movedim with + [0, 1, ..., ndim-1] (aka "full_destination") as destination, and the "full_source" as source. + + This function computes the full_source based on a source and destination. + + Example: + source=[2, 4] + destination=[0, 3] + ndim=5 + + full_source = [2, 0, 1, 4, 3] + full_destination = [0, 1, 2, 3, 4] + """ + + idx = torch.full((ndim,), -1, dtype=torch.int64) + idx[destination] = tensor(source, dtype=torch.int64) + source_set = set(source) + idx[idx.eq(-1)] = tensor([i for i in range(ndim) if i not in source_set], dtype=torch.int64) + + return idx.tolist() + + +def fix_dim_of_size_1(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: + is_of_size_1 = tensor([s == 1 for s in physical.shape], dtype=torch.bool) + return physical.squeeze(), strides[:, ~is_of_size_1] + + +def fix_ungrouped_dims(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: + groups = get_groupings(list(physical.shape), strides) + nphysical = physical.reshape([prod([physical.shape[dim] for dim in group]) for group in groups]) + stride_mapping = torch.zeros(physical.ndim, nphysical.ndim, dtype=torch.int64) + for j, group in enumerate(groups): + stride_mapping[group[-1], j] = 1 + + new_strides = strides @ stride_mapping + return nphysical, new_strides + + +def make_sst(physical: Tensor, strides: Tensor) -> StructuredSparseTensor: + """Fix physical and strides and create a StructuredSparseTensor with them.""" + + physical, strides = fix_dim_of_size_1(physical, strides) + physical, strides = fix_ungrouped_dims(physical, strides) + return StructuredSparseTensor(physical, strides) diff --git a/tests/unit/sparse/__init__.py b/tests/unit/sparse/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/sparse/test_structured_sparse_tensor.py b/tests/unit/sparse/test_structured_sparse_tensor.py new file mode 100644 index 000000000..e0e404149 --- /dev/null +++ b/tests/unit/sparse/test_structured_sparse_tensor.py @@ -0,0 +1,422 @@ +import torch +from pytest import mark +from torch import Tensor, tensor +from torch.ops import aten # type: ignore +from torch.testing import assert_close +from utils.tensors import randn_, tensor_, zeros_ + +from torchjd.sparse._aten_function_overrides.einsum import einsum +from torchjd.sparse._aten_function_overrides.pointwise import ( + _IN_PLACE_POINTWISE_FUNCTIONS, + _POINTWISE_FUNCTIONS, +) +from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim +from torchjd.sparse._coalesce import fix_zero_stride_columns +from torchjd.sparse._structured_sparse_tensor import ( + StructuredSparseTensor, + fix_ungrouped_dims, + get_full_source, + get_groupings, +) + + +def test_to_dense(): + n = 2 + m = 3 + a = randn_([n, m]) + b = StructuredSparseTensor(a, tensor([[1, 0], [0, 1], [0, 1], [1, 0]])) + c = b.to_dense() + + for i in range(n): + for j in range(m): + assert c[i, j, j, i] == a[i, j] + + +def test_to_dense2(): + a = tensor_([1.0, 2.0, 3.0]) + b = StructuredSparseTensor(a, tensor([[4]])) + c = b.to_dense() + expected = tensor_([1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]) + assert torch.all(torch.eq(c, expected)) + + +@mark.parametrize( + ["a_pshape", "a_strides", "b_pshape", "b_strides", "a_indices", "b_indices", "output_indices"], + [ + ( + [4, 5], + tensor([[1, 0], [1, 0], [0, 1]]), + [4, 5], + tensor([[1, 0], [0, 1], [0, 1]]), + [0, 1, 2], + [0, 2, 3], + [0, 1, 3], + ), + ( + [2, 3, 5], + tensor([[3, 1, 0], [1, 0, 2]]), + [10, 3], + tensor([[1, 0], [0, 1]]), + [0, 1], + [1, 2], + [0, 2], + ), + ( + [6, 2, 3], + tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), + [2, 3], + tensor([[3, 1], [1, 0], [0, 1]]), + [0, 1, 2], + [0, 1, 2], + [0, 1, 2], + ), + ], +) +def test_einsum( + a_pshape: list[int], + a_strides: Tensor, + b_pshape: list[int], + b_strides: Tensor, + a_indices: list[int], + b_indices: list[int], + output_indices: list[int], +): + a = StructuredSparseTensor(randn_(a_pshape), a_strides) + b = StructuredSparseTensor(randn_(b_pshape), b_strides) + + res = einsum((a, a_indices), (b, b_indices), output=output_indices) + + expected = torch.einsum(a.to_dense(), a_indices, b.to_dense(), b_indices, output_indices) + + assert isinstance(res, StructuredSparseTensor) + assert_close(res.to_dense(), expected) + + +@mark.parametrize( + "shape", + [ + [], + [2], + [2, 3], + [2, 3, 4], + ], +) +def test_structured_sparse_tensor_scalar(shape: list[int]): + a = randn_(shape) + b = StructuredSparseTensor(a, torch.eye(len(shape), dtype=torch.int64)) + + assert_close(a, b.to_dense()) + + +@mark.parametrize("dim", [2, 3, 4, 5, 10]) +def test_diag_equivalence(dim: int): + a = randn_([dim]) + b = StructuredSparseTensor(a, tensor([[1], [1]])) + + diag_a = torch.diag(a) + + assert_close(b.to_dense(), diag_a) + + +def test_three_virtual_single_physical(): + dim = 10 + a = randn_([dim]) + b = StructuredSparseTensor(a, tensor([[1], [1], [1]])) + + expected = zeros_([dim, dim, dim]) + for i in range(dim): + expected[i, i, i] = a[i] + + assert_close(b.to_dense(), expected) + + +@mark.parametrize("func", _POINTWISE_FUNCTIONS) +def test_pointwise(func): + dim = 10 + a = randn_([dim]) + b = StructuredSparseTensor(a, tensor([[1], [1]])) + c = b.to_dense() + res = func(b) + assert isinstance(res, StructuredSparseTensor) + + assert_close(res.to_dense(), func(c), equal_nan=True) + + +@mark.parametrize("func", _IN_PLACE_POINTWISE_FUNCTIONS) +def test_inplace_pointwise(func): + dim = 10 + a = randn_([dim]) + b = StructuredSparseTensor(a, tensor([[1], [1]])) + c = b.to_dense() + func(b) + assert isinstance(b, StructuredSparseTensor) + + assert_close(b.to_dense(), func(c), equal_nan=True) + + +@mark.parametrize("func", [torch.mean, torch.sum]) +def test_unary(func): + dim = 10 + a = randn_([dim]) + b = StructuredSparseTensor(a, tensor([[1], [1]])) + c = b.to_dense() + + res = func(b) + assert_close(res.to_dense(), func(c)) + + +@mark.parametrize( + ["physical_shape", "strides", "target_shape", "expected_physical_shape", "expected_strides"], + [ + ( + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + [2, 2, 3], + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + ), # no change of shape + ( + [2, 3], + tensor([[1, 0], [3, 1]]), + [2, 6], + [2, 3], + tensor([[1, 0], [3, 1]]), + ), # no change of shape + ( + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + [2, 6], + [2, 3], + tensor([[1, 0], [3, 1]]), + ), # squashing 2 dims + ( + [2, 3], + tensor([[1, 0], [3, 1]]), + [2, 2, 3], + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + ), # unsquashing into 2 dims + ( + [2, 3], + tensor([[9, 1]]), + [2, 6], + [2, 3], + tensor([[1, 0], [3, 1]]), + ), # unsquashing into 2 dims + ( + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + [12], + [2, 3], + tensor([[9, 1]]), + ), # squashing 3 dims + ( + [2, 3], + tensor([[9, 1]]), + [2, 2, 3], + [2, 3], + tensor([[1, 0], [1, 0], [0, 1]]), + ), # unsquashing into 3 dims + ( + [4], + tensor([[1], [1]]), + [2, 2, 4], + [2, 2], + tensor([[1, 0], [0, 1], [2, 1]]), + ), # unsquashing physical dim + ( + [4], + tensor([[1], [1]]), + [4, 2, 2], + [2, 2], + tensor([[2, 1], [1, 0], [0, 1]]), + ), # unsquashing physical dim + ( + [2, 3, 4], + tensor([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]), + [4, 12], + [2, 12], + tensor([[3, 0], [0, 1]]), + ), # world boss + ( + [2, 12], + tensor([[3, 0], [0, 1]]), + [2, 2, 3, 4], + [2, 3, 4], + tensor([[1, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1]]), + ), # world boss + ], +) +def test_view( + physical_shape: list[int], + strides: Tensor, + target_shape: list[int], + expected_physical_shape: list[int], + expected_strides: Tensor, +): + a = randn_(tuple(physical_shape)) + t = StructuredSparseTensor(a, strides) + + result = aten.view.default(t, target_shape) + expected = t.to_dense().reshape(target_shape) + + assert isinstance(result, StructuredSparseTensor) + assert list(result.physical.shape) == expected_physical_shape + assert torch.equal(result.strides, expected_strides) + assert torch.all(torch.eq(result.to_dense(), expected)) + + +@mark.parametrize( + ["pshape", "strides", "expected"], + [ + ( + [[32, 2, 3, 4, 5]], + torch.tensor([[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [1, 0, 0, 0, 0], [0, 60, 20, 5, 1]]), + [[0], [1, 2, 3, 4]], + ) + ], +) +def test_get_groupings(pshape: list[int], strides: torch.Tensor, expected: list[list[int]]): + result = get_groupings(pshape, strides) + assert result == expected + + +@mark.parametrize( + ["physical_shape", "strides", "expected_physical_shape", "expected_strides"], + [ + ( + [3, 4, 5], + tensor([[20, 5, 1], [4, 1, 12], [0, 0, 1]]), + [12, 5], + tensor([[5, 1], [1, 12], [0, 1]]), + ), + ( + [32, 20, 8], + tensor([[1, 0, 0], [1, 32, 0], [0, 0, 1]]), + [32, 20, 8], + tensor([[1, 0, 0], [1, 32, 0], [0, 0, 1]]), + ), + ([3, 3, 4], tensor([[3, 1, 0], [0, 4, 1]]), [3, 3, 4], tensor([[3, 1, 0], [0, 4, 1]])), + ], +) +def test_fix_ungrouped_dims( + physical_shape: list[int], + strides: Tensor, + expected_physical_shape: list[int], + expected_strides: Tensor, +): + physical = randn_(physical_shape) + fixed_physical, fixed_strides = fix_ungrouped_dims(physical, strides) + + assert list(fixed_physical.shape) == expected_physical_shape + assert torch.equal(fixed_strides, expected_strides) + + +@mark.parametrize( + [ + "physical_shape", + "strides", + "pdim", + "new_pdim_shape", + "expected_physical_shape", + "expected_strides", + ], + [ + ([4], tensor([[1], [2]]), 0, [4], [4], tensor([[1], [2]])), # trivial + ([4], tensor([[1], [2]]), 0, [2, 2], [2, 2], tensor([[2, 1], [4, 2]])), + ( + [3, 4, 5], + tensor([[1, 2, 0], [1, 0, 1], [0, 1, 1]]), + 1, + [2, 1, 1, 2], + [3, 2, 1, 1, 2, 5], + tensor([[1, 4, 4, 4, 2, 0], [1, 0, 0, 0, 0, 1], [0, 2, 2, 2, 1, 1]]), + ), + ], +) +def test_unsquash_pdim( + physical_shape: list[int], + strides: Tensor, + pdim: int, + new_pdim_shape: list[int], + expected_physical_shape: list[int], + expected_strides: Tensor, +): + physical = randn_(physical_shape) + new_physical, new_strides = unsquash_pdim(physical, strides, pdim, new_pdim_shape) + + assert list(new_physical.shape) == expected_physical_shape + assert torch.equal(new_strides, expected_strides) + + +@mark.parametrize( + [ + "source", + "destination", + "ndim", + ], + [ + ([2, 4], [0, 3], 5), + ([5, 3, 6], [2, 0, 5], 8), + ], +) +def test_get_column_indices(source: list[int], destination: list[int], ndim: int): + # TODO: this test should be improved / removed. It creates quite big tensors for nothing. + + t = randn_(list(torch.randint(3, 8, size=(ndim,)))) + full_destination = list(range(ndim)) + full_source = get_full_source(source, destination, ndim) + assert torch.equal(t.movedim(full_source, full_destination), t.movedim(source, destination)) + + +@mark.parametrize( + ["sst_args", "dim"], + [ + ([([3], tensor([[1], [1]])), ([3], tensor([[1], [1]]))], 1), + ([([3, 2], tensor([[1, 0], [1, 3]])), ([3, 2], tensor([[1, 0], [1, 3]]))], 1), + ], +) +def test_concatenate( + sst_args: list[tuple[list[int], Tensor]], + dim: int, +): + tensors = [StructuredSparseTensor(randn_(pshape), strides) for pshape, strides in sst_args] + res = aten.cat.default(tensors, dim) + expected = aten.cat.default([t.to_dense() for t in tensors], dim) + + assert isinstance(res, StructuredSparseTensor) + assert torch.all(torch.eq(res.to_dense(), expected)) + + +@mark.parametrize( + ["physical", "strides", "expected_physical", "expected_strides"], + [ + ( + tensor_([[1, 2, 3], [4, 5, 6]]), + tensor([[1, 0], [1, 0], [2, 0]]), + tensor_([6, 15]), + tensor([[1], [1], [2]]), + ), + ( + tensor_([[1, 2, 3], [4, 5, 6]]), + tensor([[1, 1], [1, 0], [2, 0]]), + tensor_([[1, 2, 3], [4, 5, 6]]), + tensor([[1, 1], [1, 0], [2, 0]]), + ), + ( + tensor_([[3, 2, 1], [6, 5, 4]]), + tensor([[0, 0], [0, 0], [0, 0]]), + tensor_(21), + tensor([[], [], []], dtype=torch.int64), + ), + ], +) +def test_fix_zero_stride_columns( + physical: Tensor, + strides: Tensor, + expected_physical: Tensor, + expected_strides: Tensor, +): + physical, strides = fix_zero_stride_columns(physical, strides) + assert torch.equal(physical, expected_physical) + assert torch.equal(strides, expected_strides) From 1c86b7942c65fc5b8a757f4939f8648a2d03a01f Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 17 Nov 2025 20:34:32 +0100 Subject: [PATCH 02/42] Fix some Mypy errors. --- src/torchjd/sparse/_aten_function_overrides/shape.py | 2 +- src/torchjd/sparse/_structured_sparse_tensor.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index a4c255607..07ed16bbc 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -207,7 +207,7 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: # Such a physical dimension already exists. Note that an alternative implementation would be # to simply always add the physical dimension, and squash it if it ends up being not needed. physicals = [t.physical for t in tensors_] - pdim = indices[0][0] + pdim = cast(int, indices[0, 0].item()) new_strides = ref_tensor.strides new_physical = aten.cat.default(physicals, dim=pdim) diff --git a/src/torchjd/sparse/_structured_sparse_tensor.py b/src/torchjd/sparse/_structured_sparse_tensor.py index 6168904b9..6b7c396be 100644 --- a/src/torchjd/sparse/_structured_sparse_tensor.py +++ b/src/torchjd/sparse/_structured_sparse_tensor.py @@ -1,5 +1,6 @@ import itertools import operator +from collections.abc import Callable from functools import wraps from itertools import accumulate from math import prod @@ -10,7 +11,7 @@ class StructuredSparseTensor(Tensor): - _HANDLED_FUNCTIONS = dict() + _HANDLED_FUNCTIONS = dict[Callable, Callable]() @staticmethod def __new__(cls, physical: Tensor, strides: Tensor): From 76bb476e66e068239803fcf657653bc182354435 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Wed, 19 Nov 2025 17:02:28 +0100 Subject: [PATCH 03/42] Add `solve_int` --- src/torchjd/sparse/linalg.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 src/torchjd/sparse/linalg.py diff --git a/src/torchjd/sparse/linalg.py b/src/torchjd/sparse/linalg.py new file mode 100644 index 000000000..6dea141a2 --- /dev/null +++ b/src/torchjd/sparse/linalg.py @@ -0,0 +1,24 @@ +import torch +from torch import Tensor + + +def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None: + """ + Solve A X = B where A, B and X have integer dtype. + Return X if such a matrix exists and otherwise None. + """ + + A_ = A.to(torch.float64) + B_ = B.to(torch.float64) + + try: + X = torch.linalg.solve(A_, B_) + except RuntimeError: + return None + + X_rounded = X.round() + if not torch.all(torch.isclose(X, X_rounded, atol=tol)): + return None + + # TODO: Verify that the round operation cannot fail + return X_rounded.to(torch.int64) From 96c54e4ddb89b929f22e6e38a1c895f8a5682ffc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 19 Nov 2025 19:52:00 +0100 Subject: [PATCH 04/42] Make linalg protected --- src/torchjd/sparse/{linalg.py => _linalg.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename src/torchjd/sparse/{linalg.py => _linalg.py} (100%) diff --git a/src/torchjd/sparse/linalg.py b/src/torchjd/sparse/_linalg.py similarity index 100% rename from src/torchjd/sparse/linalg.py rename to src/torchjd/sparse/_linalg.py From 3a8e684235df51b5a26b91f2ad49455827fc2020 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 19 Nov 2025 19:54:42 +0100 Subject: [PATCH 05/42] Add intdiv_c and mod_c * Note that in python, x % 0 raises ZeroDivisionError. The implementation of mod_c matches this behavior when t2 is the zero vector. --- src/torchjd/sparse/_linalg.py | 47 ++++++++++++++++++ .../sparse/test_structured_sparse_tensor.py | 49 ++++++++++++++++++- 2 files changed, 95 insertions(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 6dea141a2..a178e1e98 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -22,3 +22,50 @@ def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None: # TODO: Verify that the round operation cannot fail return X_rounded.to(torch.int64) + + +def mod_c(t1: Tensor, t2: Tensor) -> Tensor: + """ + Computes the combined modulo r = t1 %c t2, such that + t1 = d * t2 + r with d = t1 //c t2 and + 0 <= r[i] <= t1[i] for all i. + + :param t1: Non-negative integer vector. + :param t2: Non-negative integer vector. + + Examples: + [8, 12]^T %c [2, 3]^T = [0, 0]^T + [8, 12]^T %c [2, 4]^T = [2, 0]^T + [8, 12]^T %c [3, 3]^T = [2, 6]^T + [8, 12]^T %c [2, 0]^T = [0, 12]^T + [8, 12]^T %c [0, 2]^T = [8, 0]^T + [8, 12]^T %c [0, 0]^T => ZeroDivisionError + """ + + return t1 - intdiv_c(t1, t2) * t2 + + +def intdiv_c(t1: Tensor, t2: Tensor) -> Tensor: + """ + Computes the combined integer division d = t1 // t2, such that + t1 = d * t2 + r with r = t1 %c t2 + 0 <= r[i] <= t1[i] for all i. + + :param t1: Non-negative integer vector. + :param t2: Non-negative integer vector. + + Examples: + [8, 12]^T //c [2, 3]^T = 4 + [8, 12]^T //c [2, 4]^T = 3 + [8, 12]^T //c [3, 3]^T = 2 + [8, 12]^T //c [2, 0]^T = 4 + [8, 12]^T //c [0, 2]^T = 6 + [8, 12]^T //c [0, 0]^T => ZeroDivisionError + """ + + non_zero_indices = torch.nonzero(t2) + if len(non_zero_indices) == 0: + raise ZeroDivisionError("division by zero") + else: + min_divider = (t1[non_zero_indices] // t2[non_zero_indices]).min() + return min_divider diff --git a/tests/unit/sparse/test_structured_sparse_tensor.py b/tests/unit/sparse/test_structured_sparse_tensor.py index e0e404149..a2f800212 100644 --- a/tests/unit/sparse/test_structured_sparse_tensor.py +++ b/tests/unit/sparse/test_structured_sparse_tensor.py @@ -1,5 +1,5 @@ import torch -from pytest import mark +from pytest import mark, raises from torch import Tensor, tensor from torch.ops import aten # type: ignore from torch.testing import assert_close @@ -12,6 +12,7 @@ ) from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim from torchjd.sparse._coalesce import fix_zero_stride_columns +from torchjd.sparse._linalg import intdiv_c, mod_c from torchjd.sparse._structured_sparse_tensor import ( StructuredSparseTensor, fix_ungrouped_dims, @@ -420,3 +421,49 @@ def test_fix_zero_stride_columns( physical, strides = fix_zero_stride_columns(physical, strides) assert torch.equal(physical, expected_physical) assert torch.equal(strides, expected_strides) + + +@mark.parametrize( + ["t1", "t2", "expected"], + [ + (tensor([8, 12]), tensor([2, 3]), tensor([0, 0])), + (tensor([8, 12]), tensor([2, 4]), tensor([2, 0])), + (tensor([8, 12]), tensor([3, 3]), tensor([2, 6])), + (tensor([8, 12]), tensor([2, 0]), tensor([0, 12])), + (tensor([8, 12]), tensor([0, 2]), tensor([8, 0])), + ], +) +def test_mod_c( + t1: Tensor, + t2: Tensor, + expected: Tensor, +): + assert torch.equal(mod_c(t1, t2), expected) + + +def test_mod_c_by_0_raises(): + with raises(ZeroDivisionError): + mod_c(tensor([3, 4]), tensor([0, 0])) + + +@mark.parametrize( + ["t1", "t2", "expected"], + [ + (tensor([8, 12]), tensor([2, 3]), 4), + (tensor([8, 12]), tensor([2, 4]), 3), + (tensor([8, 12]), tensor([3, 3]), 2), + (tensor([8, 12]), tensor([2, 0]), 4), + (tensor([8, 12]), tensor([0, 2]), 6), + ], +) +def test_intdiv_c( + t1: Tensor, + t2: Tensor, + expected: Tensor, +): + assert intdiv_c(t1, t2) == expected + + +def test_intdiv_c_by_0_raises(): + with raises(ZeroDivisionError): + intdiv_c(tensor([3, 4]), tensor([0, 0])) From ba6e65f86a4785e0dc9c6d95d438cd9b90a04daf Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 21 Nov 2025 09:09:13 +0100 Subject: [PATCH 06/42] Remove mod_c and div_c --- src/torchjd/sparse/_linalg.py | 47 ----------------------------------- 1 file changed, 47 deletions(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index a178e1e98..6dea141a2 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -22,50 +22,3 @@ def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None: # TODO: Verify that the round operation cannot fail return X_rounded.to(torch.int64) - - -def mod_c(t1: Tensor, t2: Tensor) -> Tensor: - """ - Computes the combined modulo r = t1 %c t2, such that - t1 = d * t2 + r with d = t1 //c t2 and - 0 <= r[i] <= t1[i] for all i. - - :param t1: Non-negative integer vector. - :param t2: Non-negative integer vector. - - Examples: - [8, 12]^T %c [2, 3]^T = [0, 0]^T - [8, 12]^T %c [2, 4]^T = [2, 0]^T - [8, 12]^T %c [3, 3]^T = [2, 6]^T - [8, 12]^T %c [2, 0]^T = [0, 12]^T - [8, 12]^T %c [0, 2]^T = [8, 0]^T - [8, 12]^T %c [0, 0]^T => ZeroDivisionError - """ - - return t1 - intdiv_c(t1, t2) * t2 - - -def intdiv_c(t1: Tensor, t2: Tensor) -> Tensor: - """ - Computes the combined integer division d = t1 // t2, such that - t1 = d * t2 + r with r = t1 %c t2 - 0 <= r[i] <= t1[i] for all i. - - :param t1: Non-negative integer vector. - :param t2: Non-negative integer vector. - - Examples: - [8, 12]^T //c [2, 3]^T = 4 - [8, 12]^T //c [2, 4]^T = 3 - [8, 12]^T //c [3, 3]^T = 2 - [8, 12]^T //c [2, 0]^T = 4 - [8, 12]^T //c [0, 2]^T = 6 - [8, 12]^T //c [0, 0]^T => ZeroDivisionError - """ - - non_zero_indices = torch.nonzero(t2) - if len(non_zero_indices) == 0: - raise ZeroDivisionError("division by zero") - else: - min_divider = (t1[non_zero_indices] // t2[non_zero_indices]).min() - return min_divider From f00377a127f2751d3a9c7020793357c0521134d7 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 21 Nov 2025 09:24:56 +0100 Subject: [PATCH 07/42] Add HNF decomposition, LCM and GCD. --- src/torchjd/sparse/_linalg.py | 205 ++++++++++++++++++++++++++++++++++ 1 file changed, 205 insertions(+) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 6dea141a2..7d1b973e6 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -1,6 +1,8 @@ import torch from torch import Tensor +# TODO: Implement in C everything in this file. + def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None: """ @@ -22,3 +24,206 @@ def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None: # TODO: Verify that the round operation cannot fail return X_rounded.to(torch.int64) + + +def extended_gcd(a: int, b: int) -> tuple[int, int, int]: + """ + Extended Euclidean Algorithm (Python integers). + Returns (g, x, y) such that a*x + b*y = g. + """ + # We perform the logic in standard Python int for speed on scalars + # then cast back to torch tensors if needed, or return python ints. + if a == 0: + return b, 0, 1 + else: + g, y, x = extended_gcd(b % a, a) + return g, x - (b // a) * y, y + + +def hnf_decomposition(A: Tensor) -> tuple[Tensor, Tensor, Tensor]: + """ + Computes the Hermite Normal Form decomposition using PyTorch. + + Args: + A: (m x n) torch.Tensor (dtype=torch.long) + + Returns: + H: (m x n) Canonical Lower Triangular HNF + U: (n x n) Unimodular transform (A @ U = H) + V: (n x n) Inverse Unimodular transform (H @ V = A) + """ + + H = A.clone().to(dtype=torch.long) + m, n = H.shape + + U = torch.eye(n, dtype=torch.long) + V = torch.eye(n, dtype=torch.long) + + row = 0 + col = 0 + + while row < m and col < n: + # --- 1. Pivot Selection --- + # Find first non-zero entry in current row from col onwards + pivot_idx = -1 + + # We extract the row slice to CPU for faster scalar checks if on GPU + # or just iterate. For HNF, strictly sequential loop is often easiest. + for j in range(col, n): + if H[row, j] != 0: + pivot_idx = j + break + + if pivot_idx == -1: + row += 1 + continue + + # Swap to current column + if pivot_idx != col: + # Swap Columns in H and U + H[:, [col, pivot_idx]] = H[:, [pivot_idx, col]] + U[:, [col, pivot_idx]] = U[:, [pivot_idx, col]] + # Swap ROWS in V + V[[col, pivot_idx], :] = V[[pivot_idx, col], :] + + # --- 2. Gaussian Elimination via GCD --- + for j in range(col + 1, n): + if H[row, j] != 0: + # Extract values as python ints for GCD logic + a_val = H[row, col].item() + b_val = H[row, j].item() + + g, x, y = extended_gcd(a_val, b_val) + + # Bezout: a*x + b*y = g + # c1 = -b // g, c2 = a // g + c1 = -b_val // g + c2 = a_val // g + + # --- Update H (Column Ops) --- + # Important: Clone columns to avoid in-place modification issues during calc + col_c = H[:, col].clone() + col_j = H[:, j].clone() + + H[:, col] = col_c * x + col_j * y + H[:, j] = col_c * c1 + col_j * c2 + + # --- Update U (Column Ops) --- + u_c = U[:, col].clone() + u_j = U[:, j].clone() + U[:, col] = u_c * x + u_j * y + U[:, j] = u_c * c1 + u_j * c2 + + # --- Update V (Inverse Row Ops) --- + # Inverse of [[x, c1], [y, c2]] is [[c2, -c1], [-y, x]] + v_r_c = V[col, :].clone() + v_r_j = V[j, :].clone() + V[col, :] = v_r_c * c2 - v_r_j * c1 + V[j, :] = v_r_c * (-y) + v_r_j * x + + # --- 3. Enforce Positive Diagonal --- + if H[row, col] < 0: + H[:, col] *= -1 + U[:, col] *= -1 + V[col, :] *= -1 + + # --- 4. Canonical Reduction (Modulo) --- + # Ensure 0 <= H[row, k] < H[row, col] for k < col + pivot_val = H[row, col].clone() + if pivot_val != 0: + for j in range(col): + # floor division + factor = torch.div(H[row, j], pivot_val, rounding_mode="floor") + + if factor != 0: + H[:, j] -= factor * H[:, col] + U[:, j] -= factor * U[:, col] + V[col, :] += factor * V[j, :] + + row += 1 + col += 1 + + return H, U, V + + +def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: + """ + Computes the GCD and the projection factors. i.e. + S1 = G @ K1 + S2 = G @ K2 + + Args: + S1, S2: torch.Tensors (m x n1), (m x n2) + + Returns: + G: (m x m) The Matrix GCD (Canonical Base) + K1: (m x n1) Factors for S1 + K2: (m x n2) Factors for S2 + """ + assert S1.shape[0] == S2.shape[0], "Virtual dimension mismatch" + m = S1.shape[0] + n1 = S1.shape[1] + + # 1. Stack: [S1 | S2] + A = torch.cat([S1, S2], dim=1) + + # 2. Decompose + H, U, V = hnf_decomposition(A) + + # 3. Extract G (First m columns of H) + G = H[:, :m] + + # 4. Extract Factors from V + # S = G @ V_top. + # V tracks the inverse transforms, so it contains the coefficients K directly. + V_active = V[:m, :] # Top m rows + + K1 = V_active[:, :n1] + K2 = V_active[:, n1:] + + return G, K1, K2 + + +def compute_lcm(S1, S2): + """ + Computes the Matrix LCM (L) and the Multiples (M1, M2), i.e. + L = S1 @ M1 = S2 @ M2 + + Returns: + L: (m x m) The Matrix LCM + M1: (n1 x m) Factor such that L = S1 @ M1 + M2: (n2 x m) Factor such that L = S2 @ M2 + """ + m = S1.shape[0] + n1 = S1.shape[1] + + # 1. Kernel Setup: [S1 | -S2] + B = torch.cat([S1, -S2], dim=1) + + # 2. Decompose to find Kernel + H_B, U_B, _ = hnf_decomposition(B) + + # 3. Find Zero Columns in H_B (Kernel basis) + # Sum abs values down columns + col_mags = torch.sum(torch.abs(H_B), dim=0) + zero_indices = torch.nonzero(col_mags == 0, as_tuple=True)[0] + + if len(zero_indices) == 0: + return torch.zeros((m, m), dtype=torch.long) + + # 4. Extract Kernel Basis + # U_B columns corresponding to H_B zeros are the kernel generators + kernel_basis = U_B[:, zero_indices] + + # 5. Map back to Image Space + # The kernel vector is [u; v]. We need u (top n1 rows). + # Intersection = S1 @ u + u_parts = kernel_basis[:n1, :] + L_generators = S1 @ u_parts + + # 6. Canonicalize L + # The generators might be redundant or non-square. + # Run HNF one last time to get the unique square LCM matrix. + L, _, _ = hnf_decomposition(L_generators) + + return L[:, :m] From 4f19317e575415dc2147ac6163ea50e73fb21b22 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 21 Nov 2025 09:35:04 +0100 Subject: [PATCH 08/42] Improve GCD for tall stride matrices. --- src/torchjd/sparse/_linalg.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 7d1b973e6..163cc6ed9 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -170,8 +170,18 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: # 2. Decompose H, U, V = hnf_decomposition(A) - # 3. Extract G (First m columns of H) - G = H[:, :m] + col_magnitudes = torch.sum(torch.abs(H), dim=0) + # Find the last index that is non-zero. + non_zero_indices = torch.nonzero(col_magnitudes, as_tuple=True)[0] + + if len(non_zero_indices) == 0: + rank = 0 + else: + rank = non_zero_indices.max().item() + 1 + + # 3. Extract G (Compact Basis) + # We only take the first 'rank' columns. + G = H[:, :rank] # 4. Extract Factors from V # S = G @ V_top. From 4dbce6d9d7b92d065d248ae3137ba710621636cb Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 21 Nov 2025 15:53:38 +0100 Subject: [PATCH 09/42] Revamp `compute_gcd` --- src/torchjd/sparse/_linalg.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 163cc6ed9..26f60a2f3 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -161,17 +161,28 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: K2: (m x n2) Factors for S2 """ assert S1.shape[0] == S2.shape[0], "Virtual dimension mismatch" - m = S1.shape[0] - n1 = S1.shape[1] + m, n1 = S1.shape - # 1. Stack: [S1 | S2] A = torch.cat([S1, S2], dim=1) - - # 2. Decompose H, U, V = hnf_decomposition(A) + # H = [S1 | S2] @ U + # [S1 | S2] = H @ V + # + # S1 = H @ V[:, :m1] + # S2 = H @ V[:, m1:] + # + # K1 = V[:, :m1] + # K2 = V[:, m1:] + # G = H + # + # S1 = G @ K1 + # S2 = G @ K2 + # + # SST(p1, S1) = SST(SST(p1, K1), G) + # SST(p2, S2) = SST(SST(p2, K2), G) + col_magnitudes = torch.sum(torch.abs(H), dim=0) - # Find the last index that is non-zero. non_zero_indices = torch.nonzero(col_magnitudes, as_tuple=True)[0] if len(non_zero_indices) == 0: @@ -179,14 +190,8 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: else: rank = non_zero_indices.max().item() + 1 - # 3. Extract G (Compact Basis) - # We only take the first 'rank' columns. G = H[:, :rank] - - # 4. Extract Factors from V - # S = G @ V_top. - # V tracks the inverse transforms, so it contains the coefficients K directly. - V_active = V[:m, :] # Top m rows + V_active = V[:rank, :] K1 = V_active[:, :n1] K2 = V_active[:, n1:] From 35522f7fe01f677d2a2708802516d383b9354424 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 22 Nov 2025 17:38:59 +0100 Subject: [PATCH 10/42] Remove mod_c and intdiv_c tests --- .../sparse/test_structured_sparse_tensor.py | 49 +------------------ 1 file changed, 1 insertion(+), 48 deletions(-) diff --git a/tests/unit/sparse/test_structured_sparse_tensor.py b/tests/unit/sparse/test_structured_sparse_tensor.py index a2f800212..e0e404149 100644 --- a/tests/unit/sparse/test_structured_sparse_tensor.py +++ b/tests/unit/sparse/test_structured_sparse_tensor.py @@ -1,5 +1,5 @@ import torch -from pytest import mark, raises +from pytest import mark from torch import Tensor, tensor from torch.ops import aten # type: ignore from torch.testing import assert_close @@ -12,7 +12,6 @@ ) from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim from torchjd.sparse._coalesce import fix_zero_stride_columns -from torchjd.sparse._linalg import intdiv_c, mod_c from torchjd.sparse._structured_sparse_tensor import ( StructuredSparseTensor, fix_ungrouped_dims, @@ -421,49 +420,3 @@ def test_fix_zero_stride_columns( physical, strides = fix_zero_stride_columns(physical, strides) assert torch.equal(physical, expected_physical) assert torch.equal(strides, expected_strides) - - -@mark.parametrize( - ["t1", "t2", "expected"], - [ - (tensor([8, 12]), tensor([2, 3]), tensor([0, 0])), - (tensor([8, 12]), tensor([2, 4]), tensor([2, 0])), - (tensor([8, 12]), tensor([3, 3]), tensor([2, 6])), - (tensor([8, 12]), tensor([2, 0]), tensor([0, 12])), - (tensor([8, 12]), tensor([0, 2]), tensor([8, 0])), - ], -) -def test_mod_c( - t1: Tensor, - t2: Tensor, - expected: Tensor, -): - assert torch.equal(mod_c(t1, t2), expected) - - -def test_mod_c_by_0_raises(): - with raises(ZeroDivisionError): - mod_c(tensor([3, 4]), tensor([0, 0])) - - -@mark.parametrize( - ["t1", "t2", "expected"], - [ - (tensor([8, 12]), tensor([2, 3]), 4), - (tensor([8, 12]), tensor([2, 4]), 3), - (tensor([8, 12]), tensor([3, 3]), 2), - (tensor([8, 12]), tensor([2, 0]), 4), - (tensor([8, 12]), tensor([0, 2]), 6), - ], -) -def test_intdiv_c( - t1: Tensor, - t2: Tensor, - expected: Tensor, -): - assert intdiv_c(t1, t2) == expected - - -def test_intdiv_c_by_0_raises(): - with raises(ZeroDivisionError): - intdiv_c(tensor([3, 4]), tensor([0, 0])) From 131fbb4d2dd698743d251d2cdf6238db3c381def Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 22 Nov 2025 17:44:20 +0100 Subject: [PATCH 11/42] Rename SST to SparseLatticedTensor --- src/torchjd/sparse/__init__.py | 2 +- .../_aten_function_overrides/backward.py | 22 +++---- .../sparse/_aten_function_overrides/einsum.py | 58 +++++++++---------- .../_aten_function_overrides/pointwise.py | 28 ++++----- .../sparse/_aten_function_overrides/shape.py | 42 +++++++------- ...e_tensor.py => _sparse_latticed_tensor.py} | 24 ++++---- ...nsor.py => test_sparse_latticed_tensor.py} | 40 ++++++------- 7 files changed, 108 insertions(+), 108 deletions(-) rename src/torchjd/sparse/{_structured_sparse_tensor.py => _sparse_latticed_tensor.py} (93%) rename tests/unit/sparse/{test_structured_sparse_tensor.py => test_sparse_latticed_tensor.py} (90%) diff --git a/src/torchjd/sparse/__init__.py b/src/torchjd/sparse/__init__.py index 7a161b6ad..071ad680a 100644 --- a/src/torchjd/sparse/__init__.py +++ b/src/torchjd/sparse/__init__.py @@ -1,3 +1,3 @@ # Need to import this to execute the code inside and thus to override the functions from . import _aten_function_overrides -from ._structured_sparse_tensor import StructuredSparseTensor, make_sst +from ._sparse_latticed_tensor import SparseLatticedTensor, make_sst diff --git a/src/torchjd/sparse/_aten_function_overrides/backward.py b/src/torchjd/sparse/_aten_function_overrides/backward.py index 9168c7653..ce44a3e4f 100644 --- a/src/torchjd/sparse/_aten_function_overrides/backward.py +++ b/src/torchjd/sparse/_aten_function_overrides/backward.py @@ -1,36 +1,36 @@ from torch import Tensor from torch.ops import aten # type: ignore -from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl +from torchjd.sparse._sparse_latticed_tensor import SparseLatticedTensor, impl @impl(aten.threshold_backward.default) def threshold_backward_default( - grad_output: StructuredSparseTensor, self: Tensor, threshold -) -> StructuredSparseTensor: + grad_output: SparseLatticedTensor, self: Tensor, threshold +) -> SparseLatticedTensor: new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) - return StructuredSparseTensor(new_physical, grad_output.strides) + return SparseLatticedTensor(new_physical, grad_output.strides) @impl(aten.hardtanh_backward.default) def hardtanh_backward_default( - grad_output: StructuredSparseTensor, + grad_output: SparseLatticedTensor, self: Tensor, min_val: Tensor | int | float, max_val: Tensor | int | float, -) -> StructuredSparseTensor: - if isinstance(self, StructuredSparseTensor): +) -> SparseLatticedTensor: + if isinstance(self, SparseLatticedTensor): raise NotImplementedError() new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) - return StructuredSparseTensor(new_physical, grad_output.strides) + return SparseLatticedTensor(new_physical, grad_output.strides) @impl(aten.hardswish_backward.default) -def hardswish_backward_default(grad_output: StructuredSparseTensor, self: Tensor): - if isinstance(self, StructuredSparseTensor): +def hardswish_backward_default(grad_output: SparseLatticedTensor, self: Tensor): + if isinstance(self, SparseLatticedTensor): raise NotImplementedError() new_physical = aten.hardswish_backward.default(grad_output.physical, self) - return StructuredSparseTensor(new_physical, grad_output.strides) + return SparseLatticedTensor(new_physical, grad_output.strides) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index 9775cccc3..6081e2fb5 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -2,15 +2,15 @@ from torch import Tensor, tensor from torch.ops import aten # type: ignore -from torchjd.sparse._structured_sparse_tensor import ( - StructuredSparseTensor, +from torchjd.sparse._sparse_latticed_tensor import ( + SparseLatticedTensor, impl, to_most_efficient_tensor, - to_structured_sparse_tensor, + to_sparse_latticed_tensor, ) -def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) -> Tensor: +def einsum(*args: tuple[SparseLatticedTensor, list[int]], output: list[int]) -> Tensor: raise NotImplementedError() # First part of the algorithm, determine how to cluster physical indices as well as the common @@ -39,13 +39,13 @@ def einsum(*args: tuple[StructuredSparseTensor, list[int]], output: list[int]) - # [p_1, ..., p_k], then we have to create fresh sub-indices for each dimension. # For this reason, an index is decomposed into sub-indices that are then independently # clustered. - # So if an index i in args for some StructuredSparseTensor corresponds to a v_to_ps [j, k, l], + # So if an index i in args for some SparseLatticedTensor corresponds to a v_to_ps [j, k, l], # We will consider three indices (i, 0), (i, 1) and (i, 2). # If furthermore [k] correspond to the v_to_ps of some other tensor with index j, then # (i, 1) and (j, 0) will be clustered together (and end up being mapped to the same indice in # the resulting einsum). # Note that this is a problem if two virtual dimensions (from possibly different - # StructuredSparseTensors) have the same size but not the same decomposition into physical + # SparseLatticedTensors) have the same size but not the same decomposition into physical # dimension sizes. For now lets leave the responsibility to care about that in the calling # functions, if we can factor code later on we will. @@ -71,7 +71,7 @@ def group_indices(indices: list[tuple[int, int]]) -> None: physicals = list[Tensor]() indices_to_n_pdims = dict[int, int]() for t, indices in args: - assert isinstance(t, StructuredSparseTensor) + assert isinstance(t, SparseLatticedTensor) physicals.append(t.physical) for pdims, index in zip(t.v_to_ps, indices): if index in indices_to_n_pdims: @@ -129,13 +129,13 @@ def unique_int(pair: tuple[int, int]) -> int: def prepare_for_elementwise_op( t1: Tensor | int | float, t2: Tensor | int | float -) -> tuple[StructuredSparseTensor, StructuredSparseTensor]: +) -> tuple[SparseLatticedTensor, SparseLatticedTensor]: """ Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being a SST, Tensor, int or float. """ - assert isinstance(t1, StructuredSparseTensor) or isinstance(t2, StructuredSparseTensor) + assert isinstance(t1, SparseLatticedTensor) or isinstance(t2, SparseLatticedTensor) if isinstance(t1, int) or isinstance(t1, float): t1_ = tensor(t1, device=t2.device) @@ -148,8 +148,8 @@ def prepare_for_elementwise_op( t2_ = t2 t1_, t2_ = aten.broadcast_tensors.default([t1_, t2_]) - t1_ = to_structured_sparse_tensor(t1_) - t2_ = to_structured_sparse_tensor(t2_) + t1_ = to_sparse_latticed_tensor(t1_) + t2_ = to_sparse_latticed_tensor(t2_) return t1_, t2_ @@ -165,37 +165,37 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: @impl(aten.div.Tensor) def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) - t2_ = StructuredSparseTensor(1.0 / t2_.physical, t2_.strides) + t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.strides) all_dims = list(range(t1_.ndim)) return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) @impl(aten.mul.Scalar) -def mul_Scalar(t: StructuredSparseTensor, scalar) -> StructuredSparseTensor: +def mul_Scalar(t: SparseLatticedTensor, scalar) -> SparseLatticedTensor: # TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check # that - assert isinstance(t, StructuredSparseTensor) + assert isinstance(t, SparseLatticedTensor) new_physical = aten.mul.Scalar(t.physical, scalar) - return StructuredSparseTensor(new_physical, t.strides) + return SparseLatticedTensor(new_physical, t.strides) @impl(aten.add.Tensor) def add_Tensor( t1: Tensor | int | float, t2: Tensor | int | float, alpha: Tensor | float = 1.0 -) -> StructuredSparseTensor: +) -> SparseLatticedTensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) if torch.equal(t1_.strides, t2_.strides): new_physical = t1_.physical + t2_.physical * alpha - return StructuredSparseTensor(new_physical, t1_.strides) + return SparseLatticedTensor(new_physical, t1_.strides) else: raise NotImplementedError() @impl(aten.bmm.default) def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: - assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor) + assert isinstance(mat1, SparseLatticedTensor) or isinstance(mat2, SparseLatticedTensor) assert ( mat1.ndim == 3 and mat2.ndim == 3 @@ -203,8 +203,8 @@ def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: and mat1.shape[2] == mat2.shape[1] ) - mat1_ = to_structured_sparse_tensor(mat1) - mat2_ = to_structured_sparse_tensor(mat2) + mat1_ = to_sparse_latticed_tensor(mat1) + mat2_ = to_sparse_latticed_tensor(mat2) # TODO: Verify that the dimension `0` of mat1_ and mat2_ have the same physical dimension sizes # decompositions. If not, can reshape to common decomposition? @@ -213,32 +213,32 @@ def bmm_default(mat1: Tensor, mat2: Tensor) -> Tensor: @impl(aten.mm.default) def mm_default(mat1: Tensor, mat2: Tensor) -> Tensor: - assert isinstance(mat1, StructuredSparseTensor) or isinstance(mat2, StructuredSparseTensor) + assert isinstance(mat1, SparseLatticedTensor) or isinstance(mat2, SparseLatticedTensor) assert mat1.ndim == 2 and mat2.ndim == 2 and mat1.shape[1] == mat2.shape[0] - mat1_ = to_structured_sparse_tensor(mat1) - mat2_ = to_structured_sparse_tensor(mat2) + mat1_ = to_sparse_latticed_tensor(mat1) + mat2_ = to_sparse_latticed_tensor(mat2) return einsum((mat1_, [0, 1]), (mat2_, [1, 2]), output=[0, 2]) @impl(aten.mean.default) -def mean_default(t: StructuredSparseTensor) -> Tensor: - assert isinstance(t, StructuredSparseTensor) +def mean_default(t: SparseLatticedTensor) -> Tensor: + assert isinstance(t, SparseLatticedTensor) return aten.sum.default(t.physical) / t.numel() @impl(aten.sum.default) -def sum_default(t: StructuredSparseTensor) -> Tensor: - assert isinstance(t, StructuredSparseTensor) +def sum_default(t: SparseLatticedTensor) -> Tensor: + assert isinstance(t, SparseLatticedTensor) return aten.sum.default(t.physical) @impl(aten.sum.dim_IntList) def sum_dim_IntList( - t: StructuredSparseTensor, dim: list[int], keepdim: bool = False, dtype=None + t: SparseLatticedTensor, dim: list[int], keepdim: bool = False, dtype=None ) -> Tensor: - assert isinstance(t, StructuredSparseTensor) + assert isinstance(t, SparseLatticedTensor) if dtype: raise NotImplementedError() diff --git a/src/torchjd/sparse/_aten_function_overrides/pointwise.py b/src/torchjd/sparse/_aten_function_overrides/pointwise.py index 9d389c10b..540b86fbc 100644 --- a/src/torchjd/sparse/_aten_function_overrides/pointwise.py +++ b/src/torchjd/sparse/_aten_function_overrides/pointwise.py @@ -1,6 +1,6 @@ from torch.ops import aten # type: ignore -from torchjd.sparse._structured_sparse_tensor import StructuredSparseTensor, impl +from torchjd.sparse._sparse_latticed_tensor import SparseLatticedTensor, impl # pointwise functions applied to one Tensor with `0.0 → 0` _POINTWISE_FUNCTIONS = [ @@ -69,17 +69,17 @@ def _override_pointwise(op): @impl(op) - def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: - assert isinstance(t, StructuredSparseTensor) - return StructuredSparseTensor(op(t.physical), t.strides) + def func_(t: SparseLatticedTensor) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) + return SparseLatticedTensor(op(t.physical), t.strides) return func_ def _override_inplace_pointwise(op): @impl(op) - def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: - assert isinstance(t, StructuredSparseTensor) + def func_(t: SparseLatticedTensor) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) op(t.physical) return t @@ -92,21 +92,21 @@ def func_(t: StructuredSparseTensor) -> StructuredSparseTensor: @impl(aten.pow.Tensor_Scalar) -def pow_Tensor_Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor: - assert isinstance(t, StructuredSparseTensor) +def pow_Tensor_Scalar(t: SparseLatticedTensor, exponent: float) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) if exponent <= 0.0: # Need to densify because we don't have pow(0.0, exponent) = 0.0 return aten.pow.Tensor_Scalar(t.to_dense(), exponent) new_physical = aten.pow.Tensor_Scalar(t.physical, exponent) - return StructuredSparseTensor(new_physical, t.strides) + return SparseLatticedTensor(new_physical, t.strides) # Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. @impl(aten.pow_.Scalar) -def pow__Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseTensor: - assert isinstance(t, StructuredSparseTensor) +def pow__Scalar(t: SparseLatticedTensor, exponent: float) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) if exponent <= 0.0: # Need to densify because we don't have pow(0.0, exponent) = 0.0 @@ -118,8 +118,8 @@ def pow__Scalar(t: StructuredSparseTensor, exponent: float) -> StructuredSparseT @impl(aten.div.Scalar) -def div_Scalar(t: StructuredSparseTensor, divisor: float) -> StructuredSparseTensor: - assert isinstance(t, StructuredSparseTensor) +def div_Scalar(t: SparseLatticedTensor, divisor: float) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) new_physical = aten.div.Scalar(t.physical, divisor) - return StructuredSparseTensor(new_physical, t.strides) + return SparseLatticedTensor(new_physical, t.strides) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 07ed16bbc..e69eb667b 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -7,8 +7,8 @@ from torch import Tensor, arange, tensor from torch.ops import aten # type: ignore -from torchjd.sparse._structured_sparse_tensor import ( - StructuredSparseTensor, +from torchjd.sparse._sparse_latticed_tensor import ( + SparseLatticedTensor, impl, print_fallback, to_most_efficient_tensor, @@ -17,7 +17,7 @@ @impl(aten.view.default) -def view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: +def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor: """ The main condition that we want to respect is that the indexing in the flattened virtual tensor should remain the same before and after the reshape, i.e. @@ -39,7 +39,7 @@ def view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: ... """ - assert isinstance(t, StructuredSparseTensor) + assert isinstance(t, SparseLatticedTensor) shape = infer_shape(shape, t.numel()) @@ -125,15 +125,15 @@ def unsquash_pdim( @impl(aten._unsafe_view.default) -def _unsafe_view_default(t: StructuredSparseTensor, shape: list[int]) -> Tensor: +def _unsafe_view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor: return view_default( t, shape ) # We don't do the optimizations that they do in https://github.com/pytorch/pytorch/blame/main/aten/src/ATen/native/TensorShape.cpp @impl(aten.unsqueeze.default) -def unsqueeze_default(t: StructuredSparseTensor, dim: int) -> StructuredSparseTensor: - assert isinstance(t, StructuredSparseTensor) +def unsqueeze_default(t: SparseLatticedTensor, dim: int) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) assert -t.ndim - 1 <= dim < t.ndim + 1 if dim < 0: @@ -142,12 +142,12 @@ def unsqueeze_default(t: StructuredSparseTensor, dim: int) -> StructuredSparseTe new_strides = torch.concatenate( [t.strides[:dim], torch.zeros(1, t.strides.shape[1], dtype=torch.int64), t.strides[dim:]] ) - return StructuredSparseTensor(t.physical, new_strides) + return SparseLatticedTensor(t.physical, new_strides) @impl(aten.squeeze.dims) -def squeeze_dims(t: StructuredSparseTensor, dims: list[int] | int | None) -> Tensor: - assert isinstance(t, StructuredSparseTensor) +def squeeze_dims(t: SparseLatticedTensor, dims: list[int] | int | None) -> Tensor: + assert isinstance(t, SparseLatticedTensor) if dims is None: excluded = set(range(t.ndim)) @@ -162,18 +162,18 @@ def squeeze_dims(t: StructuredSparseTensor, dims: list[int] | int | None) -> Ten @impl(aten.permute.default) -def permute_default(t: StructuredSparseTensor, dims: list[int]) -> StructuredSparseTensor: +def permute_default(t: SparseLatticedTensor, dims: list[int]) -> SparseLatticedTensor: new_strides = t.strides[torch.tensor(dims)] - return StructuredSparseTensor(t.physical, new_strides) + return SparseLatticedTensor(t.physical, new_strides) @impl(aten.cat.default) def cat_default(tensors: list[Tensor], dim: int) -> Tensor: - if any(not isinstance(t, StructuredSparseTensor) for t in tensors): + if any(not isinstance(t, SparseLatticedTensor) for t in tensors): print_fallback(aten.cat.default, (tensors, dim), {}) return aten.cat.default([unwrap_to_dense(t) for t in tensors]) - tensors_ = [cast(StructuredSparseTensor, t) for t in tensors] + tensors_ = [cast(SparseLatticedTensor, t) for t in tensors] ref_tensor = tensors_[0] ref_strides = ref_tensor.strides if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]): @@ -211,13 +211,13 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: new_strides = ref_tensor.strides new_physical = aten.cat.default(physicals, dim=pdim) - return StructuredSparseTensor(new_physical, new_strides) + return SparseLatticedTensor(new_physical, new_strides) @impl(aten.expand.default) -def expand_default(t: StructuredSparseTensor, sizes: list[int]) -> StructuredSparseTensor: +def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedTensor: # note that sizes could also be just an int, or a torch.Size i think - assert isinstance(t, StructuredSparseTensor) + assert isinstance(t, SparseLatticedTensor) assert isinstance(sizes, list) assert len(sizes) >= t.ndim @@ -246,7 +246,7 @@ def expand_default(t: StructuredSparseTensor, sizes: list[int]) -> StructuredSpa new_stride_column[d, 0] = 1 new_strides = torch.cat([new_strides, new_stride_column], dim=1) - return StructuredSparseTensor(new_physical, new_strides) + return SparseLatticedTensor(new_physical, new_strides) @impl(aten.broadcast_tensors.default) @@ -277,9 +277,9 @@ def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: @impl(aten.transpose.int) -def transpose_int(t: StructuredSparseTensor, dim0: int, dim1: int) -> StructuredSparseTensor: - assert isinstance(t, StructuredSparseTensor) - return StructuredSparseTensor(t.physical, _swap_rows(t.strides, dim0, dim1)) +def transpose_int(t: SparseLatticedTensor, dim0: int, dim1: int) -> SparseLatticedTensor: + assert isinstance(t, SparseLatticedTensor) + return SparseLatticedTensor(t.physical, _swap_rows(t.strides, dim0, dim1)) def _swap_rows(matrix: Tensor, c0: int, c1: int) -> Tensor: diff --git a/src/torchjd/sparse/_structured_sparse_tensor.py b/src/torchjd/sparse/_sparse_latticed_tensor.py similarity index 93% rename from src/torchjd/sparse/_structured_sparse_tensor.py rename to src/torchjd/sparse/_sparse_latticed_tensor.py index 6b7c396be..2f18b5336 100644 --- a/src/torchjd/sparse/_structured_sparse_tensor.py +++ b/src/torchjd/sparse/_sparse_latticed_tensor.py @@ -10,7 +10,7 @@ from torch.utils._pytree import tree_map -class StructuredSparseTensor(Tensor): +class SparseLatticedTensor(Tensor): _HANDLED_FUNCTIONS = dict[Callable, Callable]() @staticmethod @@ -108,7 +108,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return func(*unwrapped_args, **unwrapped_kwargs) def __repr__(self, *, tensor_contents=None) -> str: - return f"StructuredSparseTensor(physical={self.physical}, strides={self.strides})" + return f"SparseLatticedTensor(physical={self.physical}, strides={self.strides})" def debug_info(self) -> str: info = ( @@ -130,13 +130,13 @@ def decorator(func): return decorator -impl = StructuredSparseTensor.implements +impl = SparseLatticedTensor.implements def print_fallback(func, args, kwargs) -> None: def tensor_to_str(t: Tensor) -> str: result = f"{t.__class__.__name__} - vshape: {t.shape}" - if isinstance(t, StructuredSparseTensor): + if isinstance(t, SparseLatticedTensor): result += f" - pshape: {t.physical.shape} - strides: {t.strides}" return result @@ -206,8 +206,8 @@ def get_groupings(pshape: list[int], strides: Tensor) -> list[list[int]]: return new_columns -def to_structured_sparse_tensor(t: Tensor) -> StructuredSparseTensor: - if isinstance(t, StructuredSparseTensor): +def to_sparse_latticed_tensor(t: Tensor) -> SparseLatticedTensor: + if isinstance(t, SparseLatticedTensor): return t else: return make_sst(physical=t, strides=torch.eye(t.ndim, dtype=torch.int64)) @@ -219,13 +219,13 @@ def to_most_efficient_tensor(physical: Tensor, strides: Tensor) -> Tensor: if (strides.sum(dim=0) == 1).all(): # TODO: this can be done more efficiently (without even creating the SST) - return StructuredSparseTensor(physical, strides).to_dense() + return SparseLatticedTensor(physical, strides).to_dense() else: - return StructuredSparseTensor(physical, strides) + return SparseLatticedTensor(physical, strides) def unwrap_to_dense(t: Tensor): - if isinstance(t, StructuredSparseTensor): + if isinstance(t, SparseLatticedTensor): return t.to_dense() else: return t @@ -271,9 +271,9 @@ def fix_ungrouped_dims(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tenso return nphysical, new_strides -def make_sst(physical: Tensor, strides: Tensor) -> StructuredSparseTensor: - """Fix physical and strides and create a StructuredSparseTensor with them.""" +def make_sst(physical: Tensor, strides: Tensor) -> SparseLatticedTensor: + """Fix physical and strides and create a SparseLatticedTensor with them.""" physical, strides = fix_dim_of_size_1(physical, strides) physical, strides = fix_ungrouped_dims(physical, strides) - return StructuredSparseTensor(physical, strides) + return SparseLatticedTensor(physical, strides) diff --git a/tests/unit/sparse/test_structured_sparse_tensor.py b/tests/unit/sparse/test_sparse_latticed_tensor.py similarity index 90% rename from tests/unit/sparse/test_structured_sparse_tensor.py rename to tests/unit/sparse/test_sparse_latticed_tensor.py index e0e404149..684f2f32b 100644 --- a/tests/unit/sparse/test_structured_sparse_tensor.py +++ b/tests/unit/sparse/test_sparse_latticed_tensor.py @@ -12,8 +12,8 @@ ) from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim from torchjd.sparse._coalesce import fix_zero_stride_columns -from torchjd.sparse._structured_sparse_tensor import ( - StructuredSparseTensor, +from torchjd.sparse._sparse_latticed_tensor import ( + SparseLatticedTensor, fix_ungrouped_dims, get_full_source, get_groupings, @@ -24,7 +24,7 @@ def test_to_dense(): n = 2 m = 3 a = randn_([n, m]) - b = StructuredSparseTensor(a, tensor([[1, 0], [0, 1], [0, 1], [1, 0]])) + b = SparseLatticedTensor(a, tensor([[1, 0], [0, 1], [0, 1], [1, 0]])) c = b.to_dense() for i in range(n): @@ -34,7 +34,7 @@ def test_to_dense(): def test_to_dense2(): a = tensor_([1.0, 2.0, 3.0]) - b = StructuredSparseTensor(a, tensor([[4]])) + b = SparseLatticedTensor(a, tensor([[4]])) c = b.to_dense() expected = tensor_([1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]) assert torch.all(torch.eq(c, expected)) @@ -81,14 +81,14 @@ def test_einsum( b_indices: list[int], output_indices: list[int], ): - a = StructuredSparseTensor(randn_(a_pshape), a_strides) - b = StructuredSparseTensor(randn_(b_pshape), b_strides) + a = SparseLatticedTensor(randn_(a_pshape), a_strides) + b = SparseLatticedTensor(randn_(b_pshape), b_strides) res = einsum((a, a_indices), (b, b_indices), output=output_indices) expected = torch.einsum(a.to_dense(), a_indices, b.to_dense(), b_indices, output_indices) - assert isinstance(res, StructuredSparseTensor) + assert isinstance(res, SparseLatticedTensor) assert_close(res.to_dense(), expected) @@ -101,9 +101,9 @@ def test_einsum( [2, 3, 4], ], ) -def test_structured_sparse_tensor_scalar(shape: list[int]): +def test_sparse_latticed_tensor_scalar(shape: list[int]): a = randn_(shape) - b = StructuredSparseTensor(a, torch.eye(len(shape), dtype=torch.int64)) + b = SparseLatticedTensor(a, torch.eye(len(shape), dtype=torch.int64)) assert_close(a, b.to_dense()) @@ -111,7 +111,7 @@ def test_structured_sparse_tensor_scalar(shape: list[int]): @mark.parametrize("dim", [2, 3, 4, 5, 10]) def test_diag_equivalence(dim: int): a = randn_([dim]) - b = StructuredSparseTensor(a, tensor([[1], [1]])) + b = SparseLatticedTensor(a, tensor([[1], [1]])) diag_a = torch.diag(a) @@ -121,7 +121,7 @@ def test_diag_equivalence(dim: int): def test_three_virtual_single_physical(): dim = 10 a = randn_([dim]) - b = StructuredSparseTensor(a, tensor([[1], [1], [1]])) + b = SparseLatticedTensor(a, tensor([[1], [1], [1]])) expected = zeros_([dim, dim, dim]) for i in range(dim): @@ -134,10 +134,10 @@ def test_three_virtual_single_physical(): def test_pointwise(func): dim = 10 a = randn_([dim]) - b = StructuredSparseTensor(a, tensor([[1], [1]])) + b = SparseLatticedTensor(a, tensor([[1], [1]])) c = b.to_dense() res = func(b) - assert isinstance(res, StructuredSparseTensor) + assert isinstance(res, SparseLatticedTensor) assert_close(res.to_dense(), func(c), equal_nan=True) @@ -146,10 +146,10 @@ def test_pointwise(func): def test_inplace_pointwise(func): dim = 10 a = randn_([dim]) - b = StructuredSparseTensor(a, tensor([[1], [1]])) + b = SparseLatticedTensor(a, tensor([[1], [1]])) c = b.to_dense() func(b) - assert isinstance(b, StructuredSparseTensor) + assert isinstance(b, SparseLatticedTensor) assert_close(b.to_dense(), func(c), equal_nan=True) @@ -158,7 +158,7 @@ def test_inplace_pointwise(func): def test_unary(func): dim = 10 a = randn_([dim]) - b = StructuredSparseTensor(a, tensor([[1], [1]])) + b = SparseLatticedTensor(a, tensor([[1], [1]])) c = b.to_dense() res = func(b) @@ -255,12 +255,12 @@ def test_view( expected_strides: Tensor, ): a = randn_(tuple(physical_shape)) - t = StructuredSparseTensor(a, strides) + t = SparseLatticedTensor(a, strides) result = aten.view.default(t, target_shape) expected = t.to_dense().reshape(target_shape) - assert isinstance(result, StructuredSparseTensor) + assert isinstance(result, SparseLatticedTensor) assert list(result.physical.shape) == expected_physical_shape assert torch.equal(result.strides, expected_strides) assert torch.all(torch.eq(result.to_dense(), expected)) @@ -380,11 +380,11 @@ def test_concatenate( sst_args: list[tuple[list[int], Tensor]], dim: int, ): - tensors = [StructuredSparseTensor(randn_(pshape), strides) for pshape, strides in sst_args] + tensors = [SparseLatticedTensor(randn_(pshape), strides) for pshape, strides in sst_args] res = aten.cat.default(tensors, dim) expected = aten.cat.default([t.to_dense() for t in tensors], dim) - assert isinstance(res, StructuredSparseTensor) + assert isinstance(res, SparseLatticedTensor) assert torch.all(torch.eq(res.to_dense(), expected)) From 63549ca214ee441498c9d73d83c04ce6541121c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 22 Nov 2025 18:20:15 +0100 Subject: [PATCH 12/42] Rename stride to basis --- src/torchjd/autogram/_engine.py | 4 +- .../_aten_function_overrides/backward.py | 6 +- .../sparse/_aten_function_overrides/einsum.py | 8 +- .../_aten_function_overrides/pointwise.py | 6 +- .../sparse/_aten_function_overrides/shape.py | 84 +++++++------- src/torchjd/sparse/_coalesce.py | 17 ++- src/torchjd/sparse/_sparse_latticed_tensor.py | 105 ++++++++---------- .../sparse/test_sparse_latticed_tensor.py | 64 +++++------ 8 files changed, 143 insertions(+), 151 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 964b94a67..d560b4efc 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -176,8 +176,8 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: output_dims = list(range(output.ndim)) identity = torch.eye(output.ndim, dtype=torch.int64) - strides = torch.concatenate([identity, identity], dim=0) - jac_output = make_sst(torch.ones_like(output), strides) + basis = torch.concatenate([identity, identity], dim=0) + jac_output = make_sst(torch.ones_like(output), basis) vmapped_diff = differentiation for _ in output_dims: diff --git a/src/torchjd/sparse/_aten_function_overrides/backward.py b/src/torchjd/sparse/_aten_function_overrides/backward.py index ce44a3e4f..f8d9716bc 100644 --- a/src/torchjd/sparse/_aten_function_overrides/backward.py +++ b/src/torchjd/sparse/_aten_function_overrides/backward.py @@ -10,7 +10,7 @@ def threshold_backward_default( ) -> SparseLatticedTensor: new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) - return SparseLatticedTensor(new_physical, grad_output.strides) + return SparseLatticedTensor(new_physical, grad_output.basis) @impl(aten.hardtanh_backward.default) @@ -24,7 +24,7 @@ def hardtanh_backward_default( raise NotImplementedError() new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) - return SparseLatticedTensor(new_physical, grad_output.strides) + return SparseLatticedTensor(new_physical, grad_output.basis) @impl(aten.hardswish_backward.default) @@ -33,4 +33,4 @@ def hardswish_backward_default(grad_output: SparseLatticedTensor, self: Tensor): raise NotImplementedError() new_physical = aten.hardswish_backward.default(grad_output.physical, self) - return SparseLatticedTensor(new_physical, grad_output.strides) + return SparseLatticedTensor(new_physical, grad_output.basis) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index 6081e2fb5..c5bc2b273 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -165,7 +165,7 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: @impl(aten.div.Tensor) def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) - t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.strides) + t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.basis) all_dims = list(range(t1_.ndim)) return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) @@ -177,7 +177,7 @@ def mul_Scalar(t: SparseLatticedTensor, scalar) -> SparseLatticedTensor: assert isinstance(t, SparseLatticedTensor) new_physical = aten.mul.Scalar(t.physical, scalar) - return SparseLatticedTensor(new_physical, t.strides) + return SparseLatticedTensor(new_physical, t.basis) @impl(aten.add.Tensor) @@ -186,9 +186,9 @@ def add_Tensor( ) -> SparseLatticedTensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) - if torch.equal(t1_.strides, t2_.strides): + if torch.equal(t1_.basis, t2_.basis): new_physical = t1_.physical + t2_.physical * alpha - return SparseLatticedTensor(new_physical, t1_.strides) + return SparseLatticedTensor(new_physical, t1_.basis) else: raise NotImplementedError() diff --git a/src/torchjd/sparse/_aten_function_overrides/pointwise.py b/src/torchjd/sparse/_aten_function_overrides/pointwise.py index 540b86fbc..51dbd020b 100644 --- a/src/torchjd/sparse/_aten_function_overrides/pointwise.py +++ b/src/torchjd/sparse/_aten_function_overrides/pointwise.py @@ -71,7 +71,7 @@ def _override_pointwise(op): @impl(op) def func_(t: SparseLatticedTensor) -> SparseLatticedTensor: assert isinstance(t, SparseLatticedTensor) - return SparseLatticedTensor(op(t.physical), t.strides) + return SparseLatticedTensor(op(t.physical), t.basis) return func_ @@ -100,7 +100,7 @@ def pow_Tensor_Scalar(t: SparseLatticedTensor, exponent: float) -> SparseLattice return aten.pow.Tensor_Scalar(t.to_dense(), exponent) new_physical = aten.pow.Tensor_Scalar(t.physical, exponent) - return SparseLatticedTensor(new_physical, t.strides) + return SparseLatticedTensor(new_physical, t.basis) # Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. @@ -122,4 +122,4 @@ def div_Scalar(t: SparseLatticedTensor, divisor: float) -> SparseLatticedTensor: assert isinstance(t, SparseLatticedTensor) new_physical = aten.div.Scalar(t.physical, divisor) - return SparseLatticedTensor(new_physical, t.strides) + return SparseLatticedTensor(new_physical, t.basis) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index e69eb667b..34272c8be 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -28,8 +28,8 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor: c.T = [prod(t.shape[1:]), prod(t.shape[2:]), ..., t.shape[-1], 1] * c' is the same thing but after the reshape, i.e. c'.T = [prod(shape[1:]), prod(shape[2:]), ..., shape[-1], 1] - * S is the original matrix of strides (t.strides) - * S' is the matrix of strides after reshaping. + * S is the original basis matrix (t.basis) + * S' is the basis matrix after reshaping. For u, v in Z^m and c in Z, say that u ≡ v (mod c) if u_i ≡ v_i (mod c) for all i. Note that c'.T S' ≡ S'[-1] (mod shape[-1]) @@ -46,12 +46,12 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor: if prod(shape) != t.numel(): raise ValueError(f"shape '{shape}' is invalid for input of size {t.numel()}") - S = t.strides + S = t.basis vshape = list(t.shape) c = _reverse_cumulative_product(vshape) c_prime = _reverse_cumulative_product(shape) - new_strides = ((c @ S).unsqueeze(0) // c_prime.unsqueeze(1)) % tensor(shape).unsqueeze(1) - return to_most_efficient_tensor(t.physical, new_strides) + new_basis = ((c @ S).unsqueeze(0) // c_prime.unsqueeze(1)) % tensor(shape).unsqueeze(1) + return to_most_efficient_tensor(t.physical, new_basis) def _reverse_cumulative_product(values: list[int]) -> Tensor: @@ -70,7 +70,7 @@ def infer_shape(shape: list[int], numel: int) -> list[int]: def unsquash_pdim( - physical: Tensor, strides: Tensor, pdim: int, new_pdim_shape: list[int] + physical: Tensor, basis: Tensor, pdim: int, new_pdim_shape: list[int] ) -> tuple[Tensor, Tensor]: """ EXAMPLE: @@ -80,7 +80,7 @@ def unsquash_pdim( [7, 8, 9, 10, 11, 12], [13, 14, 15, 16, 17, 18], ] - strides = [ + basis = [ [1, 1], [0, 2], ] @@ -99,7 +99,7 @@ def unsquash_pdim( [16, 17, 18], ]] - new_strides = [ + new_basis = [ [1, 3, 1], [0, 6, 2] """ @@ -110,18 +110,18 @@ def unsquash_pdim( new_shape = old_shape[:pdim] + new_pdim_shape + old_shape[pdim + 1 :] new_physical = physical.reshape(new_shape) - stride_multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))]) + multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))]) - new_strides = torch.concat( + new_basis = torch.concat( [ - strides[:, :pdim], - torch.outer(strides[:, pdim], stride_multipliers), - strides[:, pdim + 1 :], + basis[:, :pdim], + torch.outer(basis[:, pdim], multipliers), + basis[:, pdim + 1 :], ], dim=1, ) - return new_physical, new_strides + return new_physical, new_basis @impl(aten._unsafe_view.default) @@ -139,10 +139,10 @@ def unsqueeze_default(t: SparseLatticedTensor, dim: int) -> SparseLatticedTensor if dim < 0: dim = t.ndim + dim + 1 - new_strides = torch.concatenate( - [t.strides[:dim], torch.zeros(1, t.strides.shape[1], dtype=torch.int64), t.strides[dim:]] + new_basis = torch.concatenate( + [t.basis[:dim], torch.zeros(1, t.basis.shape[1], dtype=torch.int64), t.basis[dim:]] ) - return SparseLatticedTensor(t.physical, new_strides) + return SparseLatticedTensor(t.physical, new_basis) @impl(aten.squeeze.dims) @@ -157,14 +157,14 @@ def squeeze_dims(t: SparseLatticedTensor, dims: list[int] | int | None) -> Tenso excluded = set(dims) is_row_kept = [i not in excluded for i in range(t.ndim)] - new_strides = t.strides[is_row_kept] - return to_most_efficient_tensor(t.physical, new_strides) + new_basis = t.basis[is_row_kept] + return to_most_efficient_tensor(t.physical, new_basis) @impl(aten.permute.default) def permute_default(t: SparseLatticedTensor, dims: list[int]) -> SparseLatticedTensor: - new_strides = t.strides[torch.tensor(dims)] - return SparseLatticedTensor(t.physical, new_strides) + new_basis = t.basis[torch.tensor(dims)] + return SparseLatticedTensor(t.physical, new_basis) @impl(aten.cat.default) @@ -175,11 +175,11 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: tensors_ = [cast(SparseLatticedTensor, t) for t in tensors] ref_tensor = tensors_[0] - ref_strides = ref_tensor.strides - if any(not torch.equal(t.strides, ref_strides) for t in tensors_[1:]): + ref_basis = ref_tensor.basis + if any(not torch.equal(t.basis, ref_basis) for t in tensors_[1:]): raise NotImplementedError( "Override for aten.cat.default does not support SSTs that do not all have the same " - f"strides. Found the following tensors:\n{[t.debug_info() for t in tensors_]} and the " + f"basis. Found the following tensors:\n{[t.debug_info() for t in tensors_]} and the " f"following dim: {dim}." ) @@ -189,8 +189,8 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: ref_virtual_dim_size = ref_tensor.shape[dim] indices = torch.argwhere( - torch.eq(ref_strides[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) - & torch.eq(ref_strides.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) + torch.eq(ref_basis[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) + & torch.eq(ref_basis.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) ) assert len(indices) <= 1 @@ -200,18 +200,18 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: pdim = ref_tensor.physical.ndim physicals = [t.physical.unsqueeze(-1) for t in tensors_] - new_stride_column = torch.zeros(ref_tensor.ndim, 1, dtype=torch.int64) - new_stride_column[dim, 0] = ref_virtual_dim_size - new_strides = torch.concatenate([ref_tensor.strides, new_stride_column], dim=1) + new_basis_vector = torch.zeros(ref_tensor.ndim, 1, dtype=torch.int64) + new_basis_vector[dim, 0] = ref_virtual_dim_size + new_basis = torch.concatenate([ref_tensor.basis, new_basis_vector], dim=1) else: # Such a physical dimension already exists. Note that an alternative implementation would be # to simply always add the physical dimension, and squash it if it ends up being not needed. physicals = [t.physical for t in tensors_] pdim = cast(int, indices[0, 0].item()) - new_strides = ref_tensor.strides + new_basis = ref_tensor.basis new_physical = aten.cat.default(physicals, dim=pdim) - return SparseLatticedTensor(new_physical, new_strides) + return SparseLatticedTensor(new_physical, new_basis) @impl(aten.expand.default) @@ -227,26 +227,26 @@ def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedT # Try to expand each dimension to its new size new_physical = t.physical - new_strides = t.strides - for d, (vstride, orig_size, new_size) in enumerate(zip(t.strides, t.shape, sizes, strict=True)): - if vstride.sum() > 0 and orig_size != new_size and new_size != -1: + new_basis = t.basis + for d, (v, orig_size, new_size) in enumerate(zip(t.basis, t.shape, sizes, strict=True)): + if v.sum() > 0 and orig_size != new_size and new_size != -1: raise ValueError( f"Cannot expand dim {d} of size != 1. Found size {orig_size} and target size " f"{new_size}." ) - if vstride.sum() == 0 and new_size != 1 and new_size != -1: + if v.sum() == 0 and new_size != 1 and new_size != -1: # Add a dimension of size new_size at the end of the physical tensor. new_physical_shape = list(new_physical.shape) + [new_size] new_physical = new_physical.unsqueeze(-1).expand(new_physical_shape) - # Make this new physical dimension have a stride of 1 at virtual dimension d and 0 at - # every other virtual dimension - new_stride_column = torch.zeros(t.ndim, 1, dtype=torch.int64) - new_stride_column[d, 0] = 1 - new_strides = torch.cat([new_strides, new_stride_column], dim=1) + # Make the basis vector of this new physical dimension be 1 at virtual dimension d and 0 + # at every other virtual dimension + new_basis_vector = torch.zeros(t.ndim, 1, dtype=torch.int64) + new_basis_vector[d, 0] = 1 + new_basis = torch.cat([new_basis, new_basis_vector], dim=1) - return SparseLatticedTensor(new_physical, new_strides) + return SparseLatticedTensor(new_physical, new_basis) @impl(aten.broadcast_tensors.default) @@ -279,7 +279,7 @@ def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: @impl(aten.transpose.int) def transpose_int(t: SparseLatticedTensor, dim0: int, dim1: int) -> SparseLatticedTensor: assert isinstance(t, SparseLatticedTensor) - return SparseLatticedTensor(t.physical, _swap_rows(t.strides, dim0, dim1)) + return SparseLatticedTensor(t.physical, _swap_rows(t.basis, dim0, dim1)) def _swap_rows(matrix: Tensor, c0: int, c1: int) -> Tensor: diff --git a/src/torchjd/sparse/_coalesce.py b/src/torchjd/sparse/_coalesce.py index 0da8c777d..ac9c8368f 100644 --- a/src/torchjd/sparse/_coalesce.py +++ b/src/torchjd/sparse/_coalesce.py @@ -2,18 +2,17 @@ from torch import Tensor -def fix_zero_stride_columns(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: +def fix_zero_basis_vectors(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor]: """ - Remove columns of strides that are all 0 and sum the corresponding elements in the physical - tensor. + Remove basis vectors that are all 0 and sum the corresponding elements in the physical tensor. """ - are_columns_zero = (strides == 0).all(dim=0) + are_vectors_zero = (basis == 0).all(dim=0) - if not are_columns_zero.any(): - return physical, strides + if not are_vectors_zero.any(): + return physical, basis - zero_column_indices = torch.arange(len(are_columns_zero))[are_columns_zero].tolist() + zero_column_indices = torch.arange(len(are_vectors_zero))[are_vectors_zero].tolist() physical = physical.sum(dim=zero_column_indices) - strides = strides[:, ~are_columns_zero] - return physical, strides + basis = basis[:, ~are_vectors_zero] + return physical, basis diff --git a/src/torchjd/sparse/_sparse_latticed_tensor.py b/src/torchjd/sparse/_sparse_latticed_tensor.py index 2f18b5336..e58f1f04e 100644 --- a/src/torchjd/sparse/_sparse_latticed_tensor.py +++ b/src/torchjd/sparse/_sparse_latticed_tensor.py @@ -14,8 +14,8 @@ class SparseLatticedTensor(Tensor): _HANDLED_FUNCTIONS = dict[Callable, Callable]() @staticmethod - def __new__(cls, physical: Tensor, strides: Tensor): - assert strides.dtype == torch.int64 + def __new__(cls, physical: Tensor, basis: Tensor): + assert basis.dtype == torch.int64 # Note [Passing requires_grad=true tensors to subclasses] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -28,23 +28,23 @@ def __new__(cls, physical: Tensor, strides: Tensor): assert not physical.requires_grad or not torch.is_grad_enabled() pshape = tensor(physical.shape, dtype=torch.int64) - vshape = strides @ (pshape - 1) + 1 + vshape = basis @ (pshape - 1) + 1 return Tensor._make_wrapper_subclass( cls, tuple(vshape.tolist()), dtype=physical.dtype, device=physical.device ) - def __init__(self, physical: Tensor, strides: Tensor): + def __init__(self, physical: Tensor, basis: Tensor): """ - This constructor is made for specifying physical and strides exactly. It should not modify + This constructor is made for specifying physical and basis exactly. It should not modify it. - For this reason, another constructor will be made to either modify the physical / strides to + For this reason, another constructor will be made to either modify the physical / basis to simplify the result, or to create a dense tensor directly if it's already dense. :param physical: The dense tensor holding the actual data. - :param strides: Integer (int64) tensor of shape [virtual_ndim, physical_ndim], representing + :param basis: Integer (int64) tensor of shape [virtual_ndim, physical_ndim], representing the linear transformation between an index in the physical tensor and the corresponding - index in the virtual tensor, i.e. v_index = strides @ p_index. + index in the virtual tensor, i.e. v_index = basis @ p_index. """ if any(s == 1 for s in physical.shape): @@ -52,30 +52,27 @@ def __init__(self, physical: Tensor, strides: Tensor): "physical must not contain any dimension of size 1. Found physical.shape=" f"{physical.shape}." ) - if strides.dtype is not torch.int64: + if basis.dtype is not torch.int64: + raise ValueError(f"basis should be of int64 dtype. Found basis.dtype={basis.dtype}.") + if not (basis >= 0).all(): + raise ValueError(f"All basis vectors must be non-negative. Found basis={basis}.") + if basis.shape[1] != physical.ndim: raise ValueError( - f"strides should be of int64 dtype. Found strides.dtype={strides.dtype}." - ) - if not (strides >= 0).all(): - raise ValueError(f"All strides must be non-negative. Found strides={strides}.") - if strides.shape[1] != physical.ndim: - raise ValueError( - f"strides should have 1 column per physical dimension. Found strides={strides} and " + f"basis should have 1 column per physical dimension. Found basis={basis} and " f"physical.shape={physical.shape}." ) - if (strides.sum(dim=0) == 0).any(): + if (basis.sum(dim=0) == 0).any(): raise ValueError( - f"strides should not have any column full of zeros. Found strides={strides}." + f"basis should not have any column full of zeros. Found basis={basis}." ) - groups = get_groupings(list(physical.shape), strides) + groups = get_groupings(list(physical.shape), basis) if any(len(group) != 1 for group in groups): raise ValueError( - f"Dimensions must be maximally grouped. Found strides={strides} and " - f"groups={groups}" + f"Dimensions must be maximally grouped. Found basis={basis} and " f"groups={groups}" ) self.physical = physical - self.strides = strides + self.basis = basis def to_dense( self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None @@ -90,7 +87,7 @@ def to_dense( p_indices_grid = stack(meshgrid(*p_index_ranges, indexing="ij")) # addmm_cuda not implemented for Long tensors => gotta have these tensors on cpu - v_indices_grid = tensordot(self.strides, p_indices_grid, dims=1) + v_indices_grid = tensordot(self.basis, p_indices_grid, dims=1) res = zeros(self.shape, device=self.device, dtype=self.dtype) res[tuple(v_indices_grid)] = self.physical return res @@ -108,14 +105,10 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return func(*unwrapped_args, **unwrapped_kwargs) def __repr__(self, *, tensor_contents=None) -> str: - return f"SparseLatticedTensor(physical={self.physical}, strides={self.strides})" + return f"SparseLatticedTensor(physical={self.physical}, basis={self.basis})" def debug_info(self) -> str: - info = ( - f"vshape: {self.shape}\n" - f"pshape: {self.physical.shape}\n" - f"strides: {self.strides}\n" - ) + info = f"vshape: {self.shape}\n" f"pshape: {self.physical.shape}\n" f"basis: {self.basis}\n" return info @classmethod @@ -137,7 +130,7 @@ def print_fallback(func, args, kwargs) -> None: def tensor_to_str(t: Tensor) -> str: result = f"{t.__class__.__name__} - vshape: {t.shape}" if isinstance(t, SparseLatticedTensor): - result += f" - pshape: {t.physical.shape} - strides: {t.strides}" + result += f" - pshape: {t.physical.shape} - basis: {t.basis}" return result @@ -189,12 +182,12 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]: return result -def get_groupings(pshape: list[int], strides: Tensor) -> list[list[int]]: - strides_time_pshape = strides * tensor(pshape, dtype=torch.int64) - groups = {i: {i} for i, column in enumerate(strides.T)} - group_ids = [i for i in range(len(strides.T))] - for i1, i2 in itertools.combinations(range(strides.shape[1]), 2): - if torch.equal(strides[:, i1], strides_time_pshape[:, i2]): +def get_groupings(pshape: list[int], basis: Tensor) -> list[list[int]]: + basis_time_pshape = basis * tensor(pshape, dtype=torch.int64) + groups = {i: {i} for i, column in enumerate(basis.T)} + group_ids = [i for i in range(len(basis.T))] + for i1, i2 in itertools.combinations(range(basis.shape[1]), 2): + if torch.equal(basis[:, i1], basis_time_pshape[:, i2]): groups[group_ids[i1]].update(groups[group_ids[i2]]) group_ids[i2] = group_ids[i1] @@ -210,18 +203,18 @@ def to_sparse_latticed_tensor(t: Tensor) -> SparseLatticedTensor: if isinstance(t, SparseLatticedTensor): return t else: - return make_sst(physical=t, strides=torch.eye(t.ndim, dtype=torch.int64)) + return make_sst(physical=t, basis=torch.eye(t.ndim, dtype=torch.int64)) -def to_most_efficient_tensor(physical: Tensor, strides: Tensor) -> Tensor: - physical, strides = fix_dim_of_size_1(physical, strides) - physical, strides = fix_ungrouped_dims(physical, strides) +def to_most_efficient_tensor(physical: Tensor, basis: Tensor) -> Tensor: + physical, basis = fix_dim_of_size_1(physical, basis) + physical, basis = fix_ungrouped_dims(physical, basis) - if (strides.sum(dim=0) == 1).all(): + if (basis.sum(dim=0) == 1).all(): # TODO: this can be done more efficiently (without even creating the SST) - return SparseLatticedTensor(physical, strides).to_dense() + return SparseLatticedTensor(physical, basis).to_dense() else: - return SparseLatticedTensor(physical, strides) + return SparseLatticedTensor(physical, basis) def unwrap_to_dense(t: Tensor): @@ -255,25 +248,25 @@ def get_full_source(source: list[int], destination: list[int], ndim: int) -> lis return idx.tolist() -def fix_dim_of_size_1(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: +def fix_dim_of_size_1(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor]: is_of_size_1 = tensor([s == 1 for s in physical.shape], dtype=torch.bool) - return physical.squeeze(), strides[:, ~is_of_size_1] + return physical.squeeze(), basis[:, ~is_of_size_1] -def fix_ungrouped_dims(physical: Tensor, strides: Tensor) -> tuple[Tensor, Tensor]: - groups = get_groupings(list(physical.shape), strides) +def fix_ungrouped_dims(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor]: + groups = get_groupings(list(physical.shape), basis) nphysical = physical.reshape([prod([physical.shape[dim] for dim in group]) for group in groups]) - stride_mapping = torch.zeros(physical.ndim, nphysical.ndim, dtype=torch.int64) + basis_mapping = torch.zeros(physical.ndim, nphysical.ndim, dtype=torch.int64) for j, group in enumerate(groups): - stride_mapping[group[-1], j] = 1 + basis_mapping[group[-1], j] = 1 - new_strides = strides @ stride_mapping - return nphysical, new_strides + new_basis = basis @ basis_mapping + return nphysical, new_basis -def make_sst(physical: Tensor, strides: Tensor) -> SparseLatticedTensor: - """Fix physical and strides and create a SparseLatticedTensor with them.""" +def make_sst(physical: Tensor, basis: Tensor) -> SparseLatticedTensor: + """Fix physical and basis and create a SparseLatticedTensor with them.""" - physical, strides = fix_dim_of_size_1(physical, strides) - physical, strides = fix_ungrouped_dims(physical, strides) - return SparseLatticedTensor(physical, strides) + physical, basis = fix_dim_of_size_1(physical, basis) + physical, basis = fix_ungrouped_dims(physical, basis) + return SparseLatticedTensor(physical, basis) diff --git a/tests/unit/sparse/test_sparse_latticed_tensor.py b/tests/unit/sparse/test_sparse_latticed_tensor.py index 684f2f32b..fe9a95fb5 100644 --- a/tests/unit/sparse/test_sparse_latticed_tensor.py +++ b/tests/unit/sparse/test_sparse_latticed_tensor.py @@ -11,7 +11,7 @@ _POINTWISE_FUNCTIONS, ) from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim -from torchjd.sparse._coalesce import fix_zero_stride_columns +from torchjd.sparse._coalesce import fix_zero_basis_vectors from torchjd.sparse._sparse_latticed_tensor import ( SparseLatticedTensor, fix_ungrouped_dims, @@ -41,7 +41,7 @@ def test_to_dense2(): @mark.parametrize( - ["a_pshape", "a_strides", "b_pshape", "b_strides", "a_indices", "b_indices", "output_indices"], + ["a_pshape", "a_basis", "b_pshape", "b_basis", "a_indices", "b_indices", "output_indices"], [ ( [4, 5], @@ -74,15 +74,15 @@ def test_to_dense2(): ) def test_einsum( a_pshape: list[int], - a_strides: Tensor, + a_basis: Tensor, b_pshape: list[int], - b_strides: Tensor, + b_basis: Tensor, a_indices: list[int], b_indices: list[int], output_indices: list[int], ): - a = SparseLatticedTensor(randn_(a_pshape), a_strides) - b = SparseLatticedTensor(randn_(b_pshape), b_strides) + a = SparseLatticedTensor(randn_(a_pshape), a_basis) + b = SparseLatticedTensor(randn_(b_pshape), b_basis) res = einsum((a, a_indices), (b, b_indices), output=output_indices) @@ -166,7 +166,7 @@ def test_unary(func): @mark.parametrize( - ["physical_shape", "strides", "target_shape", "expected_physical_shape", "expected_strides"], + ["physical_shape", "basis", "target_shape", "expected_physical_shape", "expected_basis"], [ ( [2, 3], @@ -249,25 +249,25 @@ def test_unary(func): ) def test_view( physical_shape: list[int], - strides: Tensor, + basis: Tensor, target_shape: list[int], expected_physical_shape: list[int], - expected_strides: Tensor, + expected_basis: Tensor, ): a = randn_(tuple(physical_shape)) - t = SparseLatticedTensor(a, strides) + t = SparseLatticedTensor(a, basis) result = aten.view.default(t, target_shape) expected = t.to_dense().reshape(target_shape) assert isinstance(result, SparseLatticedTensor) assert list(result.physical.shape) == expected_physical_shape - assert torch.equal(result.strides, expected_strides) + assert torch.equal(result.basis, expected_basis) assert torch.all(torch.eq(result.to_dense(), expected)) @mark.parametrize( - ["pshape", "strides", "expected"], + ["pshape", "basis", "expected"], [ ( [[32, 2, 3, 4, 5]], @@ -276,13 +276,13 @@ def test_view( ) ], ) -def test_get_groupings(pshape: list[int], strides: torch.Tensor, expected: list[list[int]]): - result = get_groupings(pshape, strides) +def test_get_groupings(pshape: list[int], basis: torch.Tensor, expected: list[list[int]]): + result = get_groupings(pshape, basis) assert result == expected @mark.parametrize( - ["physical_shape", "strides", "expected_physical_shape", "expected_strides"], + ["physical_shape", "basis", "expected_physical_shape", "expected_basis"], [ ( [3, 4, 5], @@ -301,25 +301,25 @@ def test_get_groupings(pshape: list[int], strides: torch.Tensor, expected: list[ ) def test_fix_ungrouped_dims( physical_shape: list[int], - strides: Tensor, + basis: Tensor, expected_physical_shape: list[int], - expected_strides: Tensor, + expected_basis: Tensor, ): physical = randn_(physical_shape) - fixed_physical, fixed_strides = fix_ungrouped_dims(physical, strides) + fixed_physical, fixed_basis = fix_ungrouped_dims(physical, basis) assert list(fixed_physical.shape) == expected_physical_shape - assert torch.equal(fixed_strides, expected_strides) + assert torch.equal(fixed_basis, expected_basis) @mark.parametrize( [ "physical_shape", - "strides", + "basis", "pdim", "new_pdim_shape", "expected_physical_shape", - "expected_strides", + "expected_basis", ], [ ([4], tensor([[1], [2]]), 0, [4], [4], tensor([[1], [2]])), # trivial @@ -336,17 +336,17 @@ def test_fix_ungrouped_dims( ) def test_unsquash_pdim( physical_shape: list[int], - strides: Tensor, + basis: Tensor, pdim: int, new_pdim_shape: list[int], expected_physical_shape: list[int], - expected_strides: Tensor, + expected_basis: Tensor, ): physical = randn_(physical_shape) - new_physical, new_strides = unsquash_pdim(physical, strides, pdim, new_pdim_shape) + new_physical, new_basis = unsquash_pdim(physical, basis, pdim, new_pdim_shape) assert list(new_physical.shape) == expected_physical_shape - assert torch.equal(new_strides, expected_strides) + assert torch.equal(new_basis, expected_basis) @mark.parametrize( @@ -380,7 +380,7 @@ def test_concatenate( sst_args: list[tuple[list[int], Tensor]], dim: int, ): - tensors = [SparseLatticedTensor(randn_(pshape), strides) for pshape, strides in sst_args] + tensors = [SparseLatticedTensor(randn_(pshape), basis) for pshape, basis in sst_args] res = aten.cat.default(tensors, dim) expected = aten.cat.default([t.to_dense() for t in tensors], dim) @@ -389,7 +389,7 @@ def test_concatenate( @mark.parametrize( - ["physical", "strides", "expected_physical", "expected_strides"], + ["physical", "basis", "expected_physical", "expected_basis"], [ ( tensor_([[1, 2, 3], [4, 5, 6]]), @@ -411,12 +411,12 @@ def test_concatenate( ), ], ) -def test_fix_zero_stride_columns( +def test_fix_zero_basis_vectors( physical: Tensor, - strides: Tensor, + basis: Tensor, expected_physical: Tensor, - expected_strides: Tensor, + expected_basis: Tensor, ): - physical, strides = fix_zero_stride_columns(physical, strides) + physical, basis = fix_zero_basis_vectors(physical, basis) assert torch.equal(physical, expected_physical) - assert torch.equal(strides, expected_strides) + assert torch.equal(basis, expected_basis) From 2e641c751cf178c785510c61f76079bf44a7df96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Sat, 22 Nov 2025 18:22:31 +0100 Subject: [PATCH 13/42] Rename SST to SLT --- src/torchjd/autogram/_engine.py | 4 ++-- src/torchjd/sparse/__init__.py | 2 +- src/torchjd/sparse/_aten_function_overrides/einsum.py | 6 +++--- src/torchjd/sparse/_aten_function_overrides/shape.py | 2 +- src/torchjd/sparse/_linalg.py | 4 ++-- src/torchjd/sparse/_sparse_latticed_tensor.py | 8 ++++---- tests/unit/sparse/test_sparse_latticed_tensor.py | 6 +++--- 7 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index d560b4efc..fac380b48 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -4,7 +4,7 @@ from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge -from torchjd.sparse import make_sst +from torchjd.sparse import make_slt from ._edge_registry import EdgeRegistry from ._gramian_accumulator import GramianAccumulator @@ -177,7 +177,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: output_dims = list(range(output.ndim)) identity = torch.eye(output.ndim, dtype=torch.int64) basis = torch.concatenate([identity, identity], dim=0) - jac_output = make_sst(torch.ones_like(output), basis) + jac_output = make_slt(torch.ones_like(output), basis) vmapped_diff = differentiation for _ in output_dims: diff --git a/src/torchjd/sparse/__init__.py b/src/torchjd/sparse/__init__.py index 071ad680a..537a29ce3 100644 --- a/src/torchjd/sparse/__init__.py +++ b/src/torchjd/sparse/__init__.py @@ -1,3 +1,3 @@ # Need to import this to execute the code inside and thus to override the functions from . import _aten_function_overrides -from ._sparse_latticed_tensor import SparseLatticedTensor, make_sst +from ._sparse_latticed_tensor import SparseLatticedTensor, make_slt diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index c5bc2b273..376ec2b70 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -131,8 +131,8 @@ def prepare_for_elementwise_op( t1: Tensor | int | float, t2: Tensor | int | float ) -> tuple[SparseLatticedTensor, SparseLatticedTensor]: """ - Prepares two SSTs of the same shape from two args, one of those being a SST, and the other being - a SST, Tensor, int or float. + Prepares two SLTs of the same shape from two args, one of those being a SLT, and the other being + a SLT, Tensor, int or float. """ assert isinstance(t1, SparseLatticedTensor) or isinstance(t2, SparseLatticedTensor) @@ -172,7 +172,7 @@ def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: @impl(aten.mul.Scalar) def mul_Scalar(t: SparseLatticedTensor, scalar) -> SparseLatticedTensor: - # TODO: maybe it could be that scalar is a scalar SST and t is a normal tensor. Need to check + # TODO: maybe it could be that scalar is a scalar SLT and t is a normal tensor. Need to check # that assert isinstance(t, SparseLatticedTensor) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 34272c8be..8d18c3a65 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -178,7 +178,7 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: ref_basis = ref_tensor.basis if any(not torch.equal(t.basis, ref_basis) for t in tensors_[1:]): raise NotImplementedError( - "Override for aten.cat.default does not support SSTs that do not all have the same " + "Override for aten.cat.default does not support SLTs that do not all have the same " f"basis. Found the following tensors:\n{[t.debug_info() for t in tensors_]} and the " f"following dim: {dim}." ) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 26f60a2f3..82c054621 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -179,8 +179,8 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: # S1 = G @ K1 # S2 = G @ K2 # - # SST(p1, S1) = SST(SST(p1, K1), G) - # SST(p2, S2) = SST(SST(p2, K2), G) + # SLT(p1, S1) = SLT(SLT(p1, K1), G) + # SLT(p2, S2) = SLT(SLT(p2, K2), G) col_magnitudes = torch.sum(torch.abs(H), dim=0) non_zero_indices = torch.nonzero(col_magnitudes, as_tuple=True)[0] diff --git a/src/torchjd/sparse/_sparse_latticed_tensor.py b/src/torchjd/sparse/_sparse_latticed_tensor.py index e58f1f04e..8646a4a7f 100644 --- a/src/torchjd/sparse/_sparse_latticed_tensor.py +++ b/src/torchjd/sparse/_sparse_latticed_tensor.py @@ -160,7 +160,7 @@ def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]: Example: Imagine a vector of size 3, and of value [1, 2, 3]. - Imagine a SST t of shape [3, 3] using this vector as physical and using [[0, 0]] as v_to_ps. + Imagine a SLT t of shape [3, 3] using this vector as physical and using [[0, 0]] as v_to_ps. t.to_dense() is [1, 0, 0, 0, 2, 0, 0, 0, 3] (it's the flattening of the diagonal matrix [[1, 0, 0], [0, 2, 0], [0, 0, 3]]). When you move by 1 on physical dimension 0, you move by 4 on virtual dimension 0, i.e. @@ -203,7 +203,7 @@ def to_sparse_latticed_tensor(t: Tensor) -> SparseLatticedTensor: if isinstance(t, SparseLatticedTensor): return t else: - return make_sst(physical=t, basis=torch.eye(t.ndim, dtype=torch.int64)) + return make_slt(physical=t, basis=torch.eye(t.ndim, dtype=torch.int64)) def to_most_efficient_tensor(physical: Tensor, basis: Tensor) -> Tensor: @@ -211,7 +211,7 @@ def to_most_efficient_tensor(physical: Tensor, basis: Tensor) -> Tensor: physical, basis = fix_ungrouped_dims(physical, basis) if (basis.sum(dim=0) == 1).all(): - # TODO: this can be done more efficiently (without even creating the SST) + # TODO: this can be done more efficiently (without even creating the SLT) return SparseLatticedTensor(physical, basis).to_dense() else: return SparseLatticedTensor(physical, basis) @@ -264,7 +264,7 @@ def fix_ungrouped_dims(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor] return nphysical, new_basis -def make_sst(physical: Tensor, basis: Tensor) -> SparseLatticedTensor: +def make_slt(physical: Tensor, basis: Tensor) -> SparseLatticedTensor: """Fix physical and basis and create a SparseLatticedTensor with them.""" physical, basis = fix_dim_of_size_1(physical, basis) diff --git a/tests/unit/sparse/test_sparse_latticed_tensor.py b/tests/unit/sparse/test_sparse_latticed_tensor.py index fe9a95fb5..55b723c67 100644 --- a/tests/unit/sparse/test_sparse_latticed_tensor.py +++ b/tests/unit/sparse/test_sparse_latticed_tensor.py @@ -370,17 +370,17 @@ def test_get_column_indices(source: list[int], destination: list[int], ndim: int @mark.parametrize( - ["sst_args", "dim"], + ["slt_args", "dim"], [ ([([3], tensor([[1], [1]])), ([3], tensor([[1], [1]]))], 1), ([([3, 2], tensor([[1, 0], [1, 3]])), ([3, 2], tensor([[1, 0], [1, 3]]))], 1), ], ) def test_concatenate( - sst_args: list[tuple[list[int], Tensor]], + slt_args: list[tuple[list[int], Tensor]], dim: int, ): - tensors = [SparseLatticedTensor(randn_(pshape), basis) for pshape, basis in sst_args] + tensors = [SparseLatticedTensor(randn_(pshape), basis) for pshape, basis in slt_args] res = aten.cat.default(tensors, dim) expected = aten.cat.default([t.to_dense() for t in tensors], dim) From 3e9e7d496bc7c1ae88c0d686f7dbf73fbf7aa39f Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sun, 23 Nov 2025 11:18:44 +0100 Subject: [PATCH 14/42] Fix usage of `unsqueeze` on SLT to call `unsqueeze_default` instead --- src/torchjd/sparse/_aten_function_overrides/shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 8d18c3a65..c94cd3758 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -223,7 +223,7 @@ def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedT # Add as many dimensions as needed at the beginning of the tensor (as torch.expand works) for _ in range(len(sizes) - t.ndim): - t = t.unsqueeze(0) + t = unsqueeze_default(t, 0) # Try to expand each dimension to its new size new_physical = t.physical From c6b19c769d858f62f3a4d162a99d305bee0e0645 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sun, 23 Nov 2025 11:45:52 +0100 Subject: [PATCH 15/42] Make `hnf_decomposition` return the reduced HNF rather than the HNF. --- src/torchjd/sparse/_linalg.py | 50 ++++++++++++++++++++--------------- 1 file changed, 28 insertions(+), 22 deletions(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 82c054621..7e7c5f3fa 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -42,15 +42,19 @@ def extended_gcd(a: int, b: int) -> tuple[int, int, int]: def hnf_decomposition(A: Tensor) -> tuple[Tensor, Tensor, Tensor]: """ - Computes the Hermite Normal Form decomposition using PyTorch. + Computes the reduced Hermite Normal Form decomposition using PyTorch. For a matrix A (m x n) of + rank r, computes the matrices H (m x r), U (n x r) and V (r x n) such that + V U = I_r + A = H V + H = A U Args: A: (m x n) torch.Tensor (dtype=torch.long) Returns: - H: (m x n) Canonical Lower Triangular HNF - U: (n x n) Unimodular transform (A @ U = H) - V: (n x n) Inverse Unimodular transform (H @ V = A) + H: (m x r) Canonical Lower Triangular HNF + U: (n x r) Unimodular transform (A @ U = H) + V: (r x n) Right inverse Unimodular transform (H @ V = A) """ H = A.clone().to(dtype=torch.long) @@ -143,7 +147,19 @@ def hnf_decomposition(A: Tensor) -> tuple[Tensor, Tensor, Tensor]: row += 1 col += 1 - return H, U, V + col_magnitudes = torch.sum(torch.abs(H), dim=0) + non_zero_indices = torch.nonzero(col_magnitudes, as_tuple=True)[0] + + if len(non_zero_indices) == 0: + rank = 0 + else: + rank = non_zero_indices.max().item() + 1 + + reduced_H = H[:, :rank] + reduced_U = U[:, :rank] + reduced_V = V[:rank, :] + + return reduced_H, reduced_U, reduced_V def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: @@ -151,20 +167,21 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: Computes the GCD and the projection factors. i.e. S1 = G @ K1 S2 = G @ K2 + with G having minimal rank. Args: S1, S2: torch.Tensors (m x n1), (m x n2) Returns: - G: (m x m) The Matrix GCD (Canonical Base) - K1: (m x n1) Factors for S1 - K2: (m x n2) Factors for S2 + G: (m x r) The Matrix GCD (Canonical Base) + K1: (r x n1) Factors for S1 + K2: (r x n2) Factors for S2 """ assert S1.shape[0] == S2.shape[0], "Virtual dimension mismatch" m, n1 = S1.shape A = torch.cat([S1, S2], dim=1) - H, U, V = hnf_decomposition(A) + G, U, V = hnf_decomposition(A) # H = [S1 | S2] @ U # [S1 | S2] = H @ V @@ -182,19 +199,8 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: # SLT(p1, S1) = SLT(SLT(p1, K1), G) # SLT(p2, S2) = SLT(SLT(p2, K2), G) - col_magnitudes = torch.sum(torch.abs(H), dim=0) - non_zero_indices = torch.nonzero(col_magnitudes, as_tuple=True)[0] - - if len(non_zero_indices) == 0: - rank = 0 - else: - rank = non_zero_indices.max().item() + 1 - - G = H[:, :rank] - V_active = V[:rank, :] - - K1 = V_active[:, :n1] - K2 = V_active[:, n1:] + K1 = V[:, :n1] + K2 = V[:, n1:] return G, K1, K2 From 80ffb142d3ce7a6afffca5f076348f74b00d6519 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Sun, 23 Nov 2025 14:08:44 +0100 Subject: [PATCH 16/42] Improve `hnf_decomposition` and add a test for it (failing) --- src/torchjd/sparse/_linalg.py | 104 +++++++++++-------------------- tests/unit/sparse/test_linalg.py | 31 +++++++++ 2 files changed, 66 insertions(+), 69 deletions(-) create mode 100644 tests/unit/sparse/test_linalg.py diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 7e7c5f3fa..9c53fa64a 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -57,103 +57,69 @@ def hnf_decomposition(A: Tensor) -> tuple[Tensor, Tensor, Tensor]: V: (r x n) Right inverse Unimodular transform (H @ V = A) """ - H = A.clone().to(dtype=torch.long) + H = A.clone() m, n = H.shape - U = torch.eye(n, dtype=torch.long) - V = torch.eye(n, dtype=torch.long) + U = torch.eye(n, dtype=A.dtype) + V = torch.eye(n, dtype=A.dtype) - row = 0 col = 0 - while row < m and col < n: - # --- 1. Pivot Selection --- - # Find first non-zero entry in current row from col onwards - pivot_idx = -1 + for row in range(m): + if n <= col: + break + row_slice = H[row, col:n] + nonzero_indices = torch.nonzero(row_slice) - # We extract the row slice to CPU for faster scalar checks if on GPU - # or just iterate. For HNF, strictly sequential loop is often easiest. - for j in range(col, n): - if H[row, j] != 0: - pivot_idx = j - break - - if pivot_idx == -1: - row += 1 + if nonzero_indices.numel() > 0: + relative_pivot_idx = nonzero_indices[0][0].item() + pivot_idx = col + relative_pivot_idx + else: continue - # Swap to current column if pivot_idx != col: - # Swap Columns in H and U H[:, [col, pivot_idx]] = H[:, [pivot_idx, col]] U[:, [col, pivot_idx]] = U[:, [pivot_idx, col]] - # Swap ROWS in V V[[col, pivot_idx], :] = V[[pivot_idx, col], :] - # --- 2. Gaussian Elimination via GCD --- for j in range(col + 1, n): if H[row, j] != 0: - # Extract values as python ints for GCD logic a_val = H[row, col].item() b_val = H[row, j].item() g, x, y = extended_gcd(a_val, b_val) - # Bezout: a*x + b*y = g - # c1 = -b // g, c2 = a // g c1 = -b_val // g c2 = a_val // g - # --- Update H (Column Ops) --- - # Important: Clone columns to avoid in-place modification issues during calc - col_c = H[:, col].clone() - col_j = H[:, j].clone() - - H[:, col] = col_c * x + col_j * y - H[:, j] = col_c * c1 + col_j * c2 - - # --- Update U (Column Ops) --- - u_c = U[:, col].clone() - u_j = U[:, j].clone() - U[:, col] = u_c * x + u_j * y - U[:, j] = u_c * c1 + u_j * c2 - - # --- Update V (Inverse Row Ops) --- - # Inverse of [[x, c1], [y, c2]] is [[c2, -c1], [-y, x]] - v_r_c = V[col, :].clone() - v_r_j = V[j, :].clone() - V[col, :] = v_r_c * c2 - v_r_j * c1 - V[j, :] = v_r_c * (-y) + v_r_j * x - - # --- 3. Enforce Positive Diagonal --- - if H[row, col] < 0: - H[:, col] *= -1 - U[:, col] *= -1 - V[col, :] *= -1 - - # --- 4. Canonical Reduction (Modulo) --- - # Ensure 0 <= H[row, k] < H[row, col] for k < col - pivot_val = H[row, col].clone() - if pivot_val != 0: - for j in range(col): - # floor division - factor = torch.div(H[row, j], pivot_val, rounding_mode="floor") + H_col = H[:, col] + H_j = H[:, j] + + H[:, [col, j]] = torch.stack([H_col * x + H_j * y, H_col * c1 + H_j * c2], dim=1) + + U_col = U[:, col] + U_j = U[:, j] + U[:, [col, j]] = torch.stack([U_col * x + U_j * y, U_col * c1 + U_j * c2], dim=1) - if factor != 0: - H[:, j] -= factor * H[:, col] - U[:, j] -= factor * U[:, col] - V[col, :] += factor * V[j, :] + V_row_c = V[col, :] + V_row_j = V[j, :] + V[[col, j], :] = torch.stack( + [V_row_c * c2 - V_row_j * c1, V_row_c * (-y) + V_row_j * x], dim=0 + ) + + pivot_val = H[row, col] + + if pivot_val != 0: + H_row_prefix = H[row, 0:col] + factors = torch.div(H_row_prefix, pivot_val, rounding_mode="floor") + H[:, 0:col] -= factors.unsqueeze(0) * H[:, col].unsqueeze(1) + U[:, 0:col] -= factors.unsqueeze(0) * U[:, col].unsqueeze(1) + V[col, :] += factors @ V[0:col, :] - row += 1 col += 1 col_magnitudes = torch.sum(torch.abs(H), dim=0) - non_zero_indices = torch.nonzero(col_magnitudes, as_tuple=True)[0] - - if len(non_zero_indices) == 0: - rank = 0 - else: - rank = non_zero_indices.max().item() + 1 + rank = torch.count_nonzero(col_magnitudes).item() reduced_H = H[:, :rank] reduced_U = U[:, :rank] diff --git a/tests/unit/sparse/test_linalg.py b/tests/unit/sparse/test_linalg.py new file mode 100644 index 000000000..b3f0f4aab --- /dev/null +++ b/tests/unit/sparse/test_linalg.py @@ -0,0 +1,31 @@ +import torch +from pytest import mark + +from torchjd.sparse._linalg import hnf_decomposition + + +@mark.parametrize( + ["shape", "max_rank"], + [ + ([5, 7], 3), + ([1, 7], 1), + ([5, 1], 1), + ([7, 5], 2), + ([5, 7], 5), + ([7, 5], 5), + ], +) +def test_hnf_decomposition(shape: tuple[int, int], max_rank: int): + # Generate a matrix A of desired shape and rank max_rank with high probability and lower + # otherwise. + U = torch.randint(-50, 51, [shape[0], max_rank], dtype=torch.int64) + V = torch.randint(-50, 51, [max_rank, shape[1]], dtype=torch.int64) + A = U @ V + H, U, V = hnf_decomposition(A) + + rank = H.shape[1] + + assert rank <= max_rank + assert torch.equal(V @ U, torch.eye(rank, dtype=torch.int64)) + assert torch.equal(H @ V, A) + assert torch.equal(A @ U, H) From ba9bf21762d55483fbc06670698e78307e2fb725 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 24 Nov 2025 14:31:37 +0100 Subject: [PATCH 17/42] Reduce range of basis values to make the test pass. --- tests/unit/sparse/test_linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/unit/sparse/test_linalg.py b/tests/unit/sparse/test_linalg.py index b3f0f4aab..ed8dad903 100644 --- a/tests/unit/sparse/test_linalg.py +++ b/tests/unit/sparse/test_linalg.py @@ -18,8 +18,8 @@ def test_hnf_decomposition(shape: tuple[int, int], max_rank: int): # Generate a matrix A of desired shape and rank max_rank with high probability and lower # otherwise. - U = torch.randint(-50, 51, [shape[0], max_rank], dtype=torch.int64) - V = torch.randint(-50, 51, [max_rank, shape[1]], dtype=torch.int64) + U = torch.randint(-10, 11, [shape[0], max_rank], dtype=torch.int64) + V = torch.randint(-10, 11, [max_rank, shape[1]], dtype=torch.int64) A = U @ V H, U, V = hnf_decomposition(A) From 20db2c0ca5615ebb9169afd7c5b888d7bdad74dc Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 24 Nov 2025 14:43:56 +0100 Subject: [PATCH 18/42] Test additional properties of H. --- tests/unit/sparse/test_linalg.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/unit/sparse/test_linalg.py b/tests/unit/sparse/test_linalg.py index ed8dad903..eacd79012 100644 --- a/tests/unit/sparse/test_linalg.py +++ b/tests/unit/sparse/test_linalg.py @@ -25,7 +25,18 @@ def test_hnf_decomposition(shape: tuple[int, int], max_rank: int): rank = H.shape[1] + # Note that with these assert, the rank is typically correct as it is at most max_rank, which it + # is with high probability, and we can reconstruct A=H @ V, so the rank of H is at least that of + # A, similarly, the rank of H is at most that of A. assert rank <= max_rank assert torch.equal(V @ U, torch.eye(rank, dtype=torch.int64)) assert torch.equal(H @ V, A) assert torch.equal(A @ U, H) + + # Check H is upper triangular + mask = torch.triu(torch.ones(shape[0], rank, dtype=torch.bool), diagonal=1) + assert torch.all(H[mask] == 0).item() + + # Check pivots are positive + pivots = H.diag()[:rank] + return torch.all(pivots > 0).item() From 361a5f777b2d344e2913be419186b8d764c4903e Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 24 Nov 2025 14:56:45 +0100 Subject: [PATCH 19/42] Add implementation explanation in `computer_gcd` --- src/torchjd/sparse/_linalg.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 9c53fa64a..865884888 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -135,6 +135,16 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: S2 = G @ K2 with G having minimal rank. + Implementation logic: + The concatenated matrix [S1 | S2] spans exactly the sum of the lattices generated by S1 and S2. + This is because S1 @ u1 + S2 @ u2 = [S1 | S2] @ [u1.T | u2.T].T + The reduced HNF decomposition of [S1 | S2] yields G, U, V where the G.shape[1] is the rank of + [S1 | S2] and [S1 | S2] = G @ V. This means that + S1 = G @ V[:, :m1] + S2 = G @ V[:, m1:] + This is the target common factorization. It is the greatest as the lattice spanned by G is the + same as that spanned by [S1 | S2]. + Args: S1, S2: torch.Tensors (m x n1), (m x n2) @@ -149,22 +159,6 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: A = torch.cat([S1, S2], dim=1) G, U, V = hnf_decomposition(A) - # H = [S1 | S2] @ U - # [S1 | S2] = H @ V - # - # S1 = H @ V[:, :m1] - # S2 = H @ V[:, m1:] - # - # K1 = V[:, :m1] - # K2 = V[:, m1:] - # G = H - # - # S1 = G @ K1 - # S2 = G @ K2 - # - # SLT(p1, S1) = SLT(SLT(p1, K1), G) - # SLT(p2, S2) = SLT(SLT(p2, K2), G) - K1 = V[:, :n1] K2 = V[:, n1:] From f9c2cff13b715c4ac37519b09f42cc28d1a8f044 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 24 Nov 2025 15:33:29 +0100 Subject: [PATCH 20/42] Add the `reduced` parameter to `hnf_decomposition` --- src/torchjd/sparse/_linalg.py | 29 +++++++++++++++-------------- tests/unit/sparse/test_linalg.py | 2 +- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 865884888..542415210 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -40,16 +40,18 @@ def extended_gcd(a: int, b: int) -> tuple[int, int, int]: return g, x - (b // a) * y, y -def hnf_decomposition(A: Tensor) -> tuple[Tensor, Tensor, Tensor]: +def hnf_decomposition(A: Tensor, reduced: bool) -> tuple[Tensor, Tensor, Tensor]: """ - Computes the reduced Hermite Normal Form decomposition using PyTorch. For a matrix A (m x n) of - rank r, computes the matrices H (m x r), U (n x r) and V (r x n) such that + Computes the reduced Hermite Normal Form decomposition using PyTorch. For a matrix A (m x n) + computes the matrices H (m x r), U (n x r) and V (r x n) such that V U = I_r A = H V H = A U + where r is the rank of A if reduced is True, and otherwise r is n. Args: A: (m x n) torch.Tensor (dtype=torch.long) + reduced: Reduce to rank if True. Returns: H: (m x r) Canonical Lower Triangular HNF @@ -118,14 +120,15 @@ def hnf_decomposition(A: Tensor) -> tuple[Tensor, Tensor, Tensor]: col += 1 - col_magnitudes = torch.sum(torch.abs(H), dim=0) - rank = torch.count_nonzero(col_magnitudes).item() + if reduced: + col_magnitudes = torch.sum(torch.abs(H), dim=0) + rank = torch.count_nonzero(col_magnitudes).item() - reduced_H = H[:, :rank] - reduced_U = U[:, :rank] - reduced_V = V[:rank, :] + H = H[:, :rank] + U = U[:, :rank] + V = V[:rank, :] - return reduced_H, reduced_U, reduced_V + return H, U, V def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: @@ -157,7 +160,7 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: m, n1 = S1.shape A = torch.cat([S1, S2], dim=1) - G, U, V = hnf_decomposition(A) + G, _, V = hnf_decomposition(A, True) K1 = V[:, :n1] K2 = V[:, n1:] @@ -180,9 +183,7 @@ def compute_lcm(S1, S2): # 1. Kernel Setup: [S1 | -S2] B = torch.cat([S1, -S2], dim=1) - - # 2. Decompose to find Kernel - H_B, U_B, _ = hnf_decomposition(B) + H_B, U_B, _ = hnf_decomposition(B, False) # 3. Find Zero Columns in H_B (Kernel basis) # Sum abs values down columns @@ -205,6 +206,6 @@ def compute_lcm(S1, S2): # 6. Canonicalize L # The generators might be redundant or non-square. # Run HNF one last time to get the unique square LCM matrix. - L, _, _ = hnf_decomposition(L_generators) + L, _, _ = hnf_decomposition(L_generators, False) return L[:, :m] diff --git a/tests/unit/sparse/test_linalg.py b/tests/unit/sparse/test_linalg.py index eacd79012..9e298160b 100644 --- a/tests/unit/sparse/test_linalg.py +++ b/tests/unit/sparse/test_linalg.py @@ -21,7 +21,7 @@ def test_hnf_decomposition(shape: tuple[int, int], max_rank: int): U = torch.randint(-10, 11, [shape[0], max_rank], dtype=torch.int64) V = torch.randint(-10, 11, [max_rank, shape[1]], dtype=torch.int64) A = U @ V - H, U, V = hnf_decomposition(A) + H, U, V = hnf_decomposition(A, True) rank = H.shape[1] From 45d044d141105360f0f13ee0aacc581948cc6fff Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 24 Nov 2025 15:34:16 +0100 Subject: [PATCH 21/42] Improve documentation of `compute_gcd` --- src/torchjd/sparse/_linalg.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 542415210..1fd27d236 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -133,10 +133,18 @@ def hnf_decomposition(A: Tensor, reduced: bool) -> tuple[Tensor, Tensor, Tensor] def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: """ - Computes the GCD and the projection factors. i.e. + Computes a GCD and the projection factors, i.e. S1 = G @ K1 S2 = G @ K2 - with G having minimal rank. + with G having minimal rank r. + + Args: + S1, S2: torch.Tensors (m x n1), (m x n2) + + Returns: + G: (m x r) The Matrix GCD (Canonical Base) + K1: (r x n1) Factors for S1 + K2: (r x n2) Factors for S2 Implementation logic: The concatenated matrix [S1 | S2] spans exactly the sum of the lattices generated by S1 and S2. @@ -147,16 +155,8 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: S2 = G @ V[:, m1:] This is the target common factorization. It is the greatest as the lattice spanned by G is the same as that spanned by [S1 | S2]. - - Args: - S1, S2: torch.Tensors (m x n1), (m x n2) - - Returns: - G: (m x r) The Matrix GCD (Canonical Base) - K1: (r x n1) Factors for S1 - K2: (r x n2) Factors for S2 """ - assert S1.shape[0] == S2.shape[0], "Virtual dimension mismatch" + assert S1.shape[0] == S2.shape[0] m, n1 = S1.shape A = torch.cat([S1, S2], dim=1) From 9c2d6e7321dca50393d044897f5714d1ecdf0b59 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 24 Nov 2025 15:37:48 +0100 Subject: [PATCH 22/42] Add `get_hermit_factor_rank` --- src/torchjd/sparse/_linalg.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 1fd27d236..38b2d1a2a 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -40,6 +40,14 @@ def extended_gcd(a: int, b: int) -> tuple[int, int, int]: return g, x - (b // a) * y, y +def _get_hermite_factor_rank(H: Tensor) -> int: + """ + Computes the rank of a hermit factor matrix. + """ + col_magnitudes = torch.sum(torch.abs(H), dim=0) + return torch.count_nonzero(col_magnitudes).item() + + def hnf_decomposition(A: Tensor, reduced: bool) -> tuple[Tensor, Tensor, Tensor]: """ Computes the reduced Hermite Normal Form decomposition using PyTorch. For a matrix A (m x n) @@ -121,8 +129,7 @@ def hnf_decomposition(A: Tensor, reduced: bool) -> tuple[Tensor, Tensor, Tensor] col += 1 if reduced: - col_magnitudes = torch.sum(torch.abs(H), dim=0) - rank = torch.count_nonzero(col_magnitudes).item() + rank = _get_hermite_factor_rank(H) H = H[:, :rank] U = U[:, :rank] From 73347c513a77614d91963580bd05ff933d971cb8 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 24 Nov 2025 15:44:01 +0100 Subject: [PATCH 23/42] Test `reduced=False` in `hnf_decomposition` --- src/torchjd/sparse/_linalg.py | 3 ++- tests/unit/sparse/test_linalg.py | 18 +++++++++++------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 38b2d1a2a..548b344b8 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -55,7 +55,8 @@ def hnf_decomposition(A: Tensor, reduced: bool) -> tuple[Tensor, Tensor, Tensor] V U = I_r A = H V H = A U - where r is the rank of A if reduced is True, and otherwise r is n. + where r is the rank of A if reduced is True, and otherwise r is n. In the later case, this also + satisfies U V = I. Args: A: (m x n) torch.Tensor (dtype=torch.long) diff --git a/tests/unit/sparse/test_linalg.py b/tests/unit/sparse/test_linalg.py index 9e298160b..0a407d4f2 100644 --- a/tests/unit/sparse/test_linalg.py +++ b/tests/unit/sparse/test_linalg.py @@ -15,28 +15,32 @@ ([7, 5], 5), ], ) -def test_hnf_decomposition(shape: tuple[int, int], max_rank: int): +@mark.parametrize("reduced", [True, False]) +def test_hnf_decomposition(shape: tuple[int, int], max_rank: int, reduced: bool): # Generate a matrix A of desired shape and rank max_rank with high probability and lower # otherwise. U = torch.randint(-10, 11, [shape[0], max_rank], dtype=torch.int64) V = torch.randint(-10, 11, [max_rank, shape[1]], dtype=torch.int64) A = U @ V - H, U, V = hnf_decomposition(A, True) + H, U, V = hnf_decomposition(A, reduced) - rank = H.shape[1] + r = H.shape[1] # Note that with these assert, the rank is typically correct as it is at most max_rank, which it # is with high probability, and we can reconstruct A=H @ V, so the rank of H is at least that of # A, similarly, the rank of H is at most that of A. - assert rank <= max_rank - assert torch.equal(V @ U, torch.eye(rank, dtype=torch.int64)) + if reduced: + assert r <= max_rank + else: + assert torch.equal(U @ V, torch.eye(r, dtype=torch.int64)) + assert torch.equal(V @ U, torch.eye(r, dtype=torch.int64)) assert torch.equal(H @ V, A) assert torch.equal(A @ U, H) # Check H is upper triangular - mask = torch.triu(torch.ones(shape[0], rank, dtype=torch.bool), diagonal=1) + mask = torch.triu(torch.ones(shape[0], r, dtype=torch.bool), diagonal=1) assert torch.all(H[mask] == 0).item() # Check pivots are positive - pivots = H.diag()[:rank] + pivots = H.diag()[:r] return torch.all(pivots > 0).item() From 997acf7c64e13da129af139f0d3675dead33b28d Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Mon, 24 Nov 2025 16:05:45 +0100 Subject: [PATCH 24/42] Improve (or is it fix?) implementation of `compute_lcm` as well as improve documentation. --- src/torchjd/sparse/_linalg.py | 63 +++++++++++++++++------------------ 1 file changed, 30 insertions(+), 33 deletions(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 548b344b8..784f53965 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -129,6 +129,7 @@ def hnf_decomposition(A: Tensor, reduced: bool) -> tuple[Tensor, Tensor, Tensor] col += 1 + # TODO: Should actually make 2 functions, one for full and one for reduced if reduced: rank = _get_hermite_factor_rank(H) @@ -176,44 +177,40 @@ def compute_gcd(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: return G, K1, K2 -def compute_lcm(S1, S2): +def compute_lcm(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: """ - Computes the Matrix LCM (L) and the Multiples (M1, M2), i.e. + Computes a LCM and the projection multipliers, i.e. L = S1 @ M1 = S2 @ M2 + with L having maximal rank r. + + Args: + S1, S2: torch.Tensors (m x n1), (m x n2) Returns: - L: (m x m) The Matrix LCM - M1: (n1 x m) Factor such that L = S1 @ M1 - M2: (n2 x m) Factor such that L = S2 @ M2 + L: (m x r) The Matrix LCM + M1: (n1 x r) Multiplier for S1 + M2: (n2 x r) Multiplier for S2 + + Implementation logic: + The lattice kernel of the concatenated matrix [S1 | -S2] is the set of all vectors + [u1.T | u2.T].T such that S1 @ u1 - S2 @ u2 = 0, or equivalently S1 @ u1 = S2 @ u2. + This means that the image of the components of the kernel through S1 and S2 respectively are the + same which is exactly the intersection of the lattices generated by S1 and S2. + The full HNF decomposition of [S1 | -S2] yields H, U, V such that + H = [S1 | -S2] @ U + If [S1 | -S2] has rank r', then every column of H after the first r' contain only zeros, and + therefore U[:, r':] spans the kernel of [S1 | -S2]. We have + S1 @ U[:n1, r':] = S2 @ U[n1:, r':] + which yields the desired decomposition with r=n1+n2-r'. """ - m = S1.shape[0] - n1 = S1.shape[1] + assert S1.shape[0] == S2.shape[0] + m, n1 = S1.shape - # 1. Kernel Setup: [S1 | -S2] B = torch.cat([S1, -S2], dim=1) - H_B, U_B, _ = hnf_decomposition(B, False) - - # 3. Find Zero Columns in H_B (Kernel basis) - # Sum abs values down columns - col_mags = torch.sum(torch.abs(H_B), dim=0) - zero_indices = torch.nonzero(col_mags == 0, as_tuple=True)[0] - - if len(zero_indices) == 0: - return torch.zeros((m, m), dtype=torch.long) - - # 4. Extract Kernel Basis - # U_B columns corresponding to H_B zeros are the kernel generators - kernel_basis = U_B[:, zero_indices] - - # 5. Map back to Image Space - # The kernel vector is [u; v]. We need u (top n1 rows). - # Intersection = S1 @ u - u_parts = kernel_basis[:n1, :] - L_generators = S1 @ u_parts - - # 6. Canonicalize L - # The generators might be redundant or non-square. - # Run HNF one last time to get the unique square LCM matrix. - L, _, _ = hnf_decomposition(L_generators, False) + H, U, _ = hnf_decomposition(B, False) - return L[:, :m] + rank = _get_hermite_factor_rank(H) + M1 = U[:n1, rank:] + M2 = U[n1:, rank:] + L = S1 @ M1 + return L, M1, M2 From 8e17a77509affbb086b43af576313b04214572b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 24 Nov 2025 18:37:41 +0100 Subject: [PATCH 25/42] Remove strides_v2 --- src/torchjd/sparse/_sparse_latticed_tensor.py | 32 ------------------- 1 file changed, 32 deletions(-) diff --git a/src/torchjd/sparse/_sparse_latticed_tensor.py b/src/torchjd/sparse/_sparse_latticed_tensor.py index 8646a4a7f..f7afb77ee 100644 --- a/src/torchjd/sparse/_sparse_latticed_tensor.py +++ b/src/torchjd/sparse/_sparse_latticed_tensor.py @@ -1,8 +1,6 @@ import itertools -import operator from collections.abc import Callable from functools import wraps -from itertools import accumulate from math import prod import torch @@ -152,36 +150,6 @@ def tensor_to_str(t: Tensor) -> str: print() -def strides_v2(p_dims: list[int], physical_shape: list[int]) -> list[int]: - """ - From a list of physical dimensions corresponding to a virtual dimension, and from the physical - shape, get the stride indicating how moving on each physical dimension makes you move on the - virtual dimension. - - Example: - Imagine a vector of size 3, and of value [1, 2, 3]. - Imagine a SLT t of shape [3, 3] using this vector as physical and using [[0, 0]] as v_to_ps. - t.to_dense() is [1, 0, 0, 0, 2, 0, 0, 0, 3] (it's the flattening of the diagonal matrix - [[1, 0, 0], [0, 2, 0], [0, 0, 3]]). - When you move by 1 on physical dimension 0, you move by 4 on virtual dimension 0, i.e. - strides_v2([0, 0], [3]) = 4 - In the 2D view, you'd move by 1 row (3 indices) and 1 column (1 index). - - Example: - strides_v2([0, 0, 1], [3,4]) # [16, 1] - Moving by 1 on physical dimension 0 makes you move by 16 on the virtual dimension. Moving by - 1 on physical dimension 1 makes you move by 1 on the virtual dimension. - """ - - strides_v1 = list(accumulate([1] + [physical_shape[d] for d in p_dims[:0:-1]], operator.mul))[ - ::-1 - ] - result = [0 for _ in range(len(physical_shape))] - for i, d in enumerate(p_dims): - result[d] += strides_v1[i] - return result - - def get_groupings(pshape: list[int], basis: Tensor) -> list[list[int]]: basis_time_pshape = basis * tensor(pshape, dtype=torch.int64) groups = {i: {i} for i, column in enumerate(basis.T)} From c4f7dfc0b90cf489cfc401328ef75df370a20358 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 24 Nov 2025 18:41:15 +0100 Subject: [PATCH 26/42] Add docstring to fix_dim_of_size_1 and fix_ungrouped_dims --- src/torchjd/sparse/_sparse_latticed_tensor.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/torchjd/sparse/_sparse_latticed_tensor.py b/src/torchjd/sparse/_sparse_latticed_tensor.py index f7afb77ee..f4101031f 100644 --- a/src/torchjd/sparse/_sparse_latticed_tensor.py +++ b/src/torchjd/sparse/_sparse_latticed_tensor.py @@ -217,11 +217,17 @@ def get_full_source(source: list[int], destination: list[int], ndim: int) -> lis def fix_dim_of_size_1(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor]: + """ + Removes physical dimensions of size one and returns the corresponding new physical and new basis + """ + is_of_size_1 = tensor([s == 1 for s in physical.shape], dtype=torch.bool) return physical.squeeze(), basis[:, ~is_of_size_1] def fix_ungrouped_dims(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor]: + """Squash together physical dimensions that can be squashed.""" + groups = get_groupings(list(physical.shape), basis) nphysical = physical.reshape([prod([physical.shape[dim] for dim in group]) for group in groups]) basis_mapping = torch.zeros(physical.ndim, nphysical.ndim, dtype=torch.int64) From 731de15e76eaf58c504dbd19871c8354c58432e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 24 Nov 2025 19:08:45 +0100 Subject: [PATCH 27/42] Remove unsquash_pdim * It was unused and I think it will be replaced by functions that find divisors of the basis --- .../sparse/_aten_function_overrides/shape.py | 55 ------------------- .../sparse/test_sparse_latticed_tensor.py | 38 ------------- 2 files changed, 93 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index c94cd3758..8ae071440 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -69,61 +69,6 @@ def infer_shape(shape: list[int], numel: int) -> list[int]: return [inferred if s == -1 else s for s in shape] -def unsquash_pdim( - physical: Tensor, basis: Tensor, pdim: int, new_pdim_shape: list[int] -) -> tuple[Tensor, Tensor]: - """ - EXAMPLE: - - physical = [ - [1, 2, 3, 4, 5, 6], - [7, 8, 9, 10, 11, 12], - [13, 14, 15, 16, 17, 18], - ] - basis = [ - [1, 1], - [0, 2], - ] - - dim = 1 - shape = [2, 3] - - new_physical = [[ - [1, 2, 3], - [4, 5, 6], - ], [ - [7, 8, 9], - [10, 11, 12], - ], [ - [13, 14, 15], - [16, 17, 18], - ]] - - new_basis = [ - [1, 3, 1], - [0, 6, 2] - """ - - # TODO: handle working with multiple dimensions at once - - old_shape = list(physical.shape) - new_shape = old_shape[:pdim] + new_pdim_shape + old_shape[pdim + 1 :] - new_physical = physical.reshape(new_shape) - - multipliers = tensor([prod(new_pdim_shape[i + 1 :]) for i in range(len(new_pdim_shape))]) - - new_basis = torch.concat( - [ - basis[:, :pdim], - torch.outer(basis[:, pdim], multipliers), - basis[:, pdim + 1 :], - ], - dim=1, - ) - - return new_physical, new_basis - - @impl(aten._unsafe_view.default) def _unsafe_view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor: return view_default( diff --git a/tests/unit/sparse/test_sparse_latticed_tensor.py b/tests/unit/sparse/test_sparse_latticed_tensor.py index 55b723c67..3b7367b33 100644 --- a/tests/unit/sparse/test_sparse_latticed_tensor.py +++ b/tests/unit/sparse/test_sparse_latticed_tensor.py @@ -10,7 +10,6 @@ _IN_PLACE_POINTWISE_FUNCTIONS, _POINTWISE_FUNCTIONS, ) -from torchjd.sparse._aten_function_overrides.shape import unsquash_pdim from torchjd.sparse._coalesce import fix_zero_basis_vectors from torchjd.sparse._sparse_latticed_tensor import ( SparseLatticedTensor, @@ -312,43 +311,6 @@ def test_fix_ungrouped_dims( assert torch.equal(fixed_basis, expected_basis) -@mark.parametrize( - [ - "physical_shape", - "basis", - "pdim", - "new_pdim_shape", - "expected_physical_shape", - "expected_basis", - ], - [ - ([4], tensor([[1], [2]]), 0, [4], [4], tensor([[1], [2]])), # trivial - ([4], tensor([[1], [2]]), 0, [2, 2], [2, 2], tensor([[2, 1], [4, 2]])), - ( - [3, 4, 5], - tensor([[1, 2, 0], [1, 0, 1], [0, 1, 1]]), - 1, - [2, 1, 1, 2], - [3, 2, 1, 1, 2, 5], - tensor([[1, 4, 4, 4, 2, 0], [1, 0, 0, 0, 0, 1], [0, 2, 2, 2, 1, 1]]), - ), - ], -) -def test_unsquash_pdim( - physical_shape: list[int], - basis: Tensor, - pdim: int, - new_pdim_shape: list[int], - expected_physical_shape: list[int], - expected_basis: Tensor, -): - physical = randn_(physical_shape) - new_physical, new_basis = unsquash_pdim(physical, basis, pdim, new_pdim_shape) - - assert list(new_physical.shape) == expected_physical_shape - assert torch.equal(new_basis, expected_basis) - - @mark.parametrize( [ "source", From db571115ab2cec0abc7462d2cf316a83023cea55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 24 Nov 2025 20:25:44 +0100 Subject: [PATCH 28/42] Remove debug_info --- src/torchjd/sparse/_aten_function_overrides/shape.py | 4 ++-- src/torchjd/sparse/_sparse_latticed_tensor.py | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 8ae071440..8a88888e5 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -124,8 +124,8 @@ def cat_default(tensors: list[Tensor], dim: int) -> Tensor: if any(not torch.equal(t.basis, ref_basis) for t in tensors_[1:]): raise NotImplementedError( "Override for aten.cat.default does not support SLTs that do not all have the same " - f"basis. Found the following tensors:\n{[t.debug_info() for t in tensors_]} and the " - f"following dim: {dim}." + f"basis. Found the following tensors:\n{[repr(t) for t in tensors_]} and the following " + f"dim: {dim}." ) # We need to try to find the (pretty sure it either does not exist or is unique) physical diff --git a/src/torchjd/sparse/_sparse_latticed_tensor.py b/src/torchjd/sparse/_sparse_latticed_tensor.py index f4101031f..873e74ac9 100644 --- a/src/torchjd/sparse/_sparse_latticed_tensor.py +++ b/src/torchjd/sparse/_sparse_latticed_tensor.py @@ -105,10 +105,6 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): def __repr__(self, *, tensor_contents=None) -> str: return f"SparseLatticedTensor(physical={self.physical}, basis={self.basis})" - def debug_info(self) -> str: - info = f"vshape: {self.shape}\n" f"pshape: {self.physical.shape}\n" f"basis: {self.basis}\n" - return info - @classmethod def implements(cls, torch_function): """Register a torch function override for ScalarTensor""" From 29c4448ebdfe56cb1e143bfc4c4652b78d99d6b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 25 Nov 2025 17:56:28 +0100 Subject: [PATCH 29/42] WIP add offset and shape (still need to update tests, view, einsum functions and concatenate) --- src/torchjd/autogram/_engine.py | 2 +- .../_aten_function_overrides/backward.py | 12 +++- .../sparse/_aten_function_overrides/einsum.py | 12 ++-- .../_aten_function_overrides/pointwise.py | 6 +- .../sparse/_aten_function_overrides/shape.py | 40 +++++++---- src/torchjd/sparse/_sparse_latticed_tensor.py | 70 ++++++++++++++----- 6 files changed, 100 insertions(+), 42 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index fac380b48..95172522c 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -177,7 +177,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: output_dims = list(range(output.ndim)) identity = torch.eye(output.ndim, dtype=torch.int64) basis = torch.concatenate([identity, identity], dim=0) - jac_output = make_slt(torch.ones_like(output), basis) + jac_output = make_slt(torch.ones_like(output), basis, None, None) vmapped_diff = differentiation for _ in output_dims: diff --git a/src/torchjd/sparse/_aten_function_overrides/backward.py b/src/torchjd/sparse/_aten_function_overrides/backward.py index f8d9716bc..7aa7c9dac 100644 --- a/src/torchjd/sparse/_aten_function_overrides/backward.py +++ b/src/torchjd/sparse/_aten_function_overrides/backward.py @@ -10,7 +10,9 @@ def threshold_backward_default( ) -> SparseLatticedTensor: new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) - return SparseLatticedTensor(new_physical, grad_output.basis) + return SparseLatticedTensor( + new_physical, grad_output.basis, grad_output.offset, grad_output.size + ) @impl(aten.hardtanh_backward.default) @@ -24,7 +26,9 @@ def hardtanh_backward_default( raise NotImplementedError() new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) - return SparseLatticedTensor(new_physical, grad_output.basis) + return SparseLatticedTensor( + new_physical, grad_output.basis, grad_output.offset, grad_output.size + ) @impl(aten.hardswish_backward.default) @@ -33,4 +37,6 @@ def hardswish_backward_default(grad_output: SparseLatticedTensor, self: Tensor): raise NotImplementedError() new_physical = aten.hardswish_backward.default(grad_output.physical, self) - return SparseLatticedTensor(new_physical, grad_output.basis) + return SparseLatticedTensor( + new_physical, grad_output.basis, grad_output.offset, grad_output.size + ) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index 376ec2b70..fead4f4db 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -165,7 +165,7 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: @impl(aten.div.Tensor) def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) - t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.basis) + t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.basis, t2_.offset, t2_.size) all_dims = list(range(t1_.ndim)) return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) @@ -177,7 +177,7 @@ def mul_Scalar(t: SparseLatticedTensor, scalar) -> SparseLatticedTensor: assert isinstance(t, SparseLatticedTensor) new_physical = aten.mul.Scalar(t.physical, scalar) - return SparseLatticedTensor(new_physical, t.basis) + return SparseLatticedTensor(new_physical, t.basis, t.offset, t.size) @impl(aten.add.Tensor) @@ -186,9 +186,13 @@ def add_Tensor( ) -> SparseLatticedTensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) - if torch.equal(t1_.basis, t2_.basis): + if ( + torch.equal(t1_.basis, t2_.basis) + and torch.equal(t1_.offset, t2_.offset) + and torch.equal(t1_.size, t2_.size) + ): new_physical = t1_.physical + t2_.physical * alpha - return SparseLatticedTensor(new_physical, t1_.basis) + return SparseLatticedTensor(new_physical, t1_.basis, t1_.offset, t1_.size) else: raise NotImplementedError() diff --git a/src/torchjd/sparse/_aten_function_overrides/pointwise.py b/src/torchjd/sparse/_aten_function_overrides/pointwise.py index 51dbd020b..e9babf5de 100644 --- a/src/torchjd/sparse/_aten_function_overrides/pointwise.py +++ b/src/torchjd/sparse/_aten_function_overrides/pointwise.py @@ -71,7 +71,7 @@ def _override_pointwise(op): @impl(op) def func_(t: SparseLatticedTensor) -> SparseLatticedTensor: assert isinstance(t, SparseLatticedTensor) - return SparseLatticedTensor(op(t.physical), t.basis) + return SparseLatticedTensor(op(t.physical), t.basis, t.offset, t.shape) return func_ @@ -100,7 +100,7 @@ def pow_Tensor_Scalar(t: SparseLatticedTensor, exponent: float) -> SparseLattice return aten.pow.Tensor_Scalar(t.to_dense(), exponent) new_physical = aten.pow.Tensor_Scalar(t.physical, exponent) - return SparseLatticedTensor(new_physical, t.basis) + return SparseLatticedTensor(new_physical, t.basis, t.offset, t.shape) # Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. @@ -122,4 +122,4 @@ def div_Scalar(t: SparseLatticedTensor, divisor: float) -> SparseLatticedTensor: assert isinstance(t, SparseLatticedTensor) new_physical = aten.div.Scalar(t.physical, divisor) - return SparseLatticedTensor(new_physical, t.basis) + return SparseLatticedTensor(new_physical, t.basis, t.offset, t.shape) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 8a88888e5..d690b4a6d 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -4,7 +4,7 @@ from typing import cast import torch -from torch import Tensor, arange, tensor +from torch import Tensor, arange, cat, tensor from torch.ops import aten # type: ignore from torchjd.sparse._sparse_latticed_tensor import ( @@ -84,15 +84,17 @@ def unsqueeze_default(t: SparseLatticedTensor, dim: int) -> SparseLatticedTensor if dim < 0: dim = t.ndim + dim + 1 - new_basis = torch.concatenate( - [t.basis[:dim], torch.zeros(1, t.basis.shape[1], dtype=torch.int64), t.basis[dim:]] - ) - return SparseLatticedTensor(t.physical, new_basis) + pdims = t.basis.shape[1] + new_basis = cat([t.basis[:dim], torch.zeros(1, pdims, dtype=torch.int64), t.basis[dim:]]) + new_offset = cat([t.offset[:dim], torch.zeros(1, dtype=torch.int64), t.offset[dim:]]) + new_size = cat([t.size[:dim], torch.zeros(1, dtype=torch.int64), t.size[dim:]]) + return SparseLatticedTensor(t.physical, new_basis, new_offset, new_size) @impl(aten.squeeze.dims) def squeeze_dims(t: SparseLatticedTensor, dims: list[int] | int | None) -> Tensor: assert isinstance(t, SparseLatticedTensor) + # TODO: verify that the specified dimensions are of size 1. if dims is None: excluded = set(range(t.ndim)) @@ -103,13 +105,17 @@ def squeeze_dims(t: SparseLatticedTensor, dims: list[int] | int | None) -> Tenso is_row_kept = [i not in excluded for i in range(t.ndim)] new_basis = t.basis[is_row_kept] - return to_most_efficient_tensor(t.physical, new_basis) + new_offset = t.offset[is_row_kept] + new_size = t.size[is_row_kept] + return to_most_efficient_tensor(t.physical, new_basis, new_offset, new_size) @impl(aten.permute.default) def permute_default(t: SparseLatticedTensor, dims: list[int]) -> SparseLatticedTensor: - new_basis = t.basis[torch.tensor(dims)] - return SparseLatticedTensor(t.physical, new_basis) + new_basis = t.basis[dims] + new_offset = t.offset[dims] + new_size = t.size[dims] + return SparseLatticedTensor(t.physical, new_basis, new_offset, new_size) @impl(aten.cat.default) @@ -173,6 +179,7 @@ def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedT # Try to expand each dimension to its new size new_physical = t.physical new_basis = t.basis + new_sizes = t.size for d, (v, orig_size, new_size) in enumerate(zip(t.basis, t.shape, sizes, strict=True)): if v.sum() > 0 and orig_size != new_size and new_size != -1: raise ValueError( @@ -190,8 +197,9 @@ def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedT new_basis_vector = torch.zeros(t.ndim, 1, dtype=torch.int64) new_basis_vector[d, 0] = 1 new_basis = torch.cat([new_basis, new_basis_vector], dim=1) + new_sizes[d] = new_size - return SparseLatticedTensor(new_physical, new_basis) + return SparseLatticedTensor(new_physical, new_basis, t.offset, new_sizes) @impl(aten.broadcast_tensors.default) @@ -224,11 +232,13 @@ def broadcast_tensors_default(tensors: list[Tensor]) -> tuple[Tensor, Tensor]: @impl(aten.transpose.int) def transpose_int(t: SparseLatticedTensor, dim0: int, dim1: int) -> SparseLatticedTensor: assert isinstance(t, SparseLatticedTensor) - return SparseLatticedTensor(t.physical, _swap_rows(t.basis, dim0, dim1)) + new_index = arange(t.basis.shape[0]) + new_index[dim0] = dim1 + new_index[dim1] = dim0 + + new_basis = t.basis[new_index] + new_offset = t.offset[new_index] + new_shape = list(tensor(t.shape, dtype=torch.int64)[new_index]) -def _swap_rows(matrix: Tensor, c0: int, c1: int) -> Tensor: - index = arange(matrix.shape[0]) - index[c0] = c1 - index[c1] = c0 - return matrix[index] + return SparseLatticedTensor(t.physical, new_basis, new_offset, new_shape) diff --git a/src/torchjd/sparse/_sparse_latticed_tensor.py b/src/torchjd/sparse/_sparse_latticed_tensor.py index 873e74ac9..fe26853e3 100644 --- a/src/torchjd/sparse/_sparse_latticed_tensor.py +++ b/src/torchjd/sparse/_sparse_latticed_tensor.py @@ -12,7 +12,13 @@ class SparseLatticedTensor(Tensor): _HANDLED_FUNCTIONS = dict[Callable, Callable]() @staticmethod - def __new__(cls, physical: Tensor, basis: Tensor): + def __new__( + cls, + physical: Tensor, + basis: Tensor, + offset: Tensor | None = None, + size: list[int] | tuple[int, ...] | torch.Size | Tensor | None = None, + ): assert basis.dtype == torch.int64 # Note [Passing requires_grad=true tensors to subclasses] @@ -25,13 +31,21 @@ def __new__(cls, physical: Tensor, basis: Tensor): # (which is bad!) assert not physical.requires_grad or not torch.is_grad_enabled() - pshape = tensor(physical.shape, dtype=torch.int64) - vshape = basis @ (pshape - 1) + 1 + if size is None: + pshape = tensor(physical.shape, dtype=torch.int64) + size = basis @ (pshape - 1) + 1 + return Tensor._make_wrapper_subclass( - cls, tuple(vshape.tolist()), dtype=physical.dtype, device=physical.device + cls, list(size), dtype=physical.dtype, device=physical.device ) - def __init__(self, physical: Tensor, basis: Tensor): + def __init__( + self, + physical: Tensor, + basis: Tensor, + offset: Tensor | None, + size: list[int] | tuple[int, ...] | torch.Size | Tensor | None, + ): """ This constructor is made for specifying physical and basis exactly. It should not modify it. @@ -42,9 +56,18 @@ def __init__(self, physical: Tensor, basis: Tensor): :param physical: The dense tensor holding the actual data. :param basis: Integer (int64) tensor of shape [virtual_ndim, physical_ndim], representing the linear transformation between an index in the physical tensor and the corresponding - index in the virtual tensor, i.e. v_index = basis @ p_index. + index in the virtual tensor, i.e. v_index = basis @ p_index + offset. + :param offset: Offset for the virtual index, i.e. v_index = basis @ p_index + offset. + :param size: Size of the sparse tensor. If not provided, the size will be inferred as the + minimum size big enough to hold all non-zero elements. + + # TODO: make a nicer interface where it's possible to provide lists or sizes instead of + always having to provide int tensors """ + if offset is None: + offset = torch.zeros(len(self.shape)) + if any(s == 1 for s in physical.shape): raise ValueError( "physical must not contain any dimension of size 1. Found physical.shape=" @@ -71,6 +94,8 @@ def __init__(self, physical: Tensor, basis: Tensor): self.physical = physical self.basis = basis + self.offset = offset + self.size = tensor(size, dtype=torch.int64) def to_dense( self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None @@ -85,7 +110,7 @@ def to_dense( p_indices_grid = stack(meshgrid(*p_index_ranges, indexing="ij")) # addmm_cuda not implemented for Long tensors => gotta have these tensors on cpu - v_indices_grid = tensordot(self.basis, p_indices_grid, dims=1) + v_indices_grid = tensordot(self.basis, p_indices_grid, dims=1) + self.offset res = zeros(self.shape, device=self.device, dtype=self.dtype) res[tuple(v_indices_grid)] = self.physical return res @@ -103,7 +128,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return func(*unwrapped_args, **unwrapped_kwargs) def __repr__(self, *, tensor_contents=None) -> str: - return f"SparseLatticedTensor(physical={self.physical}, basis={self.basis})" + return f"SparseLatticedTensor(physical={self.physical}, basis={self.basis}, offset={self.offset}, size={self.size})" @classmethod def implements(cls, torch_function): @@ -122,9 +147,9 @@ def decorator(func): def print_fallback(func, args, kwargs) -> None: def tensor_to_str(t: Tensor) -> str: - result = f"{t.__class__.__name__} - vshape: {t.shape}" + result = f"{t.__class__.__name__} - shape: {t.shape}" if isinstance(t, SparseLatticedTensor): - result += f" - pshape: {t.physical.shape} - basis: {t.basis}" + result += f" - pshape: {t.physical.shape} - basis: {t.basis} - offset: {t.offset}" return result @@ -167,18 +192,26 @@ def to_sparse_latticed_tensor(t: Tensor) -> SparseLatticedTensor: if isinstance(t, SparseLatticedTensor): return t else: - return make_slt(physical=t, basis=torch.eye(t.ndim, dtype=torch.int64)) + return make_slt(t, torch.eye(t.ndim, dtype=torch.int64), None, None) -def to_most_efficient_tensor(physical: Tensor, basis: Tensor) -> Tensor: +def to_most_efficient_tensor( + physical: Tensor, + basis: Tensor, + offset: Tensor | None, + size: list[int] | tuple[int, ...] | torch.Size | Tensor | None, +) -> Tensor: physical, basis = fix_dim_of_size_1(physical, basis) physical, basis = fix_ungrouped_dims(physical, basis) if (basis.sum(dim=0) == 1).all(): + print("Turning supposedly dense SLT into Tensor. This can be bugged and slow.") + # TODO: this condition is broken if basis is allowed to have negative values. It also only + # works when size is the default and offset is 0. # TODO: this can be done more efficiently (without even creating the SLT) - return SparseLatticedTensor(physical, basis).to_dense() + return SparseLatticedTensor(physical, basis, offset, size).to_dense() else: - return SparseLatticedTensor(physical, basis) + return SparseLatticedTensor(physical, basis, offset, size) def unwrap_to_dense(t: Tensor): @@ -234,9 +267,14 @@ def fix_ungrouped_dims(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor] return nphysical, new_basis -def make_slt(physical: Tensor, basis: Tensor) -> SparseLatticedTensor: +def make_slt( + physical: Tensor, + basis: Tensor, + offset: Tensor | None, + size: list[int] | tuple[int, ...] | torch.Size | Tensor | None, +) -> SparseLatticedTensor: """Fix physical and basis and create a SparseLatticedTensor with them.""" physical, basis = fix_dim_of_size_1(physical, basis) physical, basis = fix_ungrouped_dims(physical, basis) - return SparseLatticedTensor(physical, basis) + return SparseLatticedTensor(physical, basis, offset, size) From 112931f64f13092f7f842f1da9a4fea665a5b7f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 25 Nov 2025 20:07:18 +0100 Subject: [PATCH 30/42] Add default dim=0 in cat_default * Otherwise when specifying dim=0 the dispatcher calls the function without dim arg (this is very weird but it seems that this is necessary) --- src/torchjd/sparse/_aten_function_overrides/shape.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 8a88888e5..b5e793ee0 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -113,7 +113,7 @@ def permute_default(t: SparseLatticedTensor, dims: list[int]) -> SparseLatticedT @impl(aten.cat.default) -def cat_default(tensors: list[Tensor], dim: int) -> Tensor: +def cat_default(tensors: list[Tensor], dim: int = 0) -> Tensor: if any(not isinstance(t, SparseLatticedTensor) for t in tensors): print_fallback(aten.cat.default, (tensors, dim), {}) return aten.cat.default([unwrap_to_dense(t) for t in tensors]) From 840d035d68ca9fca0f3f623cdfafa3d5077b704b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Tue, 25 Nov 2025 20:07:50 +0100 Subject: [PATCH 31/42] Fix concat for cases where it has to densify --- .../sparse/_aten_function_overrides/shape.py | 11 +++++++++++ tests/unit/sparse/test_sparse_latticed_tensor.py | 14 ++++++++++---- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index b5e793ee0..e972d767d 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -127,6 +127,17 @@ def cat_default(tensors: list[Tensor], dim: int = 0) -> Tensor: f"basis. Found the following tensors:\n{[repr(t) for t in tensors_]} and the following " f"dim: {dim}." ) + if any(t.physical.shape != ref_tensor.physical.shape for t in tensors_[1:]): + # This can happen in the following example: + # t1 = SLT([1 2 3], [[2]]) + # t2 = SLT([4 5 6 7], [[2]]) + # The expected result would be 1 0 2 0 3 4 0 5 0 6 0 7, but this is not representable + # efficiently as an SLT (because there is no 0 between 3 and 4, and both physicals have a + # different shape so we can't just stack them). + + # TODO: Maybe a partial densify is possible rather than a full densify. + print_fallback(aten.cat.default, (tensors, dim), {}) + return aten.cat.default([unwrap_to_dense(t) for t in tensors]) # We need to try to find the (pretty sure it either does not exist or is unique) physical # dimension that makes us only move on virtual dimension dim. It also needs to be such that diff --git a/tests/unit/sparse/test_sparse_latticed_tensor.py b/tests/unit/sparse/test_sparse_latticed_tensor.py index 3b7367b33..a9900a4b4 100644 --- a/tests/unit/sparse/test_sparse_latticed_tensor.py +++ b/tests/unit/sparse/test_sparse_latticed_tensor.py @@ -332,21 +332,27 @@ def test_get_column_indices(source: list[int], destination: list[int], ndim: int @mark.parametrize( - ["slt_args", "dim"], + ["slt_args", "dim", "expected_densify"], [ - ([([3], tensor([[1], [1]])), ([3], tensor([[1], [1]]))], 1), - ([([3, 2], tensor([[1, 0], [1, 3]])), ([3, 2], tensor([[1, 0], [1, 3]]))], 1), + ([([3], tensor([[1], [1]])), ([3], tensor([[1], [1]]))], 1, False), + ([([3], tensor([[2]])), ([4], tensor([[2]]))], 0, True), + ([([3, 2], tensor([[1, 0], [1, 3]])), ([3, 2], tensor([[1, 0], [1, 3]]))], 1, False), ], ) def test_concatenate( slt_args: list[tuple[list[int], Tensor]], dim: int, + expected_densify: bool, ): tensors = [SparseLatticedTensor(randn_(pshape), basis) for pshape, basis in slt_args] res = aten.cat.default(tensors, dim) expected = aten.cat.default([t.to_dense() for t in tensors], dim) - assert isinstance(res, SparseLatticedTensor) + if expected_densify: + assert not isinstance(res, SparseLatticedTensor) + else: + assert isinstance(res, SparseLatticedTensor) + assert torch.all(torch.eq(res.to_dense(), expected)) From 298daaba53664a06589dd39266d8111b645244d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 26 Nov 2025 16:17:56 +0100 Subject: [PATCH 32/42] Fix test_hnf_decomposition --- tests/unit/sparse/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sparse/test_linalg.py b/tests/unit/sparse/test_linalg.py index 0a407d4f2..6826d5874 100644 --- a/tests/unit/sparse/test_linalg.py +++ b/tests/unit/sparse/test_linalg.py @@ -43,4 +43,4 @@ def test_hnf_decomposition(shape: tuple[int, int], max_rank: int, reduced: bool) # Check pivots are positive pivots = H.diag()[:r] - return torch.all(pivots > 0).item() + assert torch.all(pivots > 0).item() From 7fd6766a90929148cacc59c3d17c1f06c63d9e80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 26 Nov 2025 16:35:38 +0100 Subject: [PATCH 33/42] Fix comment about lower triangular check and improve code --- tests/unit/sparse/test_linalg.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/unit/sparse/test_linalg.py b/tests/unit/sparse/test_linalg.py index 6826d5874..604baa4f1 100644 --- a/tests/unit/sparse/test_linalg.py +++ b/tests/unit/sparse/test_linalg.py @@ -37,9 +37,8 @@ def test_hnf_decomposition(shape: tuple[int, int], max_rank: int, reduced: bool) assert torch.equal(H @ V, A) assert torch.equal(A @ U, H) - # Check H is upper triangular - mask = torch.triu(torch.ones(shape[0], r, dtype=torch.bool), diagonal=1) - assert torch.all(H[mask] == 0).item() + # Check H is lower triangular (its upper triangle must be zero) + assert torch.equal(torch.triu(H, diagonal=1), torch.zeros_like(H)) # Check pivots are positive pivots = H.diag()[:r] From c65069c7f27f66fd2c69efbf2001c79f3f61d93e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 26 Nov 2025 16:46:07 +0100 Subject: [PATCH 34/42] Remove check that pivots are positive (they aren't) --- tests/unit/sparse/test_linalg.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/unit/sparse/test_linalg.py b/tests/unit/sparse/test_linalg.py index 604baa4f1..939d625d5 100644 --- a/tests/unit/sparse/test_linalg.py +++ b/tests/unit/sparse/test_linalg.py @@ -39,7 +39,3 @@ def test_hnf_decomposition(shape: tuple[int, int], max_rank: int, reduced: bool) # Check H is lower triangular (its upper triangle must be zero) assert torch.equal(torch.triu(H, diagonal=1), torch.zeros_like(H)) - - # Check pivots are positive - pivots = H.diag()[:r] - assert torch.all(pivots > 0).item() From e3b687bcdbf8fe5ddac07829e2641e4e6cf242fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 26 Nov 2025 17:52:42 +0100 Subject: [PATCH 35/42] Add test_compute_lcm --- tests/unit/sparse/test_linalg.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/unit/sparse/test_linalg.py b/tests/unit/sparse/test_linalg.py index 939d625d5..0dac10b55 100644 --- a/tests/unit/sparse/test_linalg.py +++ b/tests/unit/sparse/test_linalg.py @@ -1,7 +1,8 @@ import torch from pytest import mark +from torch import Tensor, tensor -from torchjd.sparse._linalg import hnf_decomposition +from torchjd.sparse._linalg import compute_lcm, hnf_decomposition @mark.parametrize( @@ -39,3 +40,25 @@ def test_hnf_decomposition(shape: tuple[int, int], max_rank: int, reduced: bool) # Check H is lower triangular (its upper triangle must be zero) assert torch.equal(torch.triu(H, diagonal=1), torch.zeros_like(H)) + + +@mark.parametrize( + ["S1", "S2"], + [ + (tensor([[8]]), tensor([[12]])), + (tensor([[8, 2]]), tensor([[12, 3]])), + (tensor([[8], [2]]), tensor([[12], [3]])), + (tensor([[8, 5]]), tensor([[12, 9]])), + (tensor([[8, 6], [4, 2]]), tensor([[16, 4], [2, 2]])), + ], +) +def test_compute_lcm(S1: Tensor, S2: Tensor): + L, M1, M2 = compute_lcm(S1, S2) + + print() + print(L) + print(M1) + print(M2) + + assert torch.equal(S1 @ M1, L) + assert torch.equal(S2 @ M2, L) From 153e3f8768205bc316e0def1075eac3bba0ed7e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 26 Nov 2025 17:52:59 +0100 Subject: [PATCH 36/42] Fix compute_lcm (no idea what i'm doing but it seems to work) --- src/torchjd/sparse/_linalg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 784f53965..56378e24e 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -210,7 +210,7 @@ def compute_lcm(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: H, U, _ = hnf_decomposition(B, False) rank = _get_hermite_factor_rank(H) - M1 = U[:n1, rank:] - M2 = U[n1:, rank:] + M2 = U[n1:, -rank:] + M1 = U[:n1, -rank:] L = S1 @ M1 return L, M1, M2 From 6ada56495e03a434f43b7acdd60627bc8453780e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 3 Dec 2025 15:55:21 +0100 Subject: [PATCH 37/42] Finish switching to offset and shape --- .../_aten_function_overrides/backward.py | 6 +- .../sparse/_aten_function_overrides/einsum.py | 8 +- .../sparse/_aten_function_overrides/shape.py | 68 +++------------ src/torchjd/sparse/_sparse_latticed_tensor.py | 86 ++++++++++++++++--- .../sparse/test_sparse_latticed_tensor.py | 27 +++--- 5 files changed, 111 insertions(+), 84 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/backward.py b/src/torchjd/sparse/_aten_function_overrides/backward.py index 7aa7c9dac..f1c3ceb38 100644 --- a/src/torchjd/sparse/_aten_function_overrides/backward.py +++ b/src/torchjd/sparse/_aten_function_overrides/backward.py @@ -11,7 +11,7 @@ def threshold_backward_default( new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) return SparseLatticedTensor( - new_physical, grad_output.basis, grad_output.offset, grad_output.size + new_physical, grad_output.basis, grad_output.offset, grad_output.shape_t ) @@ -27,7 +27,7 @@ def hardtanh_backward_default( new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) return SparseLatticedTensor( - new_physical, grad_output.basis, grad_output.offset, grad_output.size + new_physical, grad_output.basis, grad_output.offset, grad_output.shape_t ) @@ -38,5 +38,5 @@ def hardswish_backward_default(grad_output: SparseLatticedTensor, self: Tensor): new_physical = aten.hardswish_backward.default(grad_output.physical, self) return SparseLatticedTensor( - new_physical, grad_output.basis, grad_output.offset, grad_output.size + new_physical, grad_output.basis, grad_output.offset, grad_output.shape_t ) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index fead4f4db..a9691faea 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -165,7 +165,7 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: @impl(aten.div.Tensor) def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) - t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.basis, t2_.offset, t2_.size) + t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.basis, t2_.offset, t2_.shape_t) all_dims = list(range(t1_.ndim)) return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) @@ -177,7 +177,7 @@ def mul_Scalar(t: SparseLatticedTensor, scalar) -> SparseLatticedTensor: assert isinstance(t, SparseLatticedTensor) new_physical = aten.mul.Scalar(t.physical, scalar) - return SparseLatticedTensor(new_physical, t.basis, t.offset, t.size) + return SparseLatticedTensor(new_physical, t.basis, t.offset, t.shape_t) @impl(aten.add.Tensor) @@ -189,10 +189,10 @@ def add_Tensor( if ( torch.equal(t1_.basis, t2_.basis) and torch.equal(t1_.offset, t2_.offset) - and torch.equal(t1_.size, t2_.size) + and torch.equal(t1_.shape_t, t2_.shape_t) ): new_physical = t1_.physical + t2_.physical * alpha - return SparseLatticedTensor(new_physical, t1_.basis, t1_.offset, t1_.size) + return SparseLatticedTensor(new_physical, t1_.basis, t1_.offset, t1_.shape_t) else: raise NotImplementedError() diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 9ad509a2f..7cda9a57b 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -1,7 +1,6 @@ import operator from itertools import accumulate from math import prod -from typing import cast import torch from torch import Tensor, arange, cat, tensor @@ -41,6 +40,9 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor: assert isinstance(t, SparseLatticedTensor) + if not torch.equal(t.padding, torch.zeros_like(t.padding)): + raise NotImplementedError() + shape = infer_shape(shape, t.numel()) if prod(shape) != t.numel(): @@ -51,7 +53,9 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor: c = _reverse_cumulative_product(vshape) c_prime = _reverse_cumulative_product(shape) new_basis = ((c @ S).unsqueeze(0) // c_prime.unsqueeze(1)) % tensor(shape).unsqueeze(1) - return to_most_efficient_tensor(t.physical, new_basis) + + new_offset = torch.zeros(len(shape), dtype=torch.int64) + return to_most_efficient_tensor(t.physical, new_basis, new_offset, shape) def _reverse_cumulative_product(values: list[int]) -> Tensor: @@ -87,7 +91,7 @@ def unsqueeze_default(t: SparseLatticedTensor, dim: int) -> SparseLatticedTensor pdims = t.basis.shape[1] new_basis = cat([t.basis[:dim], torch.zeros(1, pdims, dtype=torch.int64), t.basis[dim:]]) new_offset = cat([t.offset[:dim], torch.zeros(1, dtype=torch.int64), t.offset[dim:]]) - new_size = cat([t.size[:dim], torch.zeros(1, dtype=torch.int64), t.size[dim:]]) + new_size = cat([t.shape_t[:dim], torch.ones(1, dtype=torch.int64), t.shape_t[dim:]]) return SparseLatticedTensor(t.physical, new_basis, new_offset, new_size) @@ -106,7 +110,7 @@ def squeeze_dims(t: SparseLatticedTensor, dims: list[int] | int | None) -> Tenso is_row_kept = [i not in excluded for i in range(t.ndim)] new_basis = t.basis[is_row_kept] new_offset = t.offset[is_row_kept] - new_size = t.size[is_row_kept] + new_size = t.shape_t[is_row_kept] return to_most_efficient_tensor(t.physical, new_basis, new_offset, new_size) @@ -114,7 +118,7 @@ def squeeze_dims(t: SparseLatticedTensor, dims: list[int] | int | None) -> Tenso def permute_default(t: SparseLatticedTensor, dims: list[int]) -> SparseLatticedTensor: new_basis = t.basis[dims] new_offset = t.offset[dims] - new_size = t.size[dims] + new_size = t.shape_t[dims] return SparseLatticedTensor(t.physical, new_basis, new_offset, new_size) @@ -124,56 +128,10 @@ def cat_default(tensors: list[Tensor], dim: int = 0) -> Tensor: print_fallback(aten.cat.default, (tensors, dim), {}) return aten.cat.default([unwrap_to_dense(t) for t in tensors]) - tensors_ = [cast(SparseLatticedTensor, t) for t in tensors] - ref_tensor = tensors_[0] - ref_basis = ref_tensor.basis - if any(not torch.equal(t.basis, ref_basis) for t in tensors_[1:]): - raise NotImplementedError( - "Override for aten.cat.default does not support SLTs that do not all have the same " - f"basis. Found the following tensors:\n{[repr(t) for t in tensors_]} and the following " - f"dim: {dim}." - ) - if any(t.physical.shape != ref_tensor.physical.shape for t in tensors_[1:]): - # This can happen in the following example: - # t1 = SLT([1 2 3], [[2]]) - # t2 = SLT([4 5 6 7], [[2]]) - # The expected result would be 1 0 2 0 3 4 0 5 0 6 0 7, but this is not representable - # efficiently as an SLT (because there is no 0 between 3 and 4, and both physicals have a - # different shape so we can't just stack them). - - # TODO: Maybe a partial densify is possible rather than a full densify. - print_fallback(aten.cat.default, (tensors, dim), {}) - return aten.cat.default([unwrap_to_dense(t) for t in tensors]) - - # We need to try to find the (pretty sure it either does not exist or is unique) physical - # dimension that makes us only move on virtual dimension dim. It also needs to be such that - # traversing it entirely brings us exactly to the end of virtual dimension dim. - - ref_virtual_dim_size = ref_tensor.shape[dim] - indices = torch.argwhere( - torch.eq(ref_basis[dim] * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) - & torch.eq(ref_basis.sum(dim=0) * tensor(ref_tensor.physical.shape), ref_virtual_dim_size) - ) - assert len(indices) <= 1 - - if len(indices) == 0: - # Add a physical dimension pdim on which we can concatenate the physicals such that this - # translates into a concatenation of the virtuals on virtual dimension dim. - - pdim = ref_tensor.physical.ndim - physicals = [t.physical.unsqueeze(-1) for t in tensors_] - new_basis_vector = torch.zeros(ref_tensor.ndim, 1, dtype=torch.int64) - new_basis_vector[dim, 0] = ref_virtual_dim_size - new_basis = torch.concatenate([ref_tensor.basis, new_basis_vector], dim=1) - else: - # Such a physical dimension already exists. Note that an alternative implementation would be - # to simply always add the physical dimension, and squash it if it ends up being not needed. - physicals = [t.physical for t in tensors_] - pdim = cast(int, indices[0, 0].item()) - new_basis = ref_tensor.basis + print_fallback(aten.cat.default, (tensors, dim), {}) + return aten.cat.default([unwrap_to_dense(t) for t in tensors]) - new_physical = aten.cat.default(physicals, dim=pdim) - return SparseLatticedTensor(new_physical, new_basis) + # TODO: add implementation based on adding some margin to tensors and summing them @impl(aten.expand.default) @@ -190,7 +148,7 @@ def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedT # Try to expand each dimension to its new size new_physical = t.physical new_basis = t.basis - new_sizes = t.size + new_sizes = t.shape_t for d, (v, orig_size, new_size) in enumerate(zip(t.basis, t.shape, sizes, strict=True)): if v.sum() > 0 and orig_size != new_size and new_size != -1: raise ValueError( diff --git a/src/torchjd/sparse/_sparse_latticed_tensor.py b/src/torchjd/sparse/_sparse_latticed_tensor.py index fe26853e3..e4a7fc87c 100644 --- a/src/torchjd/sparse/_sparse_latticed_tensor.py +++ b/src/torchjd/sparse/_sparse_latticed_tensor.py @@ -17,7 +17,7 @@ def __new__( physical: Tensor, basis: Tensor, offset: Tensor | None = None, - size: list[int] | tuple[int, ...] | torch.Size | Tensor | None = None, + shape: list[int] | tuple[int, ...] | torch.Size | Tensor | None = None, ): assert basis.dtype == torch.int64 @@ -31,12 +31,12 @@ def __new__( # (which is bad!) assert not physical.requires_grad or not torch.is_grad_enabled() - if size is None: + if shape is None: pshape = tensor(physical.shape, dtype=torch.int64) - size = basis @ (pshape - 1) + 1 + shape = basis @ (pshape - 1) + 1 return Tensor._make_wrapper_subclass( - cls, list(size), dtype=physical.dtype, device=physical.device + cls, list(shape), dtype=physical.dtype, device=physical.device ) def __init__( @@ -44,7 +44,7 @@ def __init__( physical: Tensor, basis: Tensor, offset: Tensor | None, - size: list[int] | tuple[int, ...] | torch.Size | Tensor | None, + shape: list[int] | tuple[int, ...] | torch.Size | Tensor | None, ): """ This constructor is made for specifying physical and basis exactly. It should not modify @@ -58,7 +58,7 @@ def __init__( the linear transformation between an index in the physical tensor and the corresponding index in the virtual tensor, i.e. v_index = basis @ p_index + offset. :param offset: Offset for the virtual index, i.e. v_index = basis @ p_index + offset. - :param size: Size of the sparse tensor. If not provided, the size will be inferred as the + :param shape: Size of the sparse tensor. If not provided, the size will be inferred as the minimum size big enough to hold all non-zero elements. # TODO: make a nicer interface where it's possible to provide lists or sizes instead of @@ -66,7 +66,7 @@ def __init__( """ if offset is None: - offset = torch.zeros(len(self.shape)) + offset = torch.zeros(len(self.shape), dtype=torch.int64) if any(s == 1 for s in physical.shape): raise ValueError( @@ -95,7 +95,16 @@ def __init__( self.physical = physical self.basis = basis self.offset = offset - self.size = tensor(size, dtype=torch.int64) + + if shape is None: + pshape = tensor(physical.shape, dtype=torch.int64) + shape = basis @ (pshape - 1) + 1 + if isinstance(shape, torch.Tensor): + self.shape_t = shape + else: + self.shape_t = tensor(shape, dtype=torch.int64) + + self.pshape_t = tensor(physical.shape, dtype=torch.int64) def to_dense( self, dtype: torch.dtype | None = None, *, masked_grad: bool | None = None @@ -110,7 +119,9 @@ def to_dense( p_indices_grid = stack(meshgrid(*p_index_ranges, indexing="ij")) # addmm_cuda not implemented for Long tensors => gotta have these tensors on cpu - v_indices_grid = tensordot(self.basis, p_indices_grid, dims=1) + self.offset + reshaped_offset = self.offset.reshape([-1] + [1] * self.physical.ndim) + v_indices_grid = tensordot(self.basis, p_indices_grid, dims=1) + reshaped_offset + # v_indices_grid is of shape [n_virtual_dims] + physical_shape res = zeros(self.shape, device=self.device, dtype=self.dtype) res[tuple(v_indices_grid)] = self.physical return res @@ -128,7 +139,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return func(*unwrapped_args, **unwrapped_kwargs) def __repr__(self, *, tensor_contents=None) -> str: - return f"SparseLatticedTensor(physical={self.physical}, basis={self.basis}, offset={self.offset}, size={self.size})" + return f"SparseLatticedTensor(physical={self.physical}, basis={self.basis}, offset={self.offset}, size={self.shape_t})" @classmethod def implements(cls, torch_function): @@ -141,6 +152,61 @@ def decorator(func): return decorator + @property + def start_padding(self) -> Tensor: + """ + Returns the number of zeros of padding at the start of each virtual dimension. + + The result is an int tensor of shape [virtual_ndim]. + """ + + return self.offset + + @property + def end_padding(self) -> Tensor: + """ + Returns the number of zeros of padding at the end of each virtual dimension. + + The result is an int tensor of shape [virtual_ndim]. + """ + + return self.shape_t - self.physical_image_size - self.offset + + @property + def padding(self) -> Tensor: + """ + Returns the number of zeros of padding at the start and end of each virtual dimension. + + The result is an int tensor of shape [virtual_ndim, 2]. + """ + + return torch.stack([self.start_padding, self.end_padding], dim=1) + + @property + def min_natural_virtual_indices(self) -> Tensor: + # Basis where each positive element is replaced by 0 + non_positive_basis = torch.min(self.basis, torch.zeros_like(self.basis)) + max_physical_index = self.pshape_t - 1 + return (non_positive_basis * max_physical_index.unsqueeze(0)).sum(dim=1) + + @property + def max_natural_virtual_indices(self) -> Tensor: + # Basis where each negative element is replaced by 0 + non_negative = torch.max(self.basis, torch.zeros_like(self.basis)) + max_physical_index = self.pshape_t - 1 + return (non_negative * max_physical_index.unsqueeze(0)).sum(dim=1) + + @property + def physical_image_size(self) -> Tensor: + """ + Returns the shape of the image of the physical through the basis transform. + + The result is an int tensor of shape [virtual_ndim]. + """ + + one = torch.ones(self.ndim, dtype=torch.int64) + return self.max_natural_virtual_indices - self.min_natural_virtual_indices + one + impl = SparseLatticedTensor.implements diff --git a/tests/unit/sparse/test_sparse_latticed_tensor.py b/tests/unit/sparse/test_sparse_latticed_tensor.py index a9900a4b4..c94bb9b09 100644 --- a/tests/unit/sparse/test_sparse_latticed_tensor.py +++ b/tests/unit/sparse/test_sparse_latticed_tensor.py @@ -23,7 +23,7 @@ def test_to_dense(): n = 2 m = 3 a = randn_([n, m]) - b = SparseLatticedTensor(a, tensor([[1, 0], [0, 1], [0, 1], [1, 0]])) + b = SparseLatticedTensor(a, tensor([[1, 0], [0, 1], [0, 1], [1, 0]]), offset=None, shape=None) c = b.to_dense() for i in range(n): @@ -33,7 +33,7 @@ def test_to_dense(): def test_to_dense2(): a = tensor_([1.0, 2.0, 3.0]) - b = SparseLatticedTensor(a, tensor([[4]])) + b = SparseLatticedTensor(a, tensor([[4]]), offset=None, shape=None) c = b.to_dense() expected = tensor_([1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]) assert torch.all(torch.eq(c, expected)) @@ -80,8 +80,8 @@ def test_einsum( b_indices: list[int], output_indices: list[int], ): - a = SparseLatticedTensor(randn_(a_pshape), a_basis) - b = SparseLatticedTensor(randn_(b_pshape), b_basis) + a = SparseLatticedTensor(randn_(a_pshape), a_basis, offset=None, shape=None) + b = SparseLatticedTensor(randn_(b_pshape), b_basis, offset=None, shape=None) res = einsum((a, a_indices), (b, b_indices), output=output_indices) @@ -102,7 +102,7 @@ def test_einsum( ) def test_sparse_latticed_tensor_scalar(shape: list[int]): a = randn_(shape) - b = SparseLatticedTensor(a, torch.eye(len(shape), dtype=torch.int64)) + b = SparseLatticedTensor(a, torch.eye(len(shape), dtype=torch.int64), offset=None, shape=None) assert_close(a, b.to_dense()) @@ -110,7 +110,7 @@ def test_sparse_latticed_tensor_scalar(shape: list[int]): @mark.parametrize("dim", [2, 3, 4, 5, 10]) def test_diag_equivalence(dim: int): a = randn_([dim]) - b = SparseLatticedTensor(a, tensor([[1], [1]])) + b = SparseLatticedTensor(a, tensor([[1], [1]]), offset=None, shape=None) diag_a = torch.diag(a) @@ -120,7 +120,7 @@ def test_diag_equivalence(dim: int): def test_three_virtual_single_physical(): dim = 10 a = randn_([dim]) - b = SparseLatticedTensor(a, tensor([[1], [1], [1]])) + b = SparseLatticedTensor(a, tensor([[1], [1], [1]]), offset=None, shape=None) expected = zeros_([dim, dim, dim]) for i in range(dim): @@ -133,7 +133,7 @@ def test_three_virtual_single_physical(): def test_pointwise(func): dim = 10 a = randn_([dim]) - b = SparseLatticedTensor(a, tensor([[1], [1]])) + b = SparseLatticedTensor(a, tensor([[1], [1]]), offset=None, shape=None) c = b.to_dense() res = func(b) assert isinstance(res, SparseLatticedTensor) @@ -145,7 +145,7 @@ def test_pointwise(func): def test_inplace_pointwise(func): dim = 10 a = randn_([dim]) - b = SparseLatticedTensor(a, tensor([[1], [1]])) + b = SparseLatticedTensor(a, tensor([[1], [1]]), offset=None, shape=None) c = b.to_dense() func(b) assert isinstance(b, SparseLatticedTensor) @@ -157,7 +157,7 @@ def test_inplace_pointwise(func): def test_unary(func): dim = 10 a = randn_([dim]) - b = SparseLatticedTensor(a, tensor([[1], [1]])) + b = SparseLatticedTensor(a, tensor([[1], [1]]), offset=None, shape=None) c = b.to_dense() res = func(b) @@ -254,7 +254,7 @@ def test_view( expected_basis: Tensor, ): a = randn_(tuple(physical_shape)) - t = SparseLatticedTensor(a, basis) + t = SparseLatticedTensor(a, basis, offset=None, shape=None) result = aten.view.default(t, target_shape) expected = t.to_dense().reshape(target_shape) @@ -344,7 +344,10 @@ def test_concatenate( dim: int, expected_densify: bool, ): - tensors = [SparseLatticedTensor(randn_(pshape), basis) for pshape, basis in slt_args] + tensors = [ + SparseLatticedTensor(randn_(pshape), basis, offset=None, shape=None) + for pshape, basis in slt_args + ] res = aten.cat.default(tensors, dim) expected = aten.cat.default([t.to_dense() for t in tensors], dim) From 5956d825e2ab69988a97a66873ab3b3ec13a820b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 3 Dec 2025 16:00:55 +0100 Subject: [PATCH 38/42] Rename padding to margin --- .../sparse/_aten_function_overrides/shape.py | 2 +- src/torchjd/sparse/_sparse_latticed_tensor.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 7cda9a57b..03b1048b7 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -40,7 +40,7 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor: assert isinstance(t, SparseLatticedTensor) - if not torch.equal(t.padding, torch.zeros_like(t.padding)): + if not torch.equal(t.margin, torch.zeros_like(t.margin)): raise NotImplementedError() shape = infer_shape(shape, t.numel()) diff --git a/src/torchjd/sparse/_sparse_latticed_tensor.py b/src/torchjd/sparse/_sparse_latticed_tensor.py index e4a7fc87c..0092d2b47 100644 --- a/src/torchjd/sparse/_sparse_latticed_tensor.py +++ b/src/torchjd/sparse/_sparse_latticed_tensor.py @@ -153,9 +153,9 @@ def decorator(func): return decorator @property - def start_padding(self) -> Tensor: + def start_margin(self) -> Tensor: """ - Returns the number of zeros of padding at the start of each virtual dimension. + Returns the margin at the start of each virtual dimension. Can be negative. The result is an int tensor of shape [virtual_ndim]. """ @@ -163,9 +163,9 @@ def start_padding(self) -> Tensor: return self.offset @property - def end_padding(self) -> Tensor: + def end_margin(self) -> Tensor: """ - Returns the number of zeros of padding at the end of each virtual dimension. + Returns the margin at the end of each virtual dimension. Can be negative. The result is an int tensor of shape [virtual_ndim]. """ @@ -173,14 +173,14 @@ def end_padding(self) -> Tensor: return self.shape_t - self.physical_image_size - self.offset @property - def padding(self) -> Tensor: + def margin(self) -> Tensor: """ - Returns the number of zeros of padding at the start and end of each virtual dimension. + Returns the margin at the start and end of each virtual dimension. Can be negative. The result is an int tensor of shape [virtual_ndim, 2]. """ - return torch.stack([self.start_padding, self.end_padding], dim=1) + return torch.stack([self.start_margin, self.end_margin], dim=1) @property def min_natural_virtual_indices(self) -> Tensor: From ac7e7c105deb6a3fa7b4c13addfa5eaec58f9d7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Wed, 3 Dec 2025 17:22:25 +0100 Subject: [PATCH 39/42] Use margin instead of offset and shape in SLT constructor --- src/torchjd/autogram/_engine.py | 2 +- .../_aten_function_overrides/backward.py | 12 +- .../sparse/_aten_function_overrides/einsum.py | 6 +- .../_aten_function_overrides/pointwise.py | 6 +- .../sparse/_aten_function_overrides/shape.py | 28 ++--- src/torchjd/sparse/_sparse_latticed_tensor.py | 119 +++++++----------- .../sparse/test_sparse_latticed_tensor.py | 25 ++-- 7 files changed, 76 insertions(+), 122 deletions(-) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 95172522c..6574bd498 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -177,7 +177,7 @@ def differentiation(_grad_output: Tensor) -> tuple[Tensor, ...]: output_dims = list(range(output.ndim)) identity = torch.eye(output.ndim, dtype=torch.int64) basis = torch.concatenate([identity, identity], dim=0) - jac_output = make_slt(torch.ones_like(output), basis, None, None) + jac_output = make_slt(torch.ones_like(output), basis, None) vmapped_diff = differentiation for _ in output_dims: diff --git a/src/torchjd/sparse/_aten_function_overrides/backward.py b/src/torchjd/sparse/_aten_function_overrides/backward.py index f1c3ceb38..beae4ef1c 100644 --- a/src/torchjd/sparse/_aten_function_overrides/backward.py +++ b/src/torchjd/sparse/_aten_function_overrides/backward.py @@ -10,9 +10,7 @@ def threshold_backward_default( ) -> SparseLatticedTensor: new_physical = aten.threshold_backward.default(grad_output.physical, self, threshold) - return SparseLatticedTensor( - new_physical, grad_output.basis, grad_output.offset, grad_output.shape_t - ) + return SparseLatticedTensor(new_physical, grad_output.basis, grad_output.margin) @impl(aten.hardtanh_backward.default) @@ -26,9 +24,7 @@ def hardtanh_backward_default( raise NotImplementedError() new_physical = aten.hardtanh_backward.default(grad_output.physical, self, min_val, max_val) - return SparseLatticedTensor( - new_physical, grad_output.basis, grad_output.offset, grad_output.shape_t - ) + return SparseLatticedTensor(new_physical, grad_output.basis, grad_output.margin) @impl(aten.hardswish_backward.default) @@ -37,6 +33,4 @@ def hardswish_backward_default(grad_output: SparseLatticedTensor, self: Tensor): raise NotImplementedError() new_physical = aten.hardswish_backward.default(grad_output.physical, self) - return SparseLatticedTensor( - new_physical, grad_output.basis, grad_output.offset, grad_output.shape_t - ) + return SparseLatticedTensor(new_physical, grad_output.basis, grad_output.margin) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index a9691faea..61693b52c 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -165,7 +165,7 @@ def mul_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: @impl(aten.div.Tensor) def div_Tensor(t1: Tensor | int | float, t2: Tensor | int | float) -> Tensor: t1_, t2_ = prepare_for_elementwise_op(t1, t2) - t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.basis, t2_.offset, t2_.shape_t) + t2_ = SparseLatticedTensor(1.0 / t2_.physical, t2_.basis, t2_.margin) all_dims = list(range(t1_.ndim)) return einsum((t1_, all_dims), (t2_, all_dims), output=all_dims) @@ -177,7 +177,7 @@ def mul_Scalar(t: SparseLatticedTensor, scalar) -> SparseLatticedTensor: assert isinstance(t, SparseLatticedTensor) new_physical = aten.mul.Scalar(t.physical, scalar) - return SparseLatticedTensor(new_physical, t.basis, t.offset, t.shape_t) + return SparseLatticedTensor(new_physical, t.basis, t.margin) @impl(aten.add.Tensor) @@ -192,7 +192,7 @@ def add_Tensor( and torch.equal(t1_.shape_t, t2_.shape_t) ): new_physical = t1_.physical + t2_.physical * alpha - return SparseLatticedTensor(new_physical, t1_.basis, t1_.offset, t1_.shape_t) + return SparseLatticedTensor(new_physical, t1_.basis, t1_.margin) else: raise NotImplementedError() diff --git a/src/torchjd/sparse/_aten_function_overrides/pointwise.py b/src/torchjd/sparse/_aten_function_overrides/pointwise.py index e9babf5de..44874f335 100644 --- a/src/torchjd/sparse/_aten_function_overrides/pointwise.py +++ b/src/torchjd/sparse/_aten_function_overrides/pointwise.py @@ -71,7 +71,7 @@ def _override_pointwise(op): @impl(op) def func_(t: SparseLatticedTensor) -> SparseLatticedTensor: assert isinstance(t, SparseLatticedTensor) - return SparseLatticedTensor(op(t.physical), t.basis, t.offset, t.shape) + return SparseLatticedTensor(op(t.physical), t.basis, t.margin) return func_ @@ -100,7 +100,7 @@ def pow_Tensor_Scalar(t: SparseLatticedTensor, exponent: float) -> SparseLattice return aten.pow.Tensor_Scalar(t.to_dense(), exponent) new_physical = aten.pow.Tensor_Scalar(t.physical, exponent) - return SparseLatticedTensor(new_physical, t.basis, t.offset, t.shape) + return SparseLatticedTensor(new_physical, t.basis, t.margin) # Somehow there's no pow_.Tensor_Scalar and pow_.Scalar takes tensor and scalar. @@ -122,4 +122,4 @@ def div_Scalar(t: SparseLatticedTensor, divisor: float) -> SparseLatticedTensor: assert isinstance(t, SparseLatticedTensor) new_physical = aten.div.Scalar(t.physical, divisor) - return SparseLatticedTensor(new_physical, t.basis, t.offset, t.shape) + return SparseLatticedTensor(new_physical, t.basis, t.margin) diff --git a/src/torchjd/sparse/_aten_function_overrides/shape.py b/src/torchjd/sparse/_aten_function_overrides/shape.py index 03b1048b7..07e70f7b9 100644 --- a/src/torchjd/sparse/_aten_function_overrides/shape.py +++ b/src/torchjd/sparse/_aten_function_overrides/shape.py @@ -54,8 +54,8 @@ def view_default(t: SparseLatticedTensor, shape: list[int]) -> Tensor: c_prime = _reverse_cumulative_product(shape) new_basis = ((c @ S).unsqueeze(0) // c_prime.unsqueeze(1)) % tensor(shape).unsqueeze(1) - new_offset = torch.zeros(len(shape), dtype=torch.int64) - return to_most_efficient_tensor(t.physical, new_basis, new_offset, shape) + new_margin = torch.zeros([len(shape), 2], dtype=torch.int64) + return to_most_efficient_tensor(t.physical, new_basis, new_margin) def _reverse_cumulative_product(values: list[int]) -> Tensor: @@ -90,9 +90,8 @@ def unsqueeze_default(t: SparseLatticedTensor, dim: int) -> SparseLatticedTensor pdims = t.basis.shape[1] new_basis = cat([t.basis[:dim], torch.zeros(1, pdims, dtype=torch.int64), t.basis[dim:]]) - new_offset = cat([t.offset[:dim], torch.zeros(1, dtype=torch.int64), t.offset[dim:]]) - new_size = cat([t.shape_t[:dim], torch.ones(1, dtype=torch.int64), t.shape_t[dim:]]) - return SparseLatticedTensor(t.physical, new_basis, new_offset, new_size) + new_margin = cat([t.margin[:dim], torch.zeros([1, 2], dtype=torch.int64), t.margin[dim:]]) + return SparseLatticedTensor(t.physical, new_basis, new_margin) @impl(aten.squeeze.dims) @@ -109,17 +108,15 @@ def squeeze_dims(t: SparseLatticedTensor, dims: list[int] | int | None) -> Tenso is_row_kept = [i not in excluded for i in range(t.ndim)] new_basis = t.basis[is_row_kept] - new_offset = t.offset[is_row_kept] - new_size = t.shape_t[is_row_kept] - return to_most_efficient_tensor(t.physical, new_basis, new_offset, new_size) + new_margin = t.margin[is_row_kept] + return to_most_efficient_tensor(t.physical, new_basis, new_margin) @impl(aten.permute.default) def permute_default(t: SparseLatticedTensor, dims: list[int]) -> SparseLatticedTensor: new_basis = t.basis[dims] - new_offset = t.offset[dims] - new_size = t.shape_t[dims] - return SparseLatticedTensor(t.physical, new_basis, new_offset, new_size) + new_margin = t.margin[dims] + return SparseLatticedTensor(t.physical, new_basis, new_margin) @impl(aten.cat.default) @@ -148,7 +145,6 @@ def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedT # Try to expand each dimension to its new size new_physical = t.physical new_basis = t.basis - new_sizes = t.shape_t for d, (v, orig_size, new_size) in enumerate(zip(t.basis, t.shape, sizes, strict=True)): if v.sum() > 0 and orig_size != new_size and new_size != -1: raise ValueError( @@ -166,9 +162,8 @@ def expand_default(t: SparseLatticedTensor, sizes: list[int]) -> SparseLatticedT new_basis_vector = torch.zeros(t.ndim, 1, dtype=torch.int64) new_basis_vector[d, 0] = 1 new_basis = torch.cat([new_basis, new_basis_vector], dim=1) - new_sizes[d] = new_size - return SparseLatticedTensor(new_physical, new_basis, t.offset, new_sizes) + return SparseLatticedTensor(new_physical, new_basis, t.margin) @impl(aten.broadcast_tensors.default) @@ -207,7 +202,6 @@ def transpose_int(t: SparseLatticedTensor, dim0: int, dim1: int) -> SparseLattic new_index[dim1] = dim0 new_basis = t.basis[new_index] - new_offset = t.offset[new_index] - new_shape = list(tensor(t.shape, dtype=torch.int64)[new_index]) + new_margin = t.margin[new_index] - return SparseLatticedTensor(t.physical, new_basis, new_offset, new_shape) + return SparseLatticedTensor(t.physical, new_basis, new_margin) diff --git a/src/torchjd/sparse/_sparse_latticed_tensor.py b/src/torchjd/sparse/_sparse_latticed_tensor.py index 0092d2b47..8c311c9f2 100644 --- a/src/torchjd/sparse/_sparse_latticed_tensor.py +++ b/src/torchjd/sparse/_sparse_latticed_tensor.py @@ -16,8 +16,7 @@ def __new__( cls, physical: Tensor, basis: Tensor, - offset: Tensor | None = None, - shape: list[int] | tuple[int, ...] | torch.Size | Tensor | None = None, + margin: Tensor | None = None, ): assert basis.dtype == torch.int64 @@ -31,9 +30,11 @@ def __new__( # (which is bad!) assert not physical.requires_grad or not torch.is_grad_enabled() - if shape is None: - pshape = tensor(physical.shape, dtype=torch.int64) - shape = basis @ (pshape - 1) + 1 + if margin is None: + margin = torch.zeros([basis.shape[0], 2], dtype=torch.int64) + + pshape_t = tensor(physical.shape, dtype=torch.int64) + shape = physical_image_size(basis, pshape_t) + margin.sum(dim=1) return Tensor._make_wrapper_subclass( cls, list(shape), dtype=physical.dtype, device=physical.device @@ -43,8 +44,7 @@ def __init__( self, physical: Tensor, basis: Tensor, - offset: Tensor | None, - shape: list[int] | tuple[int, ...] | torch.Size | Tensor | None, + margin: Tensor | None, ): """ This constructor is made for specifying physical and basis exactly. It should not modify @@ -56,17 +56,12 @@ def __init__( :param physical: The dense tensor holding the actual data. :param basis: Integer (int64) tensor of shape [virtual_ndim, physical_ndim], representing the linear transformation between an index in the physical tensor and the corresponding - index in the virtual tensor, i.e. v_index = basis @ p_index + offset. - :param offset: Offset for the virtual index, i.e. v_index = basis @ p_index + offset. - :param shape: Size of the sparse tensor. If not provided, the size will be inferred as the - minimum size big enough to hold all non-zero elements. - - # TODO: make a nicer interface where it's possible to provide lists or sizes instead of - always having to provide int tensors + index in the virtual tensor, i.e. v_index = basis @ p_index + margin[:, 0] + :param margin: Number of extra elements at the start and end of each virtual dimension. """ - if offset is None: - offset = torch.zeros(len(self.shape), dtype=torch.int64) + if margin is None: + margin = torch.zeros([basis.shape[0], 2], dtype=torch.int64) if any(s == 1 for s in physical.shape): raise ValueError( @@ -94,16 +89,8 @@ def __init__( self.physical = physical self.basis = basis - self.offset = offset - - if shape is None: - pshape = tensor(physical.shape, dtype=torch.int64) - shape = basis @ (pshape - 1) + 1 - if isinstance(shape, torch.Tensor): - self.shape_t = shape - else: - self.shape_t = tensor(shape, dtype=torch.int64) - + self.margin = margin + self.shape_t = tensor(self.shape, dtype=torch.int64) self.pshape_t = tensor(physical.shape, dtype=torch.int64) def to_dense( @@ -139,7 +126,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): return func(*unwrapped_args, **unwrapped_kwargs) def __repr__(self, *, tensor_contents=None) -> str: - return f"SparseLatticedTensor(physical={self.physical}, basis={self.basis}, offset={self.offset}, size={self.shape_t})" + return f"SparseLatticedTensor(physical={self.physical}, basis={self.basis}, margin={self.margin})" @classmethod def implements(cls, torch_function): @@ -153,69 +140,51 @@ def decorator(func): return decorator @property - def start_margin(self) -> Tensor: + def offset(self) -> Tensor: """ Returns the margin at the start of each virtual dimension. Can be negative. The result is an int tensor of shape [virtual_ndim]. """ - return self.offset - - @property - def end_margin(self) -> Tensor: - """ - Returns the margin at the end of each virtual dimension. Can be negative. - - The result is an int tensor of shape [virtual_ndim]. - """ - - return self.shape_t - self.physical_image_size - self.offset + return self.margin[:, 0] - @property - def margin(self) -> Tensor: - """ - Returns the margin at the start and end of each virtual dimension. Can be negative. - The result is an int tensor of shape [virtual_ndim, 2]. - """ +impl = SparseLatticedTensor.implements - return torch.stack([self.start_margin, self.end_margin], dim=1) - @property - def min_natural_virtual_indices(self) -> Tensor: - # Basis where each positive element is replaced by 0 - non_positive_basis = torch.min(self.basis, torch.zeros_like(self.basis)) - max_physical_index = self.pshape_t - 1 - return (non_positive_basis * max_physical_index.unsqueeze(0)).sum(dim=1) +def min_natural_virtual_indices(basis: Tensor, pshape: Tensor) -> Tensor: + # Basis where each positive element is replaced by 0 + non_positive_basis = torch.min(basis, torch.zeros_like(basis)) + max_physical_index = pshape - 1 + return (non_positive_basis * max_physical_index.unsqueeze(0)).sum(dim=1) - @property - def max_natural_virtual_indices(self) -> Tensor: - # Basis where each negative element is replaced by 0 - non_negative = torch.max(self.basis, torch.zeros_like(self.basis)) - max_physical_index = self.pshape_t - 1 - return (non_negative * max_physical_index.unsqueeze(0)).sum(dim=1) - @property - def physical_image_size(self) -> Tensor: - """ - Returns the shape of the image of the physical through the basis transform. +def max_natural_virtual_indices(basis: Tensor, pshape: Tensor) -> Tensor: + # Basis where each negative element is replaced by 0 + non_negative = torch.max(basis, torch.zeros_like(basis)) + max_physical_index = pshape - 1 + return (non_negative * max_physical_index.unsqueeze(0)).sum(dim=1) - The result is an int tensor of shape [virtual_ndim]. - """ - one = torch.ones(self.ndim, dtype=torch.int64) - return self.max_natural_virtual_indices - self.min_natural_virtual_indices + one +def physical_image_size(basis: Tensor, pshape: Tensor) -> Tensor: + """ + Returns the shape of the image of the physical through the basis transform. + The result is an int tensor of shape [virtual_ndim]. + """ -impl = SparseLatticedTensor.implements + one = torch.ones(basis.shape[0], dtype=torch.int64) + max_idx = max_natural_virtual_indices(basis, pshape) + min_idx = min_natural_virtual_indices(basis, pshape) + return max_idx - min_idx + one def print_fallback(func, args, kwargs) -> None: def tensor_to_str(t: Tensor) -> str: result = f"{t.__class__.__name__} - shape: {t.shape}" if isinstance(t, SparseLatticedTensor): - result += f" - pshape: {t.physical.shape} - basis: {t.basis} - offset: {t.offset}" + result += f" - pshape: {t.physical.shape} - basis: {t.basis} - margin: {t.margin}" return result @@ -258,14 +227,13 @@ def to_sparse_latticed_tensor(t: Tensor) -> SparseLatticedTensor: if isinstance(t, SparseLatticedTensor): return t else: - return make_slt(t, torch.eye(t.ndim, dtype=torch.int64), None, None) + return make_slt(t, torch.eye(t.ndim, dtype=torch.int64), None) def to_most_efficient_tensor( physical: Tensor, basis: Tensor, - offset: Tensor | None, - size: list[int] | tuple[int, ...] | torch.Size | Tensor | None, + margin: Tensor | None, ) -> Tensor: physical, basis = fix_dim_of_size_1(physical, basis) physical, basis = fix_ungrouped_dims(physical, basis) @@ -275,9 +243,9 @@ def to_most_efficient_tensor( # TODO: this condition is broken if basis is allowed to have negative values. It also only # works when size is the default and offset is 0. # TODO: this can be done more efficiently (without even creating the SLT) - return SparseLatticedTensor(physical, basis, offset, size).to_dense() + return SparseLatticedTensor(physical, basis, margin).to_dense() else: - return SparseLatticedTensor(physical, basis, offset, size) + return SparseLatticedTensor(physical, basis, margin) def unwrap_to_dense(t: Tensor): @@ -336,11 +304,10 @@ def fix_ungrouped_dims(physical: Tensor, basis: Tensor) -> tuple[Tensor, Tensor] def make_slt( physical: Tensor, basis: Tensor, - offset: Tensor | None, - size: list[int] | tuple[int, ...] | torch.Size | Tensor | None, + margin: Tensor | None, ) -> SparseLatticedTensor: """Fix physical and basis and create a SparseLatticedTensor with them.""" physical, basis = fix_dim_of_size_1(physical, basis) physical, basis = fix_ungrouped_dims(physical, basis) - return SparseLatticedTensor(physical, basis, offset, size) + return SparseLatticedTensor(physical, basis, margin) diff --git a/tests/unit/sparse/test_sparse_latticed_tensor.py b/tests/unit/sparse/test_sparse_latticed_tensor.py index c94bb9b09..d72dea425 100644 --- a/tests/unit/sparse/test_sparse_latticed_tensor.py +++ b/tests/unit/sparse/test_sparse_latticed_tensor.py @@ -23,7 +23,7 @@ def test_to_dense(): n = 2 m = 3 a = randn_([n, m]) - b = SparseLatticedTensor(a, tensor([[1, 0], [0, 1], [0, 1], [1, 0]]), offset=None, shape=None) + b = SparseLatticedTensor(a, tensor([[1, 0], [0, 1], [0, 1], [1, 0]]), margin=None) c = b.to_dense() for i in range(n): @@ -33,7 +33,7 @@ def test_to_dense(): def test_to_dense2(): a = tensor_([1.0, 2.0, 3.0]) - b = SparseLatticedTensor(a, tensor([[4]]), offset=None, shape=None) + b = SparseLatticedTensor(a, tensor([[4]]), margin=None) c = b.to_dense() expected = tensor_([1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]) assert torch.all(torch.eq(c, expected)) @@ -80,8 +80,8 @@ def test_einsum( b_indices: list[int], output_indices: list[int], ): - a = SparseLatticedTensor(randn_(a_pshape), a_basis, offset=None, shape=None) - b = SparseLatticedTensor(randn_(b_pshape), b_basis, offset=None, shape=None) + a = SparseLatticedTensor(randn_(a_pshape), a_basis, margin=None) + b = SparseLatticedTensor(randn_(b_pshape), b_basis, margin=None) res = einsum((a, a_indices), (b, b_indices), output=output_indices) @@ -102,7 +102,7 @@ def test_einsum( ) def test_sparse_latticed_tensor_scalar(shape: list[int]): a = randn_(shape) - b = SparseLatticedTensor(a, torch.eye(len(shape), dtype=torch.int64), offset=None, shape=None) + b = SparseLatticedTensor(a, torch.eye(len(shape), dtype=torch.int64), margin=None) assert_close(a, b.to_dense()) @@ -110,7 +110,7 @@ def test_sparse_latticed_tensor_scalar(shape: list[int]): @mark.parametrize("dim", [2, 3, 4, 5, 10]) def test_diag_equivalence(dim: int): a = randn_([dim]) - b = SparseLatticedTensor(a, tensor([[1], [1]]), offset=None, shape=None) + b = SparseLatticedTensor(a, tensor([[1], [1]]), margin=None) diag_a = torch.diag(a) @@ -120,7 +120,7 @@ def test_diag_equivalence(dim: int): def test_three_virtual_single_physical(): dim = 10 a = randn_([dim]) - b = SparseLatticedTensor(a, tensor([[1], [1], [1]]), offset=None, shape=None) + b = SparseLatticedTensor(a, tensor([[1], [1], [1]]), margin=None) expected = zeros_([dim, dim, dim]) for i in range(dim): @@ -133,7 +133,7 @@ def test_three_virtual_single_physical(): def test_pointwise(func): dim = 10 a = randn_([dim]) - b = SparseLatticedTensor(a, tensor([[1], [1]]), offset=None, shape=None) + b = SparseLatticedTensor(a, tensor([[1], [1]]), margin=None) c = b.to_dense() res = func(b) assert isinstance(res, SparseLatticedTensor) @@ -145,7 +145,7 @@ def test_pointwise(func): def test_inplace_pointwise(func): dim = 10 a = randn_([dim]) - b = SparseLatticedTensor(a, tensor([[1], [1]]), offset=None, shape=None) + b = SparseLatticedTensor(a, tensor([[1], [1]]), margin=None) c = b.to_dense() func(b) assert isinstance(b, SparseLatticedTensor) @@ -157,7 +157,7 @@ def test_inplace_pointwise(func): def test_unary(func): dim = 10 a = randn_([dim]) - b = SparseLatticedTensor(a, tensor([[1], [1]]), offset=None, shape=None) + b = SparseLatticedTensor(a, tensor([[1], [1]]), margin=None) c = b.to_dense() res = func(b) @@ -254,7 +254,7 @@ def test_view( expected_basis: Tensor, ): a = randn_(tuple(physical_shape)) - t = SparseLatticedTensor(a, basis, offset=None, shape=None) + t = SparseLatticedTensor(a, basis, margin=None) result = aten.view.default(t, target_shape) expected = t.to_dense().reshape(target_shape) @@ -345,8 +345,7 @@ def test_concatenate( expected_densify: bool, ): tensors = [ - SparseLatticedTensor(randn_(pshape), basis, offset=None, shape=None) - for pshape, basis in slt_args + SparseLatticedTensor(randn_(pshape), basis, margin=None) for pshape, basis in slt_args ] res = aten.cat.default(tensors, dim) expected = aten.cat.default([t.to_dense() for t in tensors], dim) From 35270e430ce4bb4bdb06dc6e021e9d8c5f30ba39 Mon Sep 17 00:00:00 2001 From: Pierre Quinton Date: Fri, 5 Dec 2025 08:56:44 +0100 Subject: [PATCH 40/42] Remove solve_int --- src/torchjd/sparse/_linalg.py | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 56378e24e..5699b89b1 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -4,28 +4,6 @@ # TODO: Implement in C everything in this file. -def solve_int(A: Tensor, B: Tensor, tol=1e-9) -> Tensor | None: - """ - Solve A X = B where A, B and X have integer dtype. - Return X if such a matrix exists and otherwise None. - """ - - A_ = A.to(torch.float64) - B_ = B.to(torch.float64) - - try: - X = torch.linalg.solve(A_, B_) - except RuntimeError: - return None - - X_rounded = X.round() - if not torch.all(torch.isclose(X, X_rounded, atol=tol)): - return None - - # TODO: Verify that the round operation cannot fail - return X_rounded.to(torch.int64) - - def extended_gcd(a: int, b: int) -> tuple[int, int, int]: """ Extended Euclidean Algorithm (Python integers). From ea897c97dc509c76ad9e394d0813c43f84cc2d41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Fri, 12 Dec 2025 20:47:10 +0100 Subject: [PATCH 41/42] Fix remaining mypy errors --- src/torchjd/sparse/_aten_function_overrides/einsum.py | 4 ++-- src/torchjd/sparse/_linalg.py | 8 +++++--- src/torchjd/sparse/_sparse_latticed_tensor.py | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/torchjd/sparse/_aten_function_overrides/einsum.py b/src/torchjd/sparse/_aten_function_overrides/einsum.py index 61693b52c..afbdf7a9a 100644 --- a/src/torchjd/sparse/_aten_function_overrides/einsum.py +++ b/src/torchjd/sparse/_aten_function_overrides/einsum.py @@ -138,12 +138,12 @@ def prepare_for_elementwise_op( assert isinstance(t1, SparseLatticedTensor) or isinstance(t2, SparseLatticedTensor) if isinstance(t1, int) or isinstance(t1, float): - t1_ = tensor(t1, device=t2.device) + t1_ = tensor(t1, device=t2.device) # type: ignore[union-attr] else: t1_ = t1 if isinstance(t2, int) or isinstance(t2, float): - t2_ = tensor(t2, device=t1.device) + t2_ = tensor(t2, device=t1.device) # type: ignore[union-attr] else: t2_ = t2 diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 5699b89b1..7efe2034b 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -1,3 +1,5 @@ +from typing import cast + import torch from torch import Tensor @@ -23,7 +25,7 @@ def _get_hermite_factor_rank(H: Tensor) -> int: Computes the rank of a hermit factor matrix. """ col_magnitudes = torch.sum(torch.abs(H), dim=0) - return torch.count_nonzero(col_magnitudes).item() + return cast(int, torch.count_nonzero(col_magnitudes).item()) def hnf_decomposition(A: Tensor, reduced: bool) -> tuple[Tensor, Tensor, Tensor]: @@ -73,8 +75,8 @@ def hnf_decomposition(A: Tensor, reduced: bool) -> tuple[Tensor, Tensor, Tensor] for j in range(col + 1, n): if H[row, j] != 0: - a_val = H[row, col].item() - b_val = H[row, j].item() + a_val = cast(int, H[row, col].item()) + b_val = cast(int, H[row, j].item()) g, x, y = extended_gcd(a_val, b_val) diff --git a/src/torchjd/sparse/_sparse_latticed_tensor.py b/src/torchjd/sparse/_sparse_latticed_tensor.py index 8c311c9f2..ddc1263b5 100644 --- a/src/torchjd/sparse/_sparse_latticed_tensor.py +++ b/src/torchjd/sparse/_sparse_latticed_tensor.py @@ -125,7 +125,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): unwrapped_kwargs = tree_map(unwrap_to_dense, kwargs) return func(*unwrapped_args, **unwrapped_kwargs) - def __repr__(self, *, tensor_contents=None) -> str: + def __repr__(self, *, tensor_contents=None) -> str: # type: ignore[override] return f"SparseLatticedTensor(physical={self.physical}, basis={self.basis}, margin={self.margin})" @classmethod From 5671500e339fffc2907d4d18b7e6d4726b3a9052 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= <31951177+ValerianRey@users.noreply.github.com> Date: Fri, 12 Dec 2025 20:50:34 +0100 Subject: [PATCH 42/42] Reorder lines in src/torchjd/sparse/_linalg.py Co-authored-by: Pierre Quinton --- src/torchjd/sparse/_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/sparse/_linalg.py b/src/torchjd/sparse/_linalg.py index 7efe2034b..761c580d5 100644 --- a/src/torchjd/sparse/_linalg.py +++ b/src/torchjd/sparse/_linalg.py @@ -190,7 +190,7 @@ def compute_lcm(S1: Tensor, S2: Tensor) -> tuple[Tensor, Tensor, Tensor]: H, U, _ = hnf_decomposition(B, False) rank = _get_hermite_factor_rank(H) - M2 = U[n1:, -rank:] M1 = U[:n1, -rank:] + M2 = U[n1:, -rank:] L = S1 @ M1 return L, M1, M2