diff --git a/pyproject.toml b/pyproject.toml index e13772b3c..0efe70a3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,9 +113,13 @@ nash_mtl = [ cagrad = [ "cvxpy>=1.3.0", # No Clarabel solver before 1.3.0 ] +qpth = [ + "qpth>=0.0.15", +] full = [ "cvxpy>=1.3.0", # No Clarabel solver before 1.3.0 "ecos>=2.0.14", # Does not work before 2.0.14 + "qpth>=0.0.15", ] [tool.pytest.ini_options] diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index acb87d2fb..156122d75 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -22,7 +22,9 @@ class DualProjWeighting(GramianWeighting): numerical errors when computing the gramian, it might not exactly be positive definite. This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. + :param solver: The solver used to optimize the underlying optimization problem. Use + ``"quadprog"`` (default) for a CPU-based solver or ``"qpth"`` to solve natively on the + device of the input tensors (requires the optional ``qpth`` package). """ def __init__( @@ -90,7 +92,9 @@ class DualProj(GramianWeightedAggregator): numerical errors when computing the gramian, it might not exactly be positive definite. This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. + :param solver: The solver used to optimize the underlying optimization problem. Use + ``"quadprog"`` (default) for a CPU-based solver or ``"qpth"`` to solve natively on the + device of the input tensors (requires the optional ``qpth`` package). """ gramian_weighting: DualProjWeighting diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 686898297..bc61f6fe3 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -23,7 +23,9 @@ class UPGradWeighting(GramianWeighting): numerical errors when computing the gramian, it might not exactly be positive definite. This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. + :param solver: The solver used to optimize the underlying optimization problem. Use + ``"quadprog"`` (default) for a CPU-based solver or ``"qpth"`` to solve natively on the + device of the input tensors (requires the optional ``qpth`` package). """ def __init__( @@ -93,7 +95,9 @@ class UPGrad(GramianWeightedAggregator): numerical errors when computing the gramian, it might not exactly be positive definite. This issue can make the optimization fail. Adding ``reg_eps`` to the diagonal of the gramian ensures that it is positive definite. - :param solver: The solver used to optimize the underlying optimization problem. + :param solver: The solver used to optimize the underlying optimization problem. Use + ``"quadprog"`` (default) for a CPU-based solver or ``"qpth"`` to solve natively on the + device of the input tensors (requires the optional ``qpth`` package). """ gramian_weighting: UPGradWeighting diff --git a/src/torchjd/aggregation/_utils/dual_cone.py b/src/torchjd/aggregation/_utils/dual_cone.py index b076366be..e3db40f1f 100644 --- a/src/torchjd/aggregation/_utils/dual_cone.py +++ b/src/torchjd/aggregation/_utils/dual_cone.py @@ -5,7 +5,7 @@ from qpsolvers import solve_qp from torch import Tensor -SUPPORTED_SOLVER: TypeAlias = Literal["quadprog"] +SUPPORTED_SOLVER: TypeAlias = Literal["quadprog", "qpth"] def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: @@ -15,10 +15,15 @@ def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: :param U: The tensor of weights corresponding to the vectors to project, of shape `[..., m]`. :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. - :param solver: The quadratic programming solver to use. + :param solver: The quadratic programming solver to use. ``"quadprog"`` converts tensors to + CPU numpy arrays and uses qpsolvers. ``"qpth"`` solves natively on the same device as + the input tensors (e.g. CUDA) using the ``qpth`` package (optional dependency). :return: A tensor of projection weights with the same shape as `U`. """ + if solver == "qpth": + return _project_weights_qpth(U, G) + G_ = _to_array(G) U_ = _to_array(U) @@ -27,6 +32,50 @@ def project_weights(U: Tensor, G: Tensor, solver: SUPPORTED_SOLVER) -> Tensor: return torch.as_tensor(W, device=G.device, dtype=G.dtype) +def _project_weights_qpth(U: Tensor, G: Tensor) -> Tensor: + r""" + Computes the tensor of projection weights using qpth, keeping computation on the device of + the input tensors and running without gradient tracking. + + :param U: The tensor of weights to project, of shape `[..., m]`. + :param G: The Gramian matrix of shape `[m, m]`. It must be symmetric and positive definite. + """ + from qpth.qp import QPFunction # lazy import: qpth is an optional dependency + + shape = U.shape + m = shape[-1] + batch_size = U.numel() // m + device = G.device + original_dtype = G.dtype + + # Use float64 for numerical precision, matching the quadprog solver's behavior. + U_flat = U.reshape(batch_size, m).double() + G_double = G.double() + + # QP formulation: minimize (1/2) v^T (2G) v + 0^T v subject to -I v <= -u (i.e., u <= v) + Q = (2.0 * G_double).unsqueeze(0).expand(batch_size, m, m).contiguous() + p = torch.zeros(batch_size, m, device=device, dtype=torch.float64) + G_ineq = ( + (-torch.eye(m, device=device, dtype=torch.float64)) + .unsqueeze(0) + .expand(batch_size, m, m) + .contiguous() + ) + h_ineq = -U_flat + A = torch.zeros(batch_size, 0, m, device=device, dtype=torch.float64) + b = torch.zeros(batch_size, 0, device=device, dtype=torch.float64) + + with torch.no_grad(): + W_flat = QPFunction(verbose=False, maxIter=10, check_Q_spd=False, notImprovedLim=1)( + Q, p, G_ineq, h_ineq, A, b + ) + + if torch.any(torch.isnan(W_flat)): + raise ValueError("Failed to solve the quadratic programming problem.") + + return W_flat.to(original_dtype).reshape(shape) + + def _project_weight_vector(u: np.ndarray, G: np.ndarray, solver: SUPPORTED_SOLVER) -> np.ndarray: r""" Computes the weights `w` of the projection of `J^T u` onto the dual cone of the rows of `J`, diff --git a/tests/unit/aggregation/test_dualproj.py b/tests/unit/aggregation/test_dualproj.py index 34fe8d462..5125a6aaf 100644 --- a/tests/unit/aggregation/test_dualproj.py +++ b/tests/unit/aggregation/test_dualproj.py @@ -1,5 +1,7 @@ +import importlib.util + import torch -from pytest import mark, raises +from pytest import mark, param, raises from torch import Tensor from utils.tensors import ones_ @@ -15,28 +17,44 @@ ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices +_has_qpth = importlib.util.find_spec("qpth") is not None +_skip_no_qpth = mark.skipif(not _has_qpth, reason="qpth not installed") + scaled_pairs = [(DualProj(), matrix) for matrix in scaled_matrices] typical_pairs = [(DualProj(), matrix) for matrix in typical_matrices] non_strong_pairs = [(DualProj(), matrix) for matrix in non_strong_matrices] requires_grad_pairs = [(DualProj(), ones_(3, 5, requires_grad=True))] +_qpth_typical_pairs = [ + param(DualProj(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in typical_matrices +] +_qpth_non_strong_pairs = [ + param(DualProj(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in non_strong_matrices +] +_qpth_scaled_pairs = [ + param(DualProj(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in scaled_matrices +] -@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) + +@mark.parametrize( + ["aggregator", "matrix"], + scaled_pairs + typical_pairs + _qpth_scaled_pairs + _qpth_typical_pairs, +) def test_expected_structure(aggregator: DualProj, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) -@mark.parametrize(["aggregator", "matrix"], typical_pairs) +@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs) def test_non_conflicting(aggregator: DualProj, matrix: Tensor) -> None: assert_non_conflicting(aggregator, matrix, atol=1e-04, rtol=1e-04) -@mark.parametrize(["aggregator", "matrix"], typical_pairs) +@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs) def test_permutation_invariant(aggregator: DualProj, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=2e-07, rtol=2e-07) -@mark.parametrize(["aggregator", "matrix"], non_strong_pairs) +@mark.parametrize(["aggregator", "matrix"], non_strong_pairs + _qpth_non_strong_pairs) def test_strongly_stationary(aggregator: DualProj, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix, threshold=3e-03) @@ -66,6 +84,13 @@ def test_representations() -> None: assert str(A) == "DualProj([1., 2., 3.])" +@_skip_no_qpth +def test_representations_qpth() -> None: + A = DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="qpth") + assert repr(A) == "DualProj(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='qpth')" + assert str(A) == "DualProj" + + def test_pref_vector_setter_updates_value() -> None: A = DualProj() new_pref = torch.tensor([1.0, 2.0, 3.0]) diff --git a/tests/unit/aggregation/test_upgrad.py b/tests/unit/aggregation/test_upgrad.py index 075680a02..407a93027 100644 --- a/tests/unit/aggregation/test_upgrad.py +++ b/tests/unit/aggregation/test_upgrad.py @@ -1,5 +1,7 @@ +import importlib.util + import torch -from pytest import mark, raises +from pytest import mark, param, raises from torch import Tensor from utils.tensors import ones_ @@ -16,33 +18,49 @@ ) from ._inputs import non_strong_matrices, scaled_matrices, typical_matrices +_has_qpth = importlib.util.find_spec("qpth") is not None +_skip_no_qpth = mark.skipif(not _has_qpth, reason="qpth not installed") + scaled_pairs = [(UPGrad(), matrix) for matrix in scaled_matrices] typical_pairs = [(UPGrad(), matrix) for matrix in typical_matrices] non_strong_pairs = [(UPGrad(), matrix) for matrix in non_strong_matrices] requires_grad_pairs = [(UPGrad(), ones_(3, 5, requires_grad=True))] +_qpth_typical_pairs = [ + param(UPGrad(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in typical_matrices +] +_qpth_non_strong_pairs = [ + param(UPGrad(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in non_strong_matrices +] +_qpth_scaled_pairs = [ + param(UPGrad(solver="qpth"), matrix, marks=_skip_no_qpth) for matrix in scaled_matrices +] -@mark.parametrize(["aggregator", "matrix"], scaled_pairs + typical_pairs) + +@mark.parametrize( + ["aggregator", "matrix"], + scaled_pairs + typical_pairs + _qpth_scaled_pairs + _qpth_typical_pairs, +) def test_expected_structure(aggregator: UPGrad, matrix: Tensor) -> None: assert_expected_structure(aggregator, matrix) -@mark.parametrize(["aggregator", "matrix"], typical_pairs) +@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs) def test_non_conflicting(aggregator: UPGrad, matrix: Tensor) -> None: assert_non_conflicting(aggregator, matrix, atol=4e-04, rtol=4e-04) -@mark.parametrize(["aggregator", "matrix"], typical_pairs) +@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs) def test_permutation_invariant(aggregator: UPGrad, matrix: Tensor) -> None: assert_permutation_invariant(aggregator, matrix, n_runs=5, atol=5e-07, rtol=5e-07) -@mark.parametrize(["aggregator", "matrix"], typical_pairs) +@mark.parametrize(["aggregator", "matrix"], typical_pairs + _qpth_typical_pairs) def test_linear_under_scaling(aggregator: UPGrad, matrix: Tensor) -> None: assert_linear_under_scaling(aggregator, matrix, n_runs=5, atol=6e-02, rtol=6e-02) -@mark.parametrize(["aggregator", "matrix"], non_strong_pairs) +@mark.parametrize(["aggregator", "matrix"], non_strong_pairs + _qpth_non_strong_pairs) def test_strongly_stationary(aggregator: UPGrad, matrix: Tensor) -> None: assert_strongly_stationary(aggregator, matrix, threshold=5e-03) @@ -70,6 +88,13 @@ def test_representations() -> None: assert str(A) == "UPGrad([1., 2., 3.])" +@_skip_no_qpth +def test_representations_qpth() -> None: + A = UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver="qpth") + assert repr(A) == "UPGrad(pref_vector=None, norm_eps=0.0001, reg_eps=0.0001, solver='qpth')" + assert str(A) == "UPGrad" + + def test_pref_vector_setter_updates_value() -> None: A = UPGrad() new_pref = torch.tensor([1.0, 2.0, 3.0])