From 26d0d42109f1689d387da6b6ed4d0d8df37eb3c7 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 18 May 2026 16:03:35 +0200 Subject: [PATCH 1/9] refactor(sos): add Model.apply/undo_sos_reformulation methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduce a stateful pair of methods on Model that own the SOS reformulation lifecycle: - apply_sos_reformulation() stashes the reformulation token on the model (new _sos_reformulation_state attribute). Raises if already applied. - undo_sos_reformulation() reads the stashed token and restores the original SOS form. No-op if nothing is applied. Model.solve(reformulate_sos=...) now delegates to these methods rather than threading the token through local state. The Solver path (which was previously raising via Model.solve's pre-flight check) now gets a clean ValueError directly from Solver._build() when an SOS-bearing model is handed to a solver without native SOS support — making the low-level API safe to use independently of Model.solve. Persistence: - copy() (and copy.copy / copy.deepcopy) carry the reformulation token with a deepcopy, so the copy is independently undoable. - to_netcdf() raises if a reformulation is active; users must undo first to serialize a stable model state. Context: motivated by the same investigation as PyPSA/linopy#688 — while reviewing the new Solver.from_model() API surface introduced by #682, the SOS reformulation lifecycle stood out as load-bearing orchestration that the Solver path couldn't reproduce. Co-Authored-By: Claude Opus 4.7 (1M context) --- linopy/io.py | 17 +++++ linopy/model.py | 52 +++++++++++--- linopy/solvers.py | 8 +++ test/test_sos_constraints.py | 7 +- test/test_sos_reformulation.py | 124 +++++++++++++++++++++++++++++++++ 5 files changed, 194 insertions(+), 14 deletions(-) diff --git a/linopy/io.py b/linopy/io.py index 36d7abb3..ba0400b6 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -5,6 +5,7 @@ from __future__ import annotations +import copy as _copy import json import logging import shutil @@ -828,7 +829,20 @@ def to_netcdf(m: Model, *args: Any, **kwargs: Any) -> None: Arguments passed to ``xarray.Dataset.to_netcdf``. **kwargs : TYPE Keyword arguments passed to ``xarray.Dataset.to_netcdf``. + + Raises + ------ + RuntimeError + If the model has an active SOS reformulation. Call + :meth:`Model.undo_sos_reformulation` before serializing. """ + if m._sos_reformulation_state is not None: + raise RuntimeError( + "Cannot serialize a model with an active SOS reformulation. " + "Call `model.undo_sos_reformulation()` first to restore the " + "original SOS form, or save the model in its reformulated form " + "after explicitly clearing `model._sos_reformulation_state`." + ) def with_prefix(ds: xr.Dataset, prefix: str) -> xr.Dataset: to_rename = set([*ds.dims, *ds.coords, *ds]) @@ -1100,6 +1114,9 @@ def copy(m: Model, include_solution: bool = False, deep: bool = True) -> Model: if include_solution or attr not in SOLVE_STATE_ATTRS: setattr(new_model, attr, getattr(m, attr)) + if m._sos_reformulation_state is not None: + new_model._sos_reformulation_state = _copy.deepcopy(m._sos_reformulation_state) + return new_model diff --git a/linopy/model.py b/linopy/model.py index 4eb91fc6..be572279 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -89,6 +89,7 @@ available_solvers, ) from linopy.sos_reformulation import ( + SOSReformulationResult, reformulate_sos_constraints, undo_sos_reformulation, ) @@ -239,6 +240,7 @@ class Model: "_relaxed_registry", "_piecewise_formulations", "_solver", + "_sos_reformulation_state", "__weakref__", ) @@ -309,6 +311,7 @@ def __init__( gettempdir() if solver_dir is None else solver_dir ) self._solver: solvers.Solver | None = None + self._sos_reformulation_state: SOSReformulationResult | None = None @property def solver(self) -> solvers.Solver | None: @@ -1220,6 +1223,39 @@ def remove_sos_constraints(self, variable: Variable) -> None: reformulate_sos_constraints = reformulate_sos_constraints + def apply_sos_reformulation(self) -> None: + """ + Reformulate SOS constraints into binary + linear form, in place. + + The reformulation token is stored on the model so it can be reverted + with :meth:`undo_sos_reformulation`. This is the stateful counterpart + to :func:`linopy.sos_reformulation.reformulate_sos_constraints`, where + the caller owns the token. + + Raises + ------ + RuntimeError + If a reformulation has already been applied and not undone. + """ + if self._sos_reformulation_state is not None: + raise RuntimeError( + "SOS reformulation has already been applied to this model. " + "Call `undo_sos_reformulation()` before applying again." + ) + self._sos_reformulation_state = reformulate_sos_constraints(self) + + def undo_sos_reformulation(self) -> None: + """ + Revert a previously applied SOS reformulation. + + No-op if no reformulation is currently applied. + """ + if self._sos_reformulation_state is None: + return + state = self._sos_reformulation_state + self._sos_reformulation_state = None + undo_sos_reformulation(self, state) + def remove_objective(self) -> None: """ Remove the objective's linear expression from the model. @@ -1711,22 +1747,20 @@ def solve( "Must be True, False, or 'auto'." ) - sos_reform_result = None + applied_sos_reformulation_here = False if self.variables.sos: supports_sos = solver_class.supports(SolverFeature.SOS_CONSTRAINTS) if reformulate_sos in (True, "auto") and not supports_sos: logger.info(f"Reformulating SOS constraints for solver {solver_name}") - sos_reform_result = reformulate_sos_constraints(self) + self.apply_sos_reformulation() + applied_sos_reformulation_here = True elif reformulate_sos is True and supports_sos: logger.warning( f"Solver {solver_name} supports SOS natively; " "reformulate_sos=True is ignored." ) - elif reformulate_sos is False and not supports_sos: - raise ValueError( - f"Solver {solver_name} does not support SOS constraints. " - "Use reformulate_sos=True or 'auto', or a solver that supports SOS (gurobi, cplex)." - ) + # If SOS is present and the solver doesn't support it (and the user + # didn't ask for reformulation), Solver._build() will raise. if self.variables.semi_continuous: if not solver_class.supports(SolverFeature.SEMI_CONTINUOUS_VARIABLES): @@ -1778,8 +1812,8 @@ def solve( try: return self.assign_result(result) finally: - if sos_reform_result is not None: - undo_sos_reformulation(self, sos_reform_result) + if applied_sos_reformulation_here: + self.undo_sos_reformulation() def assign_result(self, result: Result) -> tuple[str, str]: result.info() diff --git a/linopy/solvers.py b/linopy/solvers.py index 548db835..0ca5d956 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -507,6 +507,14 @@ def _build(self, **build_kwargs: Any) -> None: """Dispatch to direct or file build based on ``io_api``.""" if self.model is None: raise RuntimeError("Solver has no model attached; cannot build.") + if self.model.variables.sos and not type(self).supports( + SolverFeature.SOS_CONSTRAINTS + ): + raise ValueError( + f"Solver {self.solver_name.value} does not support SOS constraints. " + "Call `model.apply_sos_reformulation()` first, or use a solver that " + "supports SOS." + ) if self.io_api == "direct": self._build_direct(**build_kwargs) else: diff --git a/test/test_sos_constraints.py b/test/test_sos_constraints.py index 5d94162e..3c3e79f6 100644 --- a/test/test_sos_constraints.py +++ b/test/test_sos_constraints.py @@ -150,7 +150,7 @@ def test_unsupported_solver_raises_error() -> None: m.solve(solver_name=solver) -def test_to_highspy_raises_not_implemented() -> None: +def test_to_highspy_raises_when_sos_present() -> None: pytest.importorskip("highspy") m = Model() @@ -158,8 +158,5 @@ def test_to_highspy_raises_not_implemented() -> None: build = m.add_variables(coords=[locations], name="build", binary=True) m.add_sos_constraints(build, sos_type=1, sos_dim="locations") - with pytest.raises( - NotImplementedError, - match="SOS constraints are not supported by the HiGHS direct API", - ): + with pytest.raises(ValueError, match="does not support SOS constraints"): m.to_highspy() diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index 24ba62b3..f2f48075 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -3,6 +3,8 @@ from __future__ import annotations import logging +from collections.abc import Callable +from pathlib import Path import numpy as np import pandas as pd @@ -312,6 +314,128 @@ def test_reformulate_inplace(self) -> None: assert "_sos_reform_x_y" in m.variables +class TestApplyUndoSOSReformulation: + """Tests for Model.apply_sos_reformulation / undo_sos_reformulation.""" + + @staticmethod + def _build_sos1_model() -> Model: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + return m + + def test_apply_stashes_state(self) -> None: + m = self._build_sos1_model() + assert m._sos_reformulation_state is None + + m.apply_sos_reformulation() + + assert m._sos_reformulation_state is not None + assert m._sos_reformulation_state.reformulated == ["x"] + assert len(list(m.variables.sos)) == 0 + assert "_sos_reform_x_y" in m.variables + + def test_undo_restores_and_clears_state(self) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + m.undo_sos_reformulation() + + assert m._sos_reformulation_state is None + assert list(m.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in m.variables + + def test_double_apply_raises(self) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + with pytest.raises(RuntimeError, match="already been applied"): + m.apply_sos_reformulation() + + def test_undo_without_apply_is_noop(self) -> None: + m = self._build_sos1_model() + assert m._sos_reformulation_state is None + + m.undo_sos_reformulation() # should not raise + + assert m._sos_reformulation_state is None + assert list(m.variables.sos) == ["x"] + + @pytest.mark.parametrize( + "copy_fn", + [ + pytest.param(lambda m: m.copy(), id="model.copy()"), + pytest.param(lambda m: __import__("copy").copy(m), id="copy.copy(model)"), + pytest.param( + lambda m: __import__("copy").deepcopy(m), id="copy.deepcopy(model)" + ), + ], + ) + def test_copy_persists_state_and_undo_works_on_copy( + self, copy_fn: Callable[[Model], Model] + ) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + c = copy_fn(m) + + # State is carried over but is an independent object + assert c._sos_reformulation_state is not None + assert c._sos_reformulation_state is not m._sos_reformulation_state + # Aux vars/cons exist on the copy (they were copied as part of the + # reformulated model state) + assert "_sos_reform_x_y" in c.variables + assert "_sos_reform_x_upper" in c.constraints + assert "_sos_reform_x_card" in c.constraints + # SOS attrs are not on the copy's "x" yet (still in reformulated form) + assert "x" not in list(c.variables.sos) + + # Undo on the copy fully restores the original SOS form + c.undo_sos_reformulation() + assert c._sos_reformulation_state is None + assert list(c.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in c.variables + assert "_sos_reform_x_upper" not in c.constraints + assert "_sos_reform_x_card" not in c.constraints + + # Original is entirely unaffected + assert m._sos_reformulation_state is not None + assert "_sos_reform_x_y" in m.variables + assert len(list(m.variables.sos)) == 0 + + def test_to_netcdf_raises_when_state_active(self, tmp_path: Path) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + with pytest.raises(RuntimeError, match="active SOS reformulation"): + m.to_netcdf(tmp_path / "m.nc") + + def test_to_netcdf_works_after_undo(self, tmp_path: Path) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + m.undo_sos_reformulation() + + m.to_netcdf(tmp_path / "m.nc") # should not raise + + +@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") +class TestSolverPathSOSCheck: + """Solver._build() must raise on SOS-bearing model with non-SOS solver.""" + + def test_solver_from_name_raises_without_reformulation(self) -> None: + from linopy import solvers + + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x.sum(), sense="max") + + with pytest.raises(ValueError, match="does not support SOS"): + solvers.Solver.from_name("highs", m, io_api="lp") + + @pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") class TestSolveWithReformulation: """Tests for solving with SOS reformulation.""" From 68d75cd8fee240d85443d8165b366536a3b78f16 Mon Sep 17 00:00:00 2001 From: Felix <117816358+FBumann@users.noreply.github.com> Date: Mon, 18 May 2026 20:53:06 +0200 Subject: [PATCH 2/9] refactor(solver): validation, sanitize kwargs, and result wiring on Solver path (#691) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * refactor(solver): lift feature checks + sanitize/wiring to Solver path Make Solver.from_name(...).solve() a real first-class entry point that doesn't lose Model.solve()'s safety nets: - Lift solver-feature gates into Solver._build() via a new _validate_model() hook: quadratic models against LP-only solvers and semi-continuous variables against solvers that don't support them. Removed the duplicate checks from Model.solve(). - Add sanitize_zeros / sanitize_infinities kwargs to Solver.from_model() (default True). The kwargs are processed in _build() before dispatch, so both file and direct io_apis honor them. Model.solve() forwards the kwargs through instead of pre-mutating the constraints itself. - Extend Model.assign_result(result, solver=None) so the Solver-path canonical pattern works: solver = Solver.from_name(...); result = solver.solve(); model.assign_result(result, solver=solver). When the solver kwarg is provided, model.solver gets wired the same way Model.solve() wires it, so compute_infeasibilities() and friends keep working through the low-level API. The empty-objective check stays on Model.solve() — to_gurobipy() / to_highspy() and similar build-only converters legitimately work against objectiveless models (gurobi/highs default to a zero objective), so the check belongs at the actual submit point. Co-Authored-By: Claude Opus 4.7 (1M context) * move empty-objective check to Solver.solve() for entry-point parity The empty-objective UX guardrail was previously only on Model.solve(), leaving the lower-level Solver.from_name(...).solve() path with a silent gap. Move it to Solver.solve() — the actual submit primitive that both entry points go through — so the same check fires regardless of which API the user reaches for. Build-time translate-only paths (to_gurobipy(), to_highspy(), to_file()) are unaffected since they don't call solve(). The cost of catching the error after build instead of before is bounded and only hits a programming-error case. Co-Authored-By: Claude Opus 4.7 (1M context) * test: parametrize empty-objective check across both entry points Consolidate the Model.solve() and Solver.from_name(...).solve() tests into one parametrized case — same check, two callers, one assertion. Co-Authored-By: Claude Opus 4.7 (1M context) * test: collapse parametrize to a single test with two raises blocks Same property tested twice — no need for separate test IDs. Co-Authored-By: Claude Opus 4.7 (1M context) * preserve empty-objective check for remote-solve path in Model.solve() The remote-solve branch in Model.solve() short-circuits to a RemoteHandler before reaching Solver.solve(), so the check now in Solver.solve() doesn't cover it. Restore the early raise in Model.solve() so behavior is unchanged for all Model.solve() callers (mock, remote, local) while Solver.solve() still covers direct-Solver callers. Co-Authored-By: Claude Opus 4.7 (1M context) * move remote-path empty-objective check inside the remote branch The early-position check was a workaround: the remote branch short-circuits before Solver.solve() (where the canonical check now lives), so empty-objective with remote=... wouldn't raise. Moving it into the remote branch itself makes the intent local to where it's needed, with a comment pointing at #683 where this duplication disappears once OETC becomes a Solver subclass. Co-Authored-By: Claude Opus 4.7 (1M context) * keep sanitize on Model; Solver.from_model() stays mutation-free Remove the sanitize_zeros / sanitize_infinities kwargs from Solver.from_model(). The Solver builder now never mutates the model. Sanitization is exposed where it has always lived — model.constraints.sanitize_zeros() / .sanitize_infinities() — and Model.solve() calls them inline as part of its orchestration. Rationale: model-state transformations should be Model-level primitives (matches the SOS reformulation pattern from #690). The Solver's job is to translate the model and run; it should not silently change the caller's model on the way in. Users who go through the lower-level Solver path apply sanitize explicitly when they want it. Replaces TestSanitizeKwargs with TestSolverDoesNotMutateModel, pinning the mutation-free invariant: building a Solver against a model with a near-zero coefficient leaves model.constraints["c"].coeffs unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) * address review: SOS hint, lp_only_solver fixture, assign_result doc --------- Co-authored-by: Claude Opus 4.7 (1M context) Co-authored-by: Fabian --- linopy/model.py | 68 ++++++++++++++++++----------- linopy/solvers.py | 72 ++++++++++++++++++++++++++----- test/test_solvers.py | 100 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 204 insertions(+), 36 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index 4a11558a..12a46206 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1678,12 +1678,6 @@ def solve( sanitize_zeros=sanitize_zeros, sanitize_infinities=sanitize_infinities ) - if self.objective.expression.empty: - raise ValueError( - "No objective has been set on the model. Use `m.add_objective(...)` " - "first (e.g. `m.add_objective(0 * x)` for a pure feasibility problem)." - ) - # check io_api if io_api is not None and io_api not in IO_APIS: raise ValueError( @@ -1691,6 +1685,16 @@ def solve( ) if remote is not None: + # The remote branch short-circuits before reaching Solver.solve(), + # which is where the empty-objective check normally fires. Replicate + # it here. This duplication becomes obsolete once OETC is folded + # into the Solver pipeline (see PyPSA/linopy#683). + if self.objective.expression.empty: + raise ValueError( + "No objective has been set on the model. Use " + "`m.add_objective(...)` first (e.g. `m.add_objective(0 * x)` " + "for a pure feasibility problem)." + ) if isinstance(remote, OetcHandler): solved = remote.solve_on_oetc( self, solver_name=solver_name, **solver_options @@ -1756,19 +1760,6 @@ def solve( else: solution_fn = self.get_solution_file() - if sanitize_zeros: - self.constraints.sanitize_zeros() - - if sanitize_infinities: - self.constraints.sanitize_infinities() - - if self.is_quadratic and not solver_class.supports( - SolverFeature.QUADRATIC_OBJECTIVE - ): - raise ValueError( - f"Solver {solver_name} does not support quadratic problems." - ) - if reformulate_sos not in (True, False, "auto"): raise ValueError( f"Invalid value for reformulate_sos: {reformulate_sos!r}. " @@ -1789,12 +1780,10 @@ def solve( # If SOS is present and the solver doesn't support it (and the user # didn't ask for reformulation), Solver._build() will raise. - if self.variables.semi_continuous: - if not solver_class.supports(SolverFeature.SEMI_CONTINUOUS_VARIABLES): - raise ValueError( - f"Solver {solver_name} does not support semi-continuous variables. " - "Use a solver that supports them (gurobi, cplex, highs)." - ) + if sanitize_zeros: + self.constraints.sanitize_zeros() + if sanitize_infinities: + self.constraints.sanitize_infinities() try: self.solver = None # closes any previous solver @@ -1842,7 +1831,34 @@ def solve( if applied_sos_reformulation_here: self.undo_sos_reformulation() - def assign_result(self, result: Result) -> tuple[str, str]: + def assign_result( + self, + result: Result, + solver: solvers.Solver | None = None, + ) -> tuple[str, str]: + """ + Write a solver Result back onto the model. + + Copies primal / dual values onto variables / constraints, sets + :attr:`status`, :attr:`termination_condition`, and + :attr:`objective.value`. When ``solver`` is provided, also stores it on + ``self.solver`` so post-solve introspection (``model.solver_model``, + ``compute_infeasibilities()``) works. + + Parameters + ---------- + result : Result + The :class:`linopy.constants.Result` returned by + :meth:`linopy.solvers.Solver.solve`. + solver : Solver, optional + The solver instance that produced the result. Pass it on the + low-level ``Solver.from_name(...).solve()`` path to attach it as + ``self.solver`` for post-solve introspection. ``Model.solve()`` + attaches the solver itself and does not pass this argument. + """ + if solver is not None: + self.solver = solver + result.info() if result.solution is not None: diff --git a/linopy/solvers.py b/linopy/solvers.py index 60669b9c..0fbd2c12 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -504,23 +504,52 @@ def from_model( return instance def _build(self, **build_kwargs: Any) -> None: - """Dispatch to direct or file build based on ``io_api``.""" + """ + Dispatch to direct or file build based on ``io_api``. + + The Solver never mutates ``self.model``. Constraint sanitization + (``model.constraints.sanitize_zeros()`` / + ``.sanitize_infinities()``) and SOS reformulation + (``model.apply_sos_reformulation()``) are Model-level operations + the caller applies first; this builder consumes whatever shape it + is handed. + """ if self.model is None: raise RuntimeError("Solver has no model attached; cannot build.") self.model._check_sos_unmasked() - if self.model.variables.sos and not type(self).supports( - SolverFeature.SOS_CONSTRAINTS - ): - raise ValueError( - f"Solver {self.solver_name.value} does not support SOS constraints. " - "Call `model.apply_sos_reformulation()` first, or use a solver that " - "supports SOS." - ) + self._validate_model() if self.io_api == "direct": self._build_direct(**build_kwargs) else: self._build_file(**build_kwargs) + def _validate_model(self) -> None: + """Pre-build checks on whether this solver can handle ``self.model``.""" + model = self.model + assert model is not None + solver_name = self.solver_name.value + cls = type(self) + + if model.is_quadratic and not cls.supports(SolverFeature.QUADRATIC_OBJECTIVE): + raise ValueError( + f"Solver {solver_name} does not support quadratic problems." + ) + + if model.variables.semi_continuous and not cls.supports( + SolverFeature.SEMI_CONTINUOUS_VARIABLES + ): + raise ValueError( + f"Solver {solver_name} does not support semi-continuous variables. " + "Use a solver that supports them (gurobi, cplex, highs)." + ) + + if model.variables.sos and not cls.supports(SolverFeature.SOS_CONSTRAINTS): + raise ValueError( + f"Solver {solver_name} does not support SOS constraints. " + "Reformulate first via `Model.solve(reformulate_sos=True)` or " + "`model.apply_sos_reformulation()`, or use a solver that supports SOS." + ) + def _build_direct(self, **build_kwargs: Any) -> None: """Build the native solver model from ``self.model``. Override per-solver.""" raise NotImplementedError( @@ -561,7 +590,30 @@ def _build_file(self, **build_kwargs: Any) -> None: self._cache_model_sizes(model) def solve(self, **run_kwargs: Any) -> Result: - """Run the prepared solver and return a :class:`Result`.""" + """ + Run the prepared solver and return a :class:`Result`. + + The canonical low-level pattern is:: + + solver = Solver.from_name("gurobi", model, io_api="direct") + result = solver.solve() + model.assign_result(result, solver=solver) + + Passing ``solver=`` to :meth:`Model.assign_result` wires + ``model.solver`` so post-solve helpers like + :meth:`Model.compute_infeasibilities` keep working. + + Raises + ------ + ValueError + If the attached model has no objective set. Submit-time check + shared by both ``Model.solve()`` and direct-Solver callers. + """ + if self.model is not None and self.model.objective.expression.empty: + raise ValueError( + "No objective has been set on the model. Use `m.add_objective(...)` " + "first (e.g. `m.add_objective(0 * x)` for a pure feasibility problem)." + ) if self.io_api == "direct" or self.solver_model is not None: return self._run_direct(**run_kwargs) if self._problem_fn is not None: diff --git a/test/test_solvers.py b/test/test_solvers.py index db894137..1b6bd9a9 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -23,6 +23,14 @@ from linopy.solvers import _installed_version_in +@pytest.fixture +def lp_only_solver() -> str: + for name in ("glpk", "cbc"): + if name in solvers.available_solvers: + return name + pytest.skip("Need an LP-only solver (glpk or cbc) installed") + + @pytest.fixture def simple_model() -> Model: m = Model(chunk=None) @@ -464,3 +472,95 @@ def test_xpress_gpu_feature_reflects_installed_version() -> None: assert solvers.Xpress.supports( SolverFeature.GPU_ACCELERATION ) == _installed_version_in("xpress", ">=9.8.0") + + +class TestValidateModelOnBuild: + """Solver._build() runs solver-feature checks regardless of entry point.""" + + def test_quadratic_without_qp_support_raises(self, lp_only_solver: str) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + m.add_objective(x * x, sense="min") + + with pytest.raises(ValueError, match="does not support quadratic"): + solvers.Solver.from_name(lp_only_solver, m, io_api="lp") + + def test_semi_continuous_without_support_raises(self, lp_only_solver: str) -> None: + m = Model() + x = m.add_variables(name="x", lower=1, upper=10, semi_continuous=True) + m.add_objective(x) + + with pytest.raises(ValueError, match="does not support semi-continuous"): + solvers.Solver.from_name(lp_only_solver, m, io_api="lp") + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_solve_without_objective_raises(self) -> None: + m = Model() + m.add_variables(name="x", lower=0, upper=10) + # No objective added — both entry points should raise the same error. + with pytest.raises(ValueError, match="No objective has been set"): + solvers.Solver.from_name("highs", m, io_api="lp").solve() + with pytest.raises(ValueError, match="No objective has been set"): + m.solve("highs") + + +class TestSolverDoesNotMutateModel: + """Solver.from_model() must not mutate model state (sanitize stays Model-level).""" + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_from_model_leaves_constraints_untouched(self) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + # Constraint with a near-zero coefficient — would be sanitized away if + # the Solver path were sanitizing on build. + m.add_constraints(1e-12 * x + x >= 0, name="c") + m.add_objective(x) + + before = m.constraints["c"].coeffs.values.copy() + solvers.Solver.from_name("highs", m, io_api="lp") + after = m.constraints["c"].coeffs.values + + assert np.allclose(before, after, equal_nan=True), ( + "Solver.from_model() must not mutate model constraints. " + "Sanitization is a Model-level primitive; call " + "model.constraints.sanitize_zeros() / .sanitize_infinities() " + "explicitly before building." + ) + + +class TestAssignResultWiring: + """assign_result(result, solver=...) populates model.solver.""" + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_assign_result_with_solver_wires_model_solver(self) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + m.add_objective(x, sense="min") + + assert m.solver is None + solver = solvers.Solver.from_name("highs", m, io_api="lp") + result = solver.solve() + m.assign_result(result, solver=solver) + + assert m.solver is solver + assert m.solver_model is solver.solver_model + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_assign_result_without_solver_kwarg_leaves_solver_unset(self) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + m.add_objective(x, sense="min") + + solver = solvers.Solver.from_name("highs", m, io_api="lp") + result = solver.solve() + m.assign_result(result) # no solver kwarg + + assert m.solver is None From e08733cb39bbc2d05777555709e715a384b715c2 Mon Sep 17 00:00:00 2001 From: Fabian Date: Mon, 18 May 2026 21:16:03 +0200 Subject: [PATCH 3/9] refactor(sos): tighten undo semantics and error hints - undo_sos_reformulation() now raises if no state is applied (fail-fast) - to_netcdf error no longer suggests poking the private state slot - Solver._build runs _validate_model before _check_sos_unmasked so SOS on an LP-only solver surfaces the reformulate-first hint - reformulate_sos_constraints docstring points at the stateful API --- linopy/io.py | 3 +-- linopy/model.py | 9 +++++++-- linopy/solvers.py | 2 +- linopy/sos_reformulation.py | 6 ++++-- test/test_sos_reformulation.py | 9 +++------ 5 files changed, 16 insertions(+), 13 deletions(-) diff --git a/linopy/io.py b/linopy/io.py index 064716b9..657d2c19 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -857,8 +857,7 @@ def to_netcdf(m: Model, *args: Any, **kwargs: Any) -> None: raise RuntimeError( "Cannot serialize a model with an active SOS reformulation. " "Call `model.undo_sos_reformulation()` first to restore the " - "original SOS form, or save the model in its reformulated form " - "after explicitly clearing `model._sos_reformulation_state`." + "original SOS form before saving." ) def with_prefix(ds: xr.Dataset, prefix: str) -> xr.Dataset: diff --git a/linopy/model.py b/linopy/model.py index 12a46206..65e8093f 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1249,10 +1249,15 @@ def undo_sos_reformulation(self) -> None: """ Revert a previously applied SOS reformulation. - No-op if no reformulation is currently applied. + Raises + ------ + RuntimeError + If no reformulation is currently applied. """ if self._sos_reformulation_state is None: - return + raise RuntimeError( + "No SOS reformulation is currently applied to this model." + ) state = self._sos_reformulation_state self._sos_reformulation_state = None undo_sos_reformulation(self, state) diff --git a/linopy/solvers.py b/linopy/solvers.py index 0fbd2c12..d6cc50e6 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -516,8 +516,8 @@ def _build(self, **build_kwargs: Any) -> None: """ if self.model is None: raise RuntimeError("Solver has no model attached; cannot build.") - self.model._check_sos_unmasked() self._validate_model() + self.model._check_sos_unmasked() if self.io_api == "direct": self._build_direct(**build_kwargs) else: diff --git a/linopy/sos_reformulation.py b/linopy/sos_reformulation.py index 8ccb7613..0c677216 100644 --- a/linopy/sos_reformulation.py +++ b/linopy/sos_reformulation.py @@ -233,8 +233,10 @@ def reformulate_sos_constraints( 1. If custom big_m was specified in add_sos_constraints(), use that 2. Otherwise, use the variable bounds (tightest valid Big-M) - Note: This permanently mutates the model. To solve with automatic - undo, use ``model.solve(reformulate_sos=True)`` instead. + Note: This permanently mutates the model and returns a token the caller + owns. For a stateful, reversible API use ``model.apply_sos_reformulation()`` + / ``model.undo_sos_reformulation()``; for automatic undo around a single + solve use ``model.solve(reformulate_sos=True)``. Parameters ---------- diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index f2f48075..6b42918a 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -353,14 +353,11 @@ def test_double_apply_raises(self) -> None: with pytest.raises(RuntimeError, match="already been applied"): m.apply_sos_reformulation() - def test_undo_without_apply_is_noop(self) -> None: + def test_undo_without_apply_raises(self) -> None: m = self._build_sos1_model() - assert m._sos_reformulation_state is None - - m.undo_sos_reformulation() # should not raise - assert m._sos_reformulation_state is None - assert list(m.variables.sos) == ["x"] + with pytest.raises(RuntimeError, match="No SOS reformulation"): + m.undo_sos_reformulation() @pytest.mark.parametrize( "copy_fn", From 9c38ea6d36f02cd13c77c3db11c59f655a60c0fb Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Mon, 18 May 2026 21:23:16 +0200 Subject: [PATCH 4/9] fix(sos): auto-undo SOS reformulation when build/solve raises `Model.solve(reformulate_sos=...)` left `_sos_reformulation_state` set if `Solver.from_name`, `solver.solve`, or the file-cleanup `finally` raised, since the undo lived in a second `try` around `assign_result` that those failures never reached. The next solve then hit `RuntimeError: SOS reformulation has already been applied`. Wrap sanitize, build/solve, file cleanup, and assign_result in a single outer try/finally so the undo always runs. Co-Authored-By: Claude Opus 4.7 (1M context) --- linopy/model.py | 88 +++++++++++++++++----------------- test/test_sos_reformulation.py | 32 +++++++++++++ 2 files changed, 76 insertions(+), 44 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index 65e8093f..03450d62 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -1785,52 +1785,52 @@ def solve( # If SOS is present and the solver doesn't support it (and the user # didn't ask for reformulation), Solver._build() will raise. - if sanitize_zeros: - self.constraints.sanitize_zeros() - if sanitize_infinities: - self.constraints.sanitize_infinities() - try: - self.solver = None # closes any previous solver - if io_api == "direct": - if set_names is None: - set_names = self.set_names_in_solver_io - build_kwargs: dict[str, Any] = { - "explicit_coordinate_names": explicit_coordinate_names, - "set_names": set_names, - "log_fn": to_path(log_fn), - } - if env is not None: - build_kwargs["env"] = env - else: - build_kwargs = { - "explicit_coordinate_names": explicit_coordinate_names, - "slice_size": slice_size, - "progress": progress, - "problem_fn": to_path(problem_fn), - } - self.solver = solver = solvers.Solver.from_name( - solver_name, - model=self, - io_api=io_api, - options=solver_options, - **build_kwargs, - ) - if io_api != "direct": - problem_fn = solver._problem_fn - result = solver.solve( - solution_fn=to_path(solution_fn), - log_fn=to_path(log_fn), - warmstart_fn=to_path(warmstart_fn), - basis_fn=to_path(basis_fn), - env=env, - ) - finally: - for fn in (problem_fn, solution_fn): - if fn is not None and (os.path.exists(fn) and not keep_files): - os.remove(fn) + if sanitize_zeros: + self.constraints.sanitize_zeros() + if sanitize_infinities: + self.constraints.sanitize_infinities() + + try: + self.solver = None # closes any previous solver + if io_api == "direct": + if set_names is None: + set_names = self.set_names_in_solver_io + build_kwargs: dict[str, Any] = { + "explicit_coordinate_names": explicit_coordinate_names, + "set_names": set_names, + "log_fn": to_path(log_fn), + } + if env is not None: + build_kwargs["env"] = env + else: + build_kwargs = { + "explicit_coordinate_names": explicit_coordinate_names, + "slice_size": slice_size, + "progress": progress, + "problem_fn": to_path(problem_fn), + } + self.solver = solver = solvers.Solver.from_name( + solver_name, + model=self, + io_api=io_api, + options=solver_options, + **build_kwargs, + ) + if io_api != "direct": + problem_fn = solver._problem_fn + result = solver.solve( + solution_fn=to_path(solution_fn), + log_fn=to_path(log_fn), + warmstart_fn=to_path(warmstart_fn), + basis_fn=to_path(basis_fn), + env=env, + ) + finally: + for fn in (problem_fn, solution_fn): + if fn is not None and (os.path.exists(fn) and not keep_files): + os.remove(fn) - try: return self.assign_result(result) finally: if applied_sos_reformulation_here: diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index 6b42918a..b244d9b6 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -433,6 +433,38 @@ def test_solver_from_name_raises_without_reformulation(self) -> None: solvers.Solver.from_name("highs", m, io_api="lp") +@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") +class TestSolveAutoUndoOnFailure: + """Model.solve must auto-undo SOS reformulation when build/solve raises.""" + + def test_state_restored_when_build_raises( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + from linopy import solvers + + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x.sum(), sense="max") + + def boom(*args: object, **kwargs: object) -> None: + raise RuntimeError("simulated build failure") + + monkeypatch.setattr(solvers.Solver, "from_name", boom) + + with pytest.raises(RuntimeError, match="simulated build failure"): + m.solve(solver_name="highs", reformulate_sos=True) + + assert m._sos_reformulation_state is None + assert list(m.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in m.variables + + # A subsequent real solve must not hit "already applied" + monkeypatch.undo() + m.solve(solver_name="highs", reformulate_sos=True) + + @pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") class TestSolveWithReformulation: """Tests for solving with SOS reformulation.""" From f7b3a11c0fc9cf9a630805919c9653fd4f54faa6 Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 19 May 2026 09:40:43 +0200 Subject: [PATCH 5/9] fix(sos): support reformulation through remote/oetc netcdf path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `to_netcdf` previously raised when the model had an active SOS reformulation, blocking `Model.solve(remote=...)` for users who passed `reformulate_sos`. Beyond the raise, the remote branch was silently ignoring `reformulate_sos` entirely. Changes: - `to_netcdf` warns (UserWarning) instead of raising; the reformulated MILP form is what gets serialized. The `_sos_reformulation_state` token is not persisted — it lives only on the in-memory caller's Model, where the apply/undo bracket keeps its lifecycle intact. - `Model.solve(remote=...)` now brackets the remote dispatch with `apply_sos_reformulation` / `undo_sos_reformulation`, exactly like the local path. The `to_netcdf` warning emitted inside the remote helper is suppressed via `warnings.catch_warnings`. - New `Model._resolve_sos_reformulation(solver_name, reformulate_sos)` helper deduplicates the should-reformulate decision between the local and remote branches and uses `solver_supports(...)` instead of the ad-hoc `getattr(solvers, SolverName(...).name)` pattern. - `solver_name=None` with `reformulate_sos="auto"` now raises a sharp error pointing users at either passing `solver_name=...` or using `True`/`False` to skip the lookup. The local path is unaffected because its existing default (`solver_name = available_solvers[0]`) runs before the helper sees None. Addresses the open thread on #690 from FabianHofmann. Co-Authored-By: Claude Opus 4.7 (1M context) --- linopy/io.py | 36 ++++++-- linopy/model.py | 145 ++++++++++++++++++++++----------- test/test_sos_reformulation.py | 134 +++++++++++++++++++++++++++++- 3 files changed, 256 insertions(+), 59 deletions(-) diff --git a/linopy/io.py b/linopy/io.py index 657d2c19..b0abe9fb 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -10,6 +10,7 @@ import logging import shutil import time +import warnings from collections.abc import Callable, Iterable from io import BufferedWriter from pathlib import Path @@ -847,17 +848,27 @@ def to_netcdf(m: Model, *args: Any, **kwargs: Any) -> None: **kwargs : TYPE Keyword arguments passed to ``xarray.Dataset.to_netcdf``. - Raises - ------ - RuntimeError - If the model has an active SOS reformulation. Call - :meth:`Model.undo_sos_reformulation` before serializing. + Notes + ----- + The SOS reformulation lifecycle token lives only on the in-memory + Model and is not persisted. If the model has an active SOS + reformulation at serialization time, the netcdf contains the + reformulated MILP form (aux binaries and cardinality constraints) + and a :class:`UserWarning` is emitted to flag that the deserialized + copy will not be able to undo the reformulation. + + ``Model.solve(remote=...)`` invokes ``to_netcdf`` internally on the + reformulated model and suppresses this warning. """ if m._sos_reformulation_state is not None: - raise RuntimeError( - "Cannot serialize a model with an active SOS reformulation. " - "Call `model.undo_sos_reformulation()` first to restore the " - "original SOS form before saving." + warnings.warn( + "Serializing a model with an active SOS reformulation. The " + "netcdf will contain the reformulated MILP form; the " + "reformulation lifecycle token is not persisted, so a " + "reader cannot undo it. Call `model.undo_sos_reformulation()` " + "first if you want the original SOS form on disk.", + UserWarning, + stacklevel=2, ) def with_prefix(ds: xr.Dataset, prefix: str) -> xr.Dataset: @@ -929,6 +940,13 @@ def read_netcdf(path: Path | str, **kwargs: Any) -> Model: Returns ------- m : linopy.Model + + Notes + ----- + The SOS reformulation lifecycle token is not persisted by + :func:`to_netcdf`. If the saved model was in reformulated form, + the deserialized Model is too, but + :meth:`Model.undo_sos_reformulation` is a no-op on it. """ from linopy.constraints import ( Constraint, diff --git a/linopy/model.py b/linopy/model.py index 03450d62..bba74867 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -84,6 +84,7 @@ from linopy.remote import OetcHandler except ImportError: OetcHandler = None # type: ignore +from linopy.solver_capabilities import solver_supports from linopy.solvers import ( IO_APIS, SolverFeature, @@ -1262,6 +1263,42 @@ def undo_sos_reformulation(self) -> None: self._sos_reformulation_state = None undo_sos_reformulation(self, state) + def _resolve_sos_reformulation( + self, + solver_name: str | None, + reformulate_sos: bool | Literal["auto"], + ) -> bool: + """ + Decide whether ``apply_sos_reformulation`` should run. + + Validates ``reformulate_sos`` and returns ``True`` iff the SOS + constraints on this model should be reformulated for the chosen + solver. ``solver_name`` is only consulted when + ``reformulate_sos == "auto"`` (to look up SOS support); for + ``True`` / ``False`` the decision is independent of the solver. + """ + if reformulate_sos not in (True, False, "auto"): + raise ValueError( + f"Invalid value for reformulate_sos: {reformulate_sos!r}. " + "Must be True, False, or 'auto'." + ) + if not self.variables.sos: + return False + + if reformulate_sos is False: + return False + elif reformulate_sos is True: + return True + elif solver_name is None: + raise ValueError( + "`reformulate_sos='auto'` on a model with SOS constraints " + "requires an explicit `solver_name` so we can check " + "whether the chosen solver supports SOS. Pass " + "`solver_name=...` or use `reformulate_sos=True`/`False` " + "to skip the lookup." + ) + return not solver_supports(solver_name, SolverFeature.SOS_CONSTRAINTS) + def _check_sos_unmasked(self) -> None: """ Reject the model if any SOS variable has masked entries. @@ -1700,35 +1737,62 @@ def solve( "`m.add_objective(...)` first (e.g. `m.add_objective(0 * x)` " "for a pure feasibility problem)." ) - if isinstance(remote, OetcHandler): - solved = remote.solve_on_oetc( - self, solver_name=solver_name, **solver_options - ) - else: - solved = remote.solve_on_remote( - self, - solver_name=solver_name, - io_api=io_api, - problem_fn=problem_fn, - solution_fn=solution_fn, - log_fn=log_fn, - basis_fn=basis_fn, - warmstart_fn=warmstart_fn, - keep_files=keep_files, - sanitize_zeros=sanitize_zeros, - **solver_options, + # Reformulate before the model is serialized to netcdf inside the + # remote helper, so the worker just solves a plain MILP and the + # SOS lifecycle stays on this (caller's) Model. + applied_sos_reformulation_here = self._resolve_sos_reformulation( + solver_name, reformulate_sos + ) + if applied_sos_reformulation_here: + logger.info( + f"Reformulating SOS constraints for remote solver {solver_name}" ) + self.apply_sos_reformulation() - if solved.objective.value is not None: - self.objective.set_value(float(solved.objective.value)) - self.status = solved.status - self.termination_condition = solved.termination_condition - for k, v in self.variables.items(): - v.solution = solved.variables[k].solution - for k, c in self.constraints.items(): - if "dual" in solved.constraints[k]: - c.dual = solved.constraints[k].dual - return self.status, self.termination_condition + try: + with warnings.catch_warnings(): + # The remote helpers call self.to_netcdf(...) internally; + # serializing the reformulated form here is intentional, so + # silence the UserWarning that to_netcdf emits in that case. + if applied_sos_reformulation_here: + warnings.filterwarnings( + "ignore", + message="Serializing a model with an active SOS " + "reformulation", + category=UserWarning, + ) + if isinstance(remote, OetcHandler): + solved = remote.solve_on_oetc( + self, solver_name=solver_name, **solver_options + ) + else: + solved = remote.solve_on_remote( + self, + solver_name=solver_name, + io_api=io_api, + problem_fn=problem_fn, + solution_fn=solution_fn, + log_fn=log_fn, + basis_fn=basis_fn, + warmstart_fn=warmstart_fn, + keep_files=keep_files, + sanitize_zeros=sanitize_zeros, + **solver_options, + ) + + if solved.objective.value is not None: + self.objective.set_value(float(solved.objective.value)) + self.status = solved.status + self.termination_condition = solved.termination_condition + for k, v in self.variables.items(): + v.solution = solved.variables[k].solution + for k, c in self.constraints.items(): + if "dual" in solved.constraints[k]: + c.dual = solved.constraints[k].dual + return self.status, self.termination_condition + finally: + if applied_sos_reformulation_here: + self.undo_sos_reformulation() if len(available_solvers) == 0: raise RuntimeError("No solver installed.") @@ -1765,25 +1829,14 @@ def solve( else: solution_fn = self.get_solution_file() - if reformulate_sos not in (True, False, "auto"): - raise ValueError( - f"Invalid value for reformulate_sos: {reformulate_sos!r}. " - "Must be True, False, or 'auto'." - ) - - applied_sos_reformulation_here = False - if self.variables.sos: - supports_sos = solver_class.supports(SolverFeature.SOS_CONSTRAINTS) - should_reformulate = reformulate_sos is True or ( - reformulate_sos == "auto" and not supports_sos - ) - - if should_reformulate: - logger.info(f"Reformulating SOS constraints for solver {solver_name}") - self.apply_sos_reformulation() - applied_sos_reformulation_here = True - # If SOS is present and the solver doesn't support it (and the user - # didn't ask for reformulation), Solver._build() will raise. + applied_sos_reformulation_here = self._resolve_sos_reformulation( + solver_name, reformulate_sos + ) + if applied_sos_reformulation_here: + logger.info(f"Reformulating SOS constraints for solver {solver_name}") + self.apply_sos_reformulation() + # If SOS is present and the solver doesn't support it (and the user + # didn't ask for reformulation), Solver._build() will raise. try: if sanitize_zeros: diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index b244d9b6..0eb91c39 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -3,12 +3,14 @@ from __future__ import annotations import logging +import warnings from collections.abc import Callable from pathlib import Path import numpy as np import pandas as pd import pytest +import xarray as xr from linopy import Model, Variable, available_solvers from linopy.constants import SOS_TYPE_ATTR @@ -401,19 +403,25 @@ def test_copy_persists_state_and_undo_works_on_copy( assert "_sos_reform_x_y" in m.variables assert len(list(m.variables.sos)) == 0 - def test_to_netcdf_raises_when_state_active(self, tmp_path: Path) -> None: + def test_to_netcdf_warns_when_state_active(self, tmp_path: Path) -> None: m = self._build_sos1_model() m.apply_sos_reformulation() - with pytest.raises(RuntimeError, match="active SOS reformulation"): + with pytest.warns(UserWarning, match="active SOS reformulation"): m.to_netcdf(tmp_path / "m.nc") - def test_to_netcdf_works_after_undo(self, tmp_path: Path) -> None: + # File written despite the warning — the netcdf carries the + # reformulated MILP form. + assert (tmp_path / "m.nc").exists() + + def test_to_netcdf_silent_after_undo(self, tmp_path: Path) -> None: m = self._build_sos1_model() m.apply_sos_reformulation() m.undo_sos_reformulation() - m.to_netcdf(tmp_path / "m.nc") # should not raise + with warnings.catch_warnings(): + warnings.simplefilter("error") # any warning fails the test + m.to_netcdf(tmp_path / "m.nc") @pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") @@ -1084,3 +1092,121 @@ def test_invalid_reformulate_sos_value(self) -> None: with pytest.raises(ValueError, match="Invalid value for reformulate_sos"): m.solve(solver_name="highs", reformulate_sos="invalid") # type: ignore[arg-type] + + +class TestResolveSOSReformulation: + """Helper contracts not already exercised end-to-end by ``m.solve(...)``.""" + + @staticmethod + def _sos_model() -> Model: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + return m + + def test_no_sos_short_circuits(self) -> None: + # Fast path: no SOS variables means False regardless of args. + m = Model() + m.add_variables(name="x") + for v in (True, False, "auto"): + assert m._resolve_sos_reformulation(None, v) is False # type: ignore[arg-type] + + def test_true_does_not_consult_solver_name(self) -> None: + # reformulate_sos=True must not require solver_name — no lookup. + assert self._sos_model()._resolve_sos_reformulation(None, True) is True + + def test_auto_with_none_solver_raises(self) -> None: + with pytest.raises(ValueError, match="requires an explicit `solver_name`"): + self._sos_model()._resolve_sos_reformulation(None, "auto") + + +@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") +class TestRemoteBracket: + """ + Model.solve(remote=...) must bracket SOS reformulation around the remote + dispatch and suppress the to_netcdf warning that fires inside the helper. + """ + + @staticmethod + def _sos_model() -> Model: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x * np.array([1.0, 2.0, 3.0]), sense="max") + return m + + def _fake_handler(self, observed: dict[str, object], tmp_path: Path) -> object: + """ + Non-OetcHandler stand-in with the SSH-shaped `solve_on_remote`. + + Records whether the model arrives in reformulated form, then runs + `model.to_netcdf(...)` and `read_netcdf(...)` (naturally — no + warning recording here, so we can observe at the call-site whether + Model.solve's suppression worked). + """ + from linopy.io import read_netcdf + + class _Handler: + def solve_on_remote(_self, model: Model, **kwargs: object) -> Model: + observed["state_active"] = model._sos_reformulation_state is not None + observed["solver_name_arg"] = kwargs.get("solver_name") + model.to_netcdf(tmp_path / "sent.nc") + solved = read_netcdf(tmp_path / "sent.nc") + for _name, var in solved.variables.items(): + arr = np.zeros(var.labels.shape, dtype=float) + var.solution = xr.DataArray(arr, dims=var.labels.dims) + solved.objective.set_value(0.0) + solved.status = "ok" + solved.termination_condition = "optimal" + return solved + + return _Handler() + + def test_remote_brackets_and_suppresses_warning(self, tmp_path: Path) -> None: + m = self._sos_model() + observed: dict[str, object] = {} + handler = self._fake_handler(observed, tmp_path) + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + m.solve(solver_name="highs", remote=handler, reformulate_sos=True) + + # Reformulation was active when the handler ran (apply happened + # before the remote dispatch). + assert observed["state_active"] is True + assert observed["solver_name_arg"] == "highs" + + # No "active SOS reformulation" warning escaped Model.solve. + assert not any("active SOS reformulation" in str(w.message) for w in captured) + + # Lifecycle wound down: state cleared, original SOS variable restored. + assert m._sos_reformulation_state is None + assert list(m.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in m.variables + + def test_remote_skips_bracket_when_reformulate_sos_false( + self, tmp_path: Path + ) -> None: + m = self._sos_model() + observed: dict[str, object] = {} + handler = self._fake_handler(observed, tmp_path) + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + m.solve(solver_name="highs", remote=handler, reformulate_sos=False) + + # No reformulation happened — model still has the original SOS var + # when the handler sees it, and to_netcdf never warns. + assert observed["state_active"] is False + assert not any("active SOS reformulation" in str(w.message) for w in captured) + assert m._sos_reformulation_state is None + + def test_remote_auto_requires_solver_name_with_sos(self, tmp_path: Path) -> None: + m = self._sos_model() + observed: dict[str, object] = {} + handler = self._fake_handler(observed, tmp_path) + + with pytest.raises(ValueError, match="requires an explicit `solver_name`"): + m.solve(remote=handler, reformulate_sos="auto") From 201b8726334c7aa710ed7f8fd474ada7d87f847a Mon Sep 17 00:00:00 2001 From: FBumann <117816358+FBumann@users.noreply.github.com> Date: Tue, 19 May 2026 10:54:16 +0200 Subject: [PATCH 6/9] test(sos): fix mypy errors on remote-bracket and resolve tests - Drop now-unused type: ignore on _resolve_sos_reformulation call where mypy correctly narrows (True, False, "auto") to bool | Literal["auto"]. - Type _fake_handler as RemoteHandler via cast so the three Model.solve(remote=handler, ...) calls satisfy the RemoteHandler | OetcHandler | None signature. Co-Authored-By: Claude Opus 4.7 (1M context) --- test/test_sos_reformulation.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index 0eb91c39..b6049c76 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -6,6 +6,7 @@ import warnings from collections.abc import Callable from pathlib import Path +from typing import cast import numpy as np import pandas as pd @@ -14,6 +15,7 @@ from linopy import Model, Variable, available_solvers from linopy.constants import SOS_TYPE_ATTR +from linopy.remote import RemoteHandler from linopy.sos_reformulation import ( compute_big_m_values, reformulate_sos1, @@ -1110,7 +1112,7 @@ def test_no_sos_short_circuits(self) -> None: m = Model() m.add_variables(name="x") for v in (True, False, "auto"): - assert m._resolve_sos_reformulation(None, v) is False # type: ignore[arg-type] + assert m._resolve_sos_reformulation(None, v) is False def test_true_does_not_consult_solver_name(self) -> None: # reformulate_sos=True must not require solver_name — no lookup. @@ -1137,7 +1139,9 @@ def _sos_model() -> Model: m.add_objective(x * np.array([1.0, 2.0, 3.0]), sense="max") return m - def _fake_handler(self, observed: dict[str, object], tmp_path: Path) -> object: + def _fake_handler( + self, observed: dict[str, object], tmp_path: Path + ) -> RemoteHandler: """ Non-OetcHandler stand-in with the SSH-shaped `solve_on_remote`. @@ -1162,7 +1166,7 @@ def solve_on_remote(_self, model: Model, **kwargs: object) -> Model: solved.termination_condition = "optimal" return solved - return _Handler() + return cast(RemoteHandler, _Handler()) def test_remote_brackets_and_suppresses_warning(self, tmp_path: Path) -> None: m = self._sos_model() From 79b6a077a4fb863efcc1a644646254ec675be9b1 Mon Sep 17 00:00:00 2001 From: Fabian Date: Tue, 19 May 2026 11:13:29 +0200 Subject: [PATCH 7/9] refactor(sos): move reformulation lifecycle into remote handlers --- linopy/model.py | 100 +++++++++++---------------------- linopy/remote/oetc.py | 23 ++++++-- linopy/remote/ssh.py | 55 +++++++++++------- linopy/sos_reformulation.py | 43 +++++++++++++- test/test_oetc_settings.py | 2 +- test/test_sos_reformulation.py | 27 +++++++-- 6 files changed, 150 insertions(+), 100 deletions(-) diff --git a/linopy/model.py b/linopy/model.py index bba74867..250d65fe 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -93,6 +93,7 @@ from linopy.sos_reformulation import ( SOSReformulationResult, reformulate_sos_constraints, + sos_reformulation_context, undo_sos_reformulation, ) from linopy.types import ( @@ -1737,62 +1738,39 @@ def solve( "`m.add_objective(...)` first (e.g. `m.add_objective(0 * x)` " "for a pure feasibility problem)." ) - # Reformulate before the model is serialized to netcdf inside the - # remote helper, so the worker just solves a plain MILP and the - # SOS lifecycle stays on this (caller's) Model. - applied_sos_reformulation_here = self._resolve_sos_reformulation( - solver_name, reformulate_sos - ) - if applied_sos_reformulation_here: - logger.info( - f"Reformulating SOS constraints for remote solver {solver_name}" + if isinstance(remote, OetcHandler): + solved = remote.solve_on_oetc( + self, + solver_name=solver_name, + reformulate_sos=reformulate_sos, + **solver_options, + ) + else: + solved = remote.solve_on_remote( + self, + solver_name=solver_name, + io_api=io_api, + problem_fn=problem_fn, + solution_fn=solution_fn, + log_fn=log_fn, + basis_fn=basis_fn, + warmstart_fn=warmstart_fn, + keep_files=keep_files, + sanitize_zeros=sanitize_zeros, + reformulate_sos=reformulate_sos, + **solver_options, ) - self.apply_sos_reformulation() - try: - with warnings.catch_warnings(): - # The remote helpers call self.to_netcdf(...) internally; - # serializing the reformulated form here is intentional, so - # silence the UserWarning that to_netcdf emits in that case. - if applied_sos_reformulation_here: - warnings.filterwarnings( - "ignore", - message="Serializing a model with an active SOS " - "reformulation", - category=UserWarning, - ) - if isinstance(remote, OetcHandler): - solved = remote.solve_on_oetc( - self, solver_name=solver_name, **solver_options - ) - else: - solved = remote.solve_on_remote( - self, - solver_name=solver_name, - io_api=io_api, - problem_fn=problem_fn, - solution_fn=solution_fn, - log_fn=log_fn, - basis_fn=basis_fn, - warmstart_fn=warmstart_fn, - keep_files=keep_files, - sanitize_zeros=sanitize_zeros, - **solver_options, - ) - - if solved.objective.value is not None: - self.objective.set_value(float(solved.objective.value)) - self.status = solved.status - self.termination_condition = solved.termination_condition - for k, v in self.variables.items(): - v.solution = solved.variables[k].solution - for k, c in self.constraints.items(): - if "dual" in solved.constraints[k]: - c.dual = solved.constraints[k].dual - return self.status, self.termination_condition - finally: - if applied_sos_reformulation_here: - self.undo_sos_reformulation() + if solved.objective.value is not None: + self.objective.set_value(float(solved.objective.value)) + self.status = solved.status + self.termination_condition = solved.termination_condition + for k, v in self.variables.items(): + v.solution = solved.variables[k].solution + for k, c in self.constraints.items(): + if "dual" in solved.constraints[k]: + c.dual = solved.constraints[k].dual + return self.status, self.termination_condition if len(available_solvers) == 0: raise RuntimeError("No solver installed.") @@ -1829,16 +1807,7 @@ def solve( else: solution_fn = self.get_solution_file() - applied_sos_reformulation_here = self._resolve_sos_reformulation( - solver_name, reformulate_sos - ) - if applied_sos_reformulation_here: - logger.info(f"Reformulating SOS constraints for solver {solver_name}") - self.apply_sos_reformulation() - # If SOS is present and the solver doesn't support it (and the user - # didn't ask for reformulation), Solver._build() will raise. - - try: + with sos_reformulation_context(self, solver_name, reformulate_sos): if sanitize_zeros: self.constraints.sanitize_zeros() if sanitize_infinities: @@ -1885,9 +1854,6 @@ def solve( os.remove(fn) return self.assign_result(result) - finally: - if applied_sos_reformulation_here: - self.undo_sos_reformulation() def assign_result( self, diff --git a/linopy/remote/oetc.py b/linopy/remote/oetc.py index f451a43d..beb6245b 100644 --- a/linopy/remote/oetc.py +++ b/linopy/remote/oetc.py @@ -26,6 +26,10 @@ _oetc_deps_available = False import linopy +from linopy.sos_reformulation import ( + sos_reformulation_context, + suppress_serialization_warning, +) logger = logging.getLogger(__name__) @@ -631,7 +635,12 @@ def _download_file_from_gcp(self, file_name: str) -> str: raise Exception(f"Failed to download file from GCP: {e}") def solve_on_oetc( - self, model: Model, solver_name: str | None = None, **solver_options: Any + self, + model: Model, + solver_name: str | None = None, + *, + reformulate_sos: bool | str = False, + **solver_options: Any, ) -> Model: """ Solve a linopy model on the OET Cloud compute app. @@ -657,10 +666,14 @@ def solve_on_oetc( effective_solver = solver_name or self.settings.solver merged_solver_options = {**self.settings.solver_options, **solver_options} - with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn: - fn.file.close() - model.to_netcdf(fn.name) - input_file_name = self._upload_file_to_gcp(fn.name) + with sos_reformulation_context( + model, effective_solver, reformulate_sos + ) as applied: + with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn: + fn.file.close() + with suppress_serialization_warning(active=applied): + model.to_netcdf(fn.name) + input_file_name = self._upload_file_to_gcp(fn.name) job_uuid = self._submit_job_to_compute_service( input_file_name, effective_solver, merged_solver_options diff --git a/linopy/remote/ssh.py b/linopy/remote/ssh.py index 426ed646..ba893412 100644 --- a/linopy/remote/ssh.py +++ b/linopy/remote/ssh.py @@ -12,6 +12,10 @@ from typing import TYPE_CHECKING, Any, Union from linopy.io import read_netcdf +from linopy.sos_reformulation import ( + sos_reformulation_context, + suppress_serialization_warning, +) if TYPE_CHECKING: from linopy.model import Model @@ -200,43 +204,52 @@ def execute(self, cmd: str) -> None: if exit_status: raise OSError("Execution on remote raised an error, see above.") - def solve_on_remote(self, model: "Model", **kwargs: Any) -> "Model": + def solve_on_remote( + self, + model: "Model", + *, + reformulate_sos: bool | str = False, + **kwargs: Any, + ) -> "Model": """ Solve a linopy model on the remote machine. - This function - - 1. saves the model to a file on the local machine. - 2. copies that file to the remote machine. - 3. loads, solves and writes out the model, all on the remote machine. - 4. copies the solved model to the local machine. - 5. loads and returns the solved model. + Reformulates SOS constraints locally before serialization when + requested, so the worker just solves a plain MILP and the SOS + lifecycle stays on the caller's model. Parameters ---------- model : linopy.model.Model + reformulate_sos : bool | "auto", optional + Forwarded to ``Model._resolve_sos_reformulation`` to decide + whether to apply SOS reformulation locally before transfer. **kwargs : - Keyword arguments passed to `linopy.model.Model.solve`. + Keyword arguments passed to `linopy.model.Model.solve` on the + remote worker. Returns ------- linopy.model.Model Solved model. """ - self.write_python_file_on_remote(**kwargs) - self.write_model_on_remote(model) + solver_name = kwargs.get("solver_name") + with sos_reformulation_context(model, solver_name, reformulate_sos) as applied: + self.write_python_file_on_remote(**kwargs) + with suppress_serialization_warning(active=applied): + self.write_model_on_remote(model) - command = f"{self.python_executable} {self.python_file}" + command = f"{self.python_executable} {self.python_file}" - logger.info("Solving model on remote.") - self.execute(command) + logger.info("Solving model on remote.") + self.execute(command) - logger.info("Retrieve solved model from remote.") - with tempfile.NamedTemporaryFile(prefix="linopy", suffix=".nc") as fn: - self.sftp_client.get(self.model_solved_file, fn.name) - solved = read_netcdf(fn.name) + logger.info("Retrieve solved model from remote.") + with tempfile.NamedTemporaryFile(prefix="linopy", suffix=".nc") as fn: + self.sftp_client.get(self.model_solved_file, fn.name) + solved = read_netcdf(fn.name) - self.sftp_client.remove(self.python_file) - self.sftp_client.remove(self.model_solved_file) + self.sftp_client.remove(self.python_file) + self.sftp_client.remove(self.model_solved_file) - return solved + return solved diff --git a/linopy/sos_reformulation.py b/linopy/sos_reformulation.py index 0c677216..1f17ee92 100644 --- a/linopy/sos_reformulation.py +++ b/linopy/sos_reformulation.py @@ -8,8 +8,11 @@ from __future__ import annotations import logging +import warnings +from collections.abc import Iterator +from contextlib import contextmanager from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import numpy as np import pandas as pd @@ -328,3 +331,41 @@ def undo_sos_reformulation(model: Model, result: SOSReformulationResult) -> None model.variables[var_name].attrs.update(attrs) model.objective._value = objective_value + + +@contextmanager +def sos_reformulation_context( + model: Model, + solver_name: str | None, + reformulate_sos: bool | Literal["auto"], +) -> Iterator[bool]: + """ + Apply SOS reformulation for the duration of the block, then undo. + + Yields whether the reformulation was actually applied, so callers can + branch on it (e.g. to scope a warning suppression). + """ + applied = model._resolve_sos_reformulation(solver_name, reformulate_sos) + if applied: + logger.info(f"Reformulating SOS constraints for solver {solver_name}") + model.apply_sos_reformulation() + try: + yield applied + finally: + if applied: + model.undo_sos_reformulation() + + +@contextmanager +def suppress_serialization_warning(active: bool) -> Iterator[None]: + """Silence the SOS-active-on-serialize UserWarning when ``active`` is True.""" + if not active: + yield + return + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Serializing a model with an active SOS reformulation", + category=UserWarning, + ) + yield diff --git a/test/test_oetc_settings.py b/test/test_oetc_settings.py index 0a9cac7c..12deeb66 100644 --- a/test/test_oetc_settings.py +++ b/test/test_oetc_settings.py @@ -317,5 +317,5 @@ def test_model_solve_forwards_to_oetc() -> None: m.solve(solver_name="gurobi", remote=handler, TimeLimit=100) handler.solve_on_oetc.assert_called_once_with( - m, solver_name="gurobi", TimeLimit=100 + m, solver_name="gurobi", reformulate_sos=False, TimeLimit=100 ) diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index b6049c76..e8ed8c7b 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -1151,13 +1151,30 @@ def _fake_handler( Model.solve's suppression worked). """ from linopy.io import read_netcdf + from linopy.sos_reformulation import ( + sos_reformulation_context, + suppress_serialization_warning, + ) class _Handler: - def solve_on_remote(_self, model: Model, **kwargs: object) -> Model: - observed["state_active"] = model._sos_reformulation_state is not None - observed["solver_name_arg"] = kwargs.get("solver_name") - model.to_netcdf(tmp_path / "sent.nc") - solved = read_netcdf(tmp_path / "sent.nc") + def solve_on_remote( + _self, + model: Model, + *, + reformulate_sos: bool | str = False, + **kwargs: object, + ) -> Model: + solver_name = kwargs.get("solver_name") + with sos_reformulation_context( + model, solver_name, reformulate_sos + ) as applied: + observed["state_active"] = ( + model._sos_reformulation_state is not None + ) + observed["solver_name_arg"] = solver_name + with suppress_serialization_warning(active=applied): + model.to_netcdf(tmp_path / "sent.nc") + solved = read_netcdf(tmp_path / "sent.nc") for _name, var in solved.variables.items(): arr = np.zeros(var.labels.shape, dtype=float) var.solution = xr.DataArray(arr, dims=var.labels.dims) From 67c939e7c654f63e39e1490994e9c570d92538d2 Mon Sep 17 00:00:00 2001 From: Fabian Date: Tue, 19 May 2026 11:30:12 +0200 Subject: [PATCH 8/9] fix(types): tighten reformulate_sos to bool | Literal["auto"] --- linopy/remote/oetc.py | 4 ++-- linopy/remote/ssh.py | 4 ++-- test/test_sos_reformulation.py | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/linopy/remote/oetc.py b/linopy/remote/oetc.py index beb6245b..beef5873 100644 --- a/linopy/remote/oetc.py +++ b/linopy/remote/oetc.py @@ -10,7 +10,7 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal if TYPE_CHECKING: from linopy.model import Model @@ -639,7 +639,7 @@ def solve_on_oetc( model: Model, solver_name: str | None = None, *, - reformulate_sos: bool | str = False, + reformulate_sos: bool | Literal["auto"] = False, **solver_options: Any, ) -> Model: """ diff --git a/linopy/remote/ssh.py b/linopy/remote/ssh.py index ba893412..ea8fd19e 100644 --- a/linopy/remote/ssh.py +++ b/linopy/remote/ssh.py @@ -9,7 +9,7 @@ import tempfile from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Literal, Union from linopy.io import read_netcdf from linopy.sos_reformulation import ( @@ -208,7 +208,7 @@ def solve_on_remote( self, model: "Model", *, - reformulate_sos: bool | str = False, + reformulate_sos: bool | Literal["auto"] = False, **kwargs: Any, ) -> "Model": """ diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index e8ed8c7b..51ec1770 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Callable from pathlib import Path -from typing import cast +from typing import Literal, cast import numpy as np import pandas as pd @@ -1161,10 +1161,11 @@ def solve_on_remote( _self, model: Model, *, - reformulate_sos: bool | str = False, + reformulate_sos: bool | Literal["auto"] = False, **kwargs: object, ) -> Model: solver_name = kwargs.get("solver_name") + assert solver_name is None or isinstance(solver_name, str) with sos_reformulation_context( model, solver_name, reformulate_sos ) as applied: From 4497392f4c9833bedc6d9a0f93f28e6a0789af8e Mon Sep 17 00:00:00 2001 From: Fabian Date: Tue, 19 May 2026 11:52:36 +0200 Subject: [PATCH 9/9] test(ssh): cover SOS bracket in RemoteHandler.solve_on_remote --- test/remote/test_ssh.py | 157 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) create mode 100644 test/remote/test_ssh.py diff --git a/test/remote/test_ssh.py b/test/remote/test_ssh.py new file mode 100644 index 00000000..c6960c84 --- /dev/null +++ b/test/remote/test_ssh.py @@ -0,0 +1,157 @@ +"""Tests for ``linopy.remote.ssh.RemoteHandler.solve_on_remote``.""" + +from __future__ import annotations + +import warnings +from collections.abc import Iterator +from contextlib import contextmanager +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import numpy as np +import pandas as pd +import pytest + +pytest.importorskip("paramiko") + +from linopy import Model # noqa: E402 +from linopy.remote.ssh import RemoteHandler # noqa: E402 + + +class _FakeSFTPClient: + """In-memory SFTP stand-in: ``put`` / ``get`` round-trip file bytes.""" + + def __init__(self) -> None: + self.store: dict[str, bytes] = {} + + def open(self, path: str, mode: str) -> Any: + store = self.store + + @contextmanager + def _writer() -> Iterator[Any]: + class _Writer: + def write(self_inner, data: str | bytes) -> None: + store[path] = data.encode() if isinstance(data, str) else data + + yield _Writer() + + return _writer() + + def put(self, local_path: str, remote_path: str) -> None: + with open(local_path, "rb") as fh: + self.store[remote_path] = fh.read() + + def get(self, remote_path: str, local_path: str) -> None: + with open(local_path, "wb") as fh: + fh.write(self.store[remote_path]) + + def remove(self, path: str) -> None: + self.store.pop(path, None) + + +def _make_sos_model() -> Model: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x * np.array([1.0, 2.0, 3.0]), sense="max") + return m + + +@pytest.fixture +def handler() -> RemoteHandler: + """``RemoteHandler`` wired to an in-memory SFTP and a no-op shell.""" + client = MagicMock() + client.invoke_shell.return_value.makefile.return_value = MagicMock() + sftp = _FakeSFTPClient() + client.open_sftp.return_value = sftp + + h = RemoteHandler(hostname="fake", client=client) + # The unsolved model gets put() into sftp.store under model_unsolved_file; + # serve it back as the "solved" model so read_netcdf has something valid. + h.sftp_client = sftp # type: ignore[assignment] + h.execute = MagicMock() # type: ignore[method-assign] + + original_put = sftp.put + + def put_and_mirror(local_path: str, remote_path: str) -> None: + original_put(local_path, remote_path) + if remote_path == h.model_unsolved_file: + sftp.store[h.model_solved_file] = sftp.store[remote_path] + + sftp.put = put_and_mirror # type: ignore[method-assign] + return h + + +class TestSolveOnRemoteSosBracket: + """``solve_on_remote`` must bracket SOS reformulation around transfer.""" + + def test_reformulates_before_transfer_and_restores_after( + self, handler: RemoteHandler + ) -> None: + m = _make_sos_model() + + observed: dict[str, bool] = {} + real_write = handler.write_model_on_remote + + def spy_write(model: Model) -> None: + observed["state_active"] = model._sos_reformulation_state is not None + observed["has_aux_var"] = "_sos_reform_x_y" in model.variables + real_write(model) + + handler.write_model_on_remote = spy_write # type: ignore[method-assign] + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + handler.solve_on_remote(m, reformulate_sos=True, solver_name="highs") + + assert observed["state_active"] is True + assert observed["has_aux_var"] is True + assert not any("active SOS reformulation" in str(w.message) for w in captured) + assert m._sos_reformulation_state is None + assert "_sos_reform_x_y" not in m.variables + assert list(m.variables.sos) == ["x"] + + def test_skips_bracket_when_reformulate_sos_false( + self, handler: RemoteHandler + ) -> None: + m = _make_sos_model() + + observed: dict[str, bool] = {} + real_write = handler.write_model_on_remote + + def spy_write(model: Model) -> None: + observed["state_active"] = model._sos_reformulation_state is not None + real_write(model) + + handler.write_model_on_remote = spy_write # type: ignore[method-assign] + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + handler.solve_on_remote(m, reformulate_sos=False) + + assert observed["state_active"] is False + assert not any("active SOS reformulation" in str(w.message) for w in captured) + assert m._sos_reformulation_state is None + + def test_auto_without_solver_name_raises_on_sos_model( + self, handler: RemoteHandler + ) -> None: + m = _make_sos_model() + with pytest.raises(ValueError, match="requires an explicit `solver_name`"): + handler.solve_on_remote(m, reformulate_sos="auto") + + def test_no_sos_model_passes_through_unchanged( + self, handler: RemoteHandler, tmp_path: Path + ) -> None: + m = Model() + x = m.add_variables(lower=0, upper=1, name="x") + m.add_objective(1.0 * x, sense="max") + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + handler.solve_on_remote(m, reformulate_sos="auto") + + assert m._sos_reformulation_state is None + assert not any("active SOS reformulation" in str(w.message) for w in captured)