diff --git a/linopy/io.py b/linopy/io.py index 4dc4dc02..b0abe9fb 100644 --- a/linopy/io.py +++ b/linopy/io.py @@ -5,10 +5,12 @@ from __future__ import annotations +import copy as _copy import json import logging import shutil import time +import warnings from collections.abc import Callable, Iterable from io import BufferedWriter from pathlib import Path @@ -845,7 +847,29 @@ def to_netcdf(m: Model, *args: Any, **kwargs: Any) -> None: Arguments passed to ``xarray.Dataset.to_netcdf``. **kwargs : TYPE Keyword arguments passed to ``xarray.Dataset.to_netcdf``. + + Notes + ----- + The SOS reformulation lifecycle token lives only on the in-memory + Model and is not persisted. If the model has an active SOS + reformulation at serialization time, the netcdf contains the + reformulated MILP form (aux binaries and cardinality constraints) + and a :class:`UserWarning` is emitted to flag that the deserialized + copy will not be able to undo the reformulation. + + ``Model.solve(remote=...)`` invokes ``to_netcdf`` internally on the + reformulated model and suppresses this warning. """ + if m._sos_reformulation_state is not None: + warnings.warn( + "Serializing a model with an active SOS reformulation. The " + "netcdf will contain the reformulated MILP form; the " + "reformulation lifecycle token is not persisted, so a " + "reader cannot undo it. Call `model.undo_sos_reformulation()` " + "first if you want the original SOS form on disk.", + UserWarning, + stacklevel=2, + ) def with_prefix(ds: xr.Dataset, prefix: str) -> xr.Dataset: to_rename = set([*ds.dims, *ds.coords, *ds]) @@ -916,6 +940,13 @@ def read_netcdf(path: Path | str, **kwargs: Any) -> Model: Returns ------- m : linopy.Model + + Notes + ----- + The SOS reformulation lifecycle token is not persisted by + :func:`to_netcdf`. If the saved model was in reformulated form, + the deserialized Model is too, but + :meth:`Model.undo_sos_reformulation` is a no-op on it. """ from linopy.constraints import ( Constraint, @@ -1117,6 +1148,9 @@ def copy(m: Model, include_solution: bool = False, deep: bool = True) -> Model: if include_solution or attr not in SOLVE_STATE_ATTRS: setattr(new_model, attr, getattr(m, attr)) + if m._sos_reformulation_state is not None: + new_model._sos_reformulation_state = _copy.deepcopy(m._sos_reformulation_state) + return new_model diff --git a/linopy/model.py b/linopy/model.py index 03fd9479..250d65fe 100644 --- a/linopy/model.py +++ b/linopy/model.py @@ -84,13 +84,16 @@ from linopy.remote import OetcHandler except ImportError: OetcHandler = None # type: ignore +from linopy.solver_capabilities import solver_supports from linopy.solvers import ( IO_APIS, SolverFeature, available_solvers, ) from linopy.sos_reformulation import ( + SOSReformulationResult, reformulate_sos_constraints, + sos_reformulation_context, undo_sos_reformulation, ) from linopy.types import ( @@ -240,6 +243,7 @@ class Model: "_relaxed_registry", "_piecewise_formulations", "_solver", + "_sos_reformulation_state", "__weakref__", ) @@ -310,6 +314,7 @@ def __init__( gettempdir() if solver_dir is None else solver_dir ) self._solver: solvers.Solver | None = None + self._sos_reformulation_state: SOSReformulationResult | None = None @property def solver(self) -> solvers.Solver | None: @@ -1221,6 +1226,80 @@ def remove_sos_constraints(self, variable: Variable) -> None: reformulate_sos_constraints = reformulate_sos_constraints + def apply_sos_reformulation(self) -> None: + """ + Reformulate SOS constraints into binary + linear form, in place. + + The reformulation token is stored on the model so it can be reverted + with :meth:`undo_sos_reformulation`. This is the stateful counterpart + to :func:`linopy.sos_reformulation.reformulate_sos_constraints`, where + the caller owns the token. + + Raises + ------ + RuntimeError + If a reformulation has already been applied and not undone. + """ + if self._sos_reformulation_state is not None: + raise RuntimeError( + "SOS reformulation has already been applied to this model. " + "Call `undo_sos_reformulation()` before applying again." + ) + self._sos_reformulation_state = reformulate_sos_constraints(self) + + def undo_sos_reformulation(self) -> None: + """ + Revert a previously applied SOS reformulation. + + Raises + ------ + RuntimeError + If no reformulation is currently applied. + """ + if self._sos_reformulation_state is None: + raise RuntimeError( + "No SOS reformulation is currently applied to this model." + ) + state = self._sos_reformulation_state + self._sos_reformulation_state = None + undo_sos_reformulation(self, state) + + def _resolve_sos_reformulation( + self, + solver_name: str | None, + reformulate_sos: bool | Literal["auto"], + ) -> bool: + """ + Decide whether ``apply_sos_reformulation`` should run. + + Validates ``reformulate_sos`` and returns ``True`` iff the SOS + constraints on this model should be reformulated for the chosen + solver. ``solver_name`` is only consulted when + ``reformulate_sos == "auto"`` (to look up SOS support); for + ``True`` / ``False`` the decision is independent of the solver. + """ + if reformulate_sos not in (True, False, "auto"): + raise ValueError( + f"Invalid value for reformulate_sos: {reformulate_sos!r}. " + "Must be True, False, or 'auto'." + ) + if not self.variables.sos: + return False + + if reformulate_sos is False: + return False + elif reformulate_sos is True: + return True + elif solver_name is None: + raise ValueError( + "`reformulate_sos='auto'` on a model with SOS constraints " + "requires an explicit `solver_name` so we can check " + "whether the chosen solver supports SOS. Pass " + "`solver_name=...` or use `reformulate_sos=True`/`False` " + "to skip the lookup." + ) + return not solver_supports(solver_name, SolverFeature.SOS_CONSTRAINTS) + def _check_sos_unmasked(self) -> None: """ Reject the model if any SOS variable has masked entries. @@ -1642,12 +1721,6 @@ def solve( sanitize_zeros=sanitize_zeros, sanitize_infinities=sanitize_infinities ) - if self.objective.expression.empty: - raise ValueError( - "No objective has been set on the model. Use `m.add_objective(...)` " - "first (e.g. `m.add_objective(0 * x)` for a pure feasibility problem)." - ) - # check io_api if io_api is not None and io_api not in IO_APIS: raise ValueError( @@ -1655,9 +1728,22 @@ def solve( ) if remote is not None: + # The remote branch short-circuits before reaching Solver.solve(), + # which is where the empty-objective check normally fires. Replicate + # it here. This duplication becomes obsolete once OETC is folded + # into the Solver pipeline (see PyPSA/linopy#683). + if self.objective.expression.empty: + raise ValueError( + "No objective has been set on the model. Use " + "`m.add_objective(...)` first (e.g. `m.add_objective(0 * x)` " + "for a pure feasibility problem)." + ) if isinstance(remote, OetcHandler): solved = remote.solve_on_oetc( - self, solver_name=solver_name, **solver_options + self, + solver_name=solver_name, + reformulate_sos=reformulate_sos, + **solver_options, ) else: solved = remote.solve_on_remote( @@ -1671,6 +1757,7 @@ def solve( warmstart_fn=warmstart_fn, keep_files=keep_files, sanitize_zeros=sanitize_zeros, + reformulate_sos=reformulate_sos, **solver_options, ) @@ -1720,95 +1807,82 @@ def solve( else: solution_fn = self.get_solution_file() - if sanitize_zeros: - self.constraints.sanitize_zeros() - - if sanitize_infinities: - self.constraints.sanitize_infinities() - - if self.is_quadratic and not solver_class.supports( - SolverFeature.QUADRATIC_OBJECTIVE - ): - raise ValueError( - f"Solver {solver_name} does not support quadratic problems." - ) - - if reformulate_sos not in (True, False, "auto"): - raise ValueError( - f"Invalid value for reformulate_sos: {reformulate_sos!r}. " - "Must be True, False, or 'auto'." - ) - - sos_reform_result = None - if self.variables.sos: - supports_sos = solver_class.supports(SolverFeature.SOS_CONSTRAINTS) - should_reformulate = reformulate_sos is True or ( - reformulate_sos == "auto" and not supports_sos - ) + with sos_reformulation_context(self, solver_name, reformulate_sos): + if sanitize_zeros: + self.constraints.sanitize_zeros() + if sanitize_infinities: + self.constraints.sanitize_infinities() - if should_reformulate: - logger.info(f"Reformulating SOS constraints for solver {solver_name}") - sos_reform_result = reformulate_sos_constraints(self) - elif reformulate_sos is False and not supports_sos: - raise ValueError( - f"Solver {solver_name} does not support SOS constraints. " - "Use reformulate_sos=True or 'auto', or a solver that supports SOS." + try: + self.solver = None # closes any previous solver + if io_api == "direct": + if set_names is None: + set_names = self.set_names_in_solver_io + build_kwargs: dict[str, Any] = { + "explicit_coordinate_names": explicit_coordinate_names, + "set_names": set_names, + "log_fn": to_path(log_fn), + } + if env is not None: + build_kwargs["env"] = env + else: + build_kwargs = { + "explicit_coordinate_names": explicit_coordinate_names, + "slice_size": slice_size, + "progress": progress, + "problem_fn": to_path(problem_fn), + } + self.solver = solver = solvers.Solver.from_name( + solver_name, + model=self, + io_api=io_api, + options=solver_options, + **build_kwargs, ) - - if self.variables.semi_continuous: - if not solver_class.supports(SolverFeature.SEMI_CONTINUOUS_VARIABLES): - raise ValueError( - f"Solver {solver_name} does not support semi-continuous variables. " - "Use a solver that supports them (gurobi, cplex, highs)." + if io_api != "direct": + problem_fn = solver._problem_fn + result = solver.solve( + solution_fn=to_path(solution_fn), + log_fn=to_path(log_fn), + warmstart_fn=to_path(warmstart_fn), + basis_fn=to_path(basis_fn), + env=env, ) + finally: + for fn in (problem_fn, solution_fn): + if fn is not None and (os.path.exists(fn) and not keep_files): + os.remove(fn) - try: - self.solver = None # closes any previous solver - if io_api == "direct": - if set_names is None: - set_names = self.set_names_in_solver_io - build_kwargs: dict[str, Any] = { - "explicit_coordinate_names": explicit_coordinate_names, - "set_names": set_names, - "log_fn": to_path(log_fn), - } - if env is not None: - build_kwargs["env"] = env - else: - build_kwargs = { - "explicit_coordinate_names": explicit_coordinate_names, - "slice_size": slice_size, - "progress": progress, - "problem_fn": to_path(problem_fn), - } - self.solver = solver = solvers.Solver.from_name( - solver_name, - model=self, - io_api=io_api, - options=solver_options, - **build_kwargs, - ) - if io_api != "direct": - problem_fn = solver._problem_fn - result = solver.solve( - solution_fn=to_path(solution_fn), - log_fn=to_path(log_fn), - warmstart_fn=to_path(warmstart_fn), - basis_fn=to_path(basis_fn), - env=env, - ) - finally: - for fn in (problem_fn, solution_fn): - if fn is not None and (os.path.exists(fn) and not keep_files): - os.remove(fn) - - try: return self.assign_result(result) - finally: - if sos_reform_result is not None: - undo_sos_reformulation(self, sos_reform_result) - def assign_result(self, result: Result) -> tuple[str, str]: + def assign_result( + self, + result: Result, + solver: solvers.Solver | None = None, + ) -> tuple[str, str]: + """ + Write a solver Result back onto the model. + + Copies primal / dual values onto variables / constraints, sets + :attr:`status`, :attr:`termination_condition`, and + :attr:`objective.value`. When ``solver`` is provided, also stores it on + ``self.solver`` so post-solve introspection (``model.solver_model``, + ``compute_infeasibilities()``) works. + + Parameters + ---------- + result : Result + The :class:`linopy.constants.Result` returned by + :meth:`linopy.solvers.Solver.solve`. + solver : Solver, optional + The solver instance that produced the result. Pass it on the + low-level ``Solver.from_name(...).solve()`` path to attach it as + ``self.solver`` for post-solve introspection. ``Model.solve()`` + attaches the solver itself and does not pass this argument. + """ + if solver is not None: + self.solver = solver + result.info() if result.solution is not None: diff --git a/linopy/remote/oetc.py b/linopy/remote/oetc.py index f451a43d..beef5873 100644 --- a/linopy/remote/oetc.py +++ b/linopy/remote/oetc.py @@ -10,7 +10,7 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal if TYPE_CHECKING: from linopy.model import Model @@ -26,6 +26,10 @@ _oetc_deps_available = False import linopy +from linopy.sos_reformulation import ( + sos_reformulation_context, + suppress_serialization_warning, +) logger = logging.getLogger(__name__) @@ -631,7 +635,12 @@ def _download_file_from_gcp(self, file_name: str) -> str: raise Exception(f"Failed to download file from GCP: {e}") def solve_on_oetc( - self, model: Model, solver_name: str | None = None, **solver_options: Any + self, + model: Model, + solver_name: str | None = None, + *, + reformulate_sos: bool | Literal["auto"] = False, + **solver_options: Any, ) -> Model: """ Solve a linopy model on the OET Cloud compute app. @@ -657,10 +666,14 @@ def solve_on_oetc( effective_solver = solver_name or self.settings.solver merged_solver_options = {**self.settings.solver_options, **solver_options} - with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn: - fn.file.close() - model.to_netcdf(fn.name) - input_file_name = self._upload_file_to_gcp(fn.name) + with sos_reformulation_context( + model, effective_solver, reformulate_sos + ) as applied: + with tempfile.NamedTemporaryFile(prefix="linopy-", suffix=".nc") as fn: + fn.file.close() + with suppress_serialization_warning(active=applied): + model.to_netcdf(fn.name) + input_file_name = self._upload_file_to_gcp(fn.name) job_uuid = self._submit_job_to_compute_service( input_file_name, effective_solver, merged_solver_options diff --git a/linopy/remote/ssh.py b/linopy/remote/ssh.py index 426ed646..ea8fd19e 100644 --- a/linopy/remote/ssh.py +++ b/linopy/remote/ssh.py @@ -9,9 +9,13 @@ import tempfile from collections.abc import Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Literal, Union from linopy.io import read_netcdf +from linopy.sos_reformulation import ( + sos_reformulation_context, + suppress_serialization_warning, +) if TYPE_CHECKING: from linopy.model import Model @@ -200,43 +204,52 @@ def execute(self, cmd: str) -> None: if exit_status: raise OSError("Execution on remote raised an error, see above.") - def solve_on_remote(self, model: "Model", **kwargs: Any) -> "Model": + def solve_on_remote( + self, + model: "Model", + *, + reformulate_sos: bool | Literal["auto"] = False, + **kwargs: Any, + ) -> "Model": """ Solve a linopy model on the remote machine. - This function - - 1. saves the model to a file on the local machine. - 2. copies that file to the remote machine. - 3. loads, solves and writes out the model, all on the remote machine. - 4. copies the solved model to the local machine. - 5. loads and returns the solved model. + Reformulates SOS constraints locally before serialization when + requested, so the worker just solves a plain MILP and the SOS + lifecycle stays on the caller's model. Parameters ---------- model : linopy.model.Model + reformulate_sos : bool | "auto", optional + Forwarded to ``Model._resolve_sos_reformulation`` to decide + whether to apply SOS reformulation locally before transfer. **kwargs : - Keyword arguments passed to `linopy.model.Model.solve`. + Keyword arguments passed to `linopy.model.Model.solve` on the + remote worker. Returns ------- linopy.model.Model Solved model. """ - self.write_python_file_on_remote(**kwargs) - self.write_model_on_remote(model) + solver_name = kwargs.get("solver_name") + with sos_reformulation_context(model, solver_name, reformulate_sos) as applied: + self.write_python_file_on_remote(**kwargs) + with suppress_serialization_warning(active=applied): + self.write_model_on_remote(model) - command = f"{self.python_executable} {self.python_file}" + command = f"{self.python_executable} {self.python_file}" - logger.info("Solving model on remote.") - self.execute(command) + logger.info("Solving model on remote.") + self.execute(command) - logger.info("Retrieve solved model from remote.") - with tempfile.NamedTemporaryFile(prefix="linopy", suffix=".nc") as fn: - self.sftp_client.get(self.model_solved_file, fn.name) - solved = read_netcdf(fn.name) + logger.info("Retrieve solved model from remote.") + with tempfile.NamedTemporaryFile(prefix="linopy", suffix=".nc") as fn: + self.sftp_client.get(self.model_solved_file, fn.name) + solved = read_netcdf(fn.name) - self.sftp_client.remove(self.python_file) - self.sftp_client.remove(self.model_solved_file) + self.sftp_client.remove(self.python_file) + self.sftp_client.remove(self.model_solved_file) - return solved + return solved diff --git a/linopy/solvers.py b/linopy/solvers.py index 364e8ced..44db983f 100644 --- a/linopy/solvers.py +++ b/linopy/solvers.py @@ -505,15 +505,52 @@ def from_model( return instance def _build(self, **build_kwargs: Any) -> None: - """Dispatch to direct or file build based on ``io_api``.""" + """ + Dispatch to direct or file build based on ``io_api``. + + The Solver never mutates ``self.model``. Constraint sanitization + (``model.constraints.sanitize_zeros()`` / + ``.sanitize_infinities()``) and SOS reformulation + (``model.apply_sos_reformulation()``) are Model-level operations + the caller applies first; this builder consumes whatever shape it + is handed. + """ if self.model is None: raise RuntimeError("Solver has no model attached; cannot build.") + self._validate_model() self.model._check_sos_unmasked() if self.io_api == "direct": self._build_direct(**build_kwargs) else: self._build_file(**build_kwargs) + def _validate_model(self) -> None: + """Pre-build checks on whether this solver can handle ``self.model``.""" + model = self.model + assert model is not None + solver_name = self.solver_name.value + cls = type(self) + + if model.is_quadratic and not cls.supports(SolverFeature.QUADRATIC_OBJECTIVE): + raise ValueError( + f"Solver {solver_name} does not support quadratic problems." + ) + + if model.variables.semi_continuous and not cls.supports( + SolverFeature.SEMI_CONTINUOUS_VARIABLES + ): + raise ValueError( + f"Solver {solver_name} does not support semi-continuous variables. " + "Use a solver that supports them (gurobi, cplex, highs)." + ) + + if model.variables.sos and not cls.supports(SolverFeature.SOS_CONSTRAINTS): + raise ValueError( + f"Solver {solver_name} does not support SOS constraints. " + "Reformulate first via `Model.solve(reformulate_sos=True)` or " + "`model.apply_sos_reformulation()`, or use a solver that supports SOS." + ) + def _build_direct(self, **build_kwargs: Any) -> None: """Build the native solver model from ``self.model``. Override per-solver.""" raise NotImplementedError( @@ -554,7 +591,30 @@ def _build_file(self, **build_kwargs: Any) -> None: self._cache_model_sizes(model) def solve(self, **run_kwargs: Any) -> Result: - """Run the prepared solver and return a :class:`Result`.""" + """ + Run the prepared solver and return a :class:`Result`. + + The canonical low-level pattern is:: + + solver = Solver.from_name("gurobi", model, io_api="direct") + result = solver.solve() + model.assign_result(result, solver=solver) + + Passing ``solver=`` to :meth:`Model.assign_result` wires + ``model.solver`` so post-solve helpers like + :meth:`Model.compute_infeasibilities` keep working. + + Raises + ------ + ValueError + If the attached model has no objective set. Submit-time check + shared by both ``Model.solve()`` and direct-Solver callers. + """ + if self.model is not None and self.model.objective.expression.empty: + raise ValueError( + "No objective has been set on the model. Use `m.add_objective(...)` " + "first (e.g. `m.add_objective(0 * x)` for a pure feasibility problem)." + ) if self.io_api == "direct" or self.solver_model is not None: return self._run_direct(**run_kwargs) if self._problem_fn is not None: diff --git a/linopy/sos_reformulation.py b/linopy/sos_reformulation.py index 8ccb7613..1f17ee92 100644 --- a/linopy/sos_reformulation.py +++ b/linopy/sos_reformulation.py @@ -8,8 +8,11 @@ from __future__ import annotations import logging +import warnings +from collections.abc import Iterator +from contextlib import contextmanager from dataclasses import dataclass, field -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import numpy as np import pandas as pd @@ -233,8 +236,10 @@ def reformulate_sos_constraints( 1. If custom big_m was specified in add_sos_constraints(), use that 2. Otherwise, use the variable bounds (tightest valid Big-M) - Note: This permanently mutates the model. To solve with automatic - undo, use ``model.solve(reformulate_sos=True)`` instead. + Note: This permanently mutates the model and returns a token the caller + owns. For a stateful, reversible API use ``model.apply_sos_reformulation()`` + / ``model.undo_sos_reformulation()``; for automatic undo around a single + solve use ``model.solve(reformulate_sos=True)``. Parameters ---------- @@ -326,3 +331,41 @@ def undo_sos_reformulation(model: Model, result: SOSReformulationResult) -> None model.variables[var_name].attrs.update(attrs) model.objective._value = objective_value + + +@contextmanager +def sos_reformulation_context( + model: Model, + solver_name: str | None, + reformulate_sos: bool | Literal["auto"], +) -> Iterator[bool]: + """ + Apply SOS reformulation for the duration of the block, then undo. + + Yields whether the reformulation was actually applied, so callers can + branch on it (e.g. to scope a warning suppression). + """ + applied = model._resolve_sos_reformulation(solver_name, reformulate_sos) + if applied: + logger.info(f"Reformulating SOS constraints for solver {solver_name}") + model.apply_sos_reformulation() + try: + yield applied + finally: + if applied: + model.undo_sos_reformulation() + + +@contextmanager +def suppress_serialization_warning(active: bool) -> Iterator[None]: + """Silence the SOS-active-on-serialize UserWarning when ``active`` is True.""" + if not active: + yield + return + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="Serializing a model with an active SOS reformulation", + category=UserWarning, + ) + yield diff --git a/test/remote/test_ssh.py b/test/remote/test_ssh.py new file mode 100644 index 00000000..c6960c84 --- /dev/null +++ b/test/remote/test_ssh.py @@ -0,0 +1,157 @@ +"""Tests for ``linopy.remote.ssh.RemoteHandler.solve_on_remote``.""" + +from __future__ import annotations + +import warnings +from collections.abc import Iterator +from contextlib import contextmanager +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import numpy as np +import pandas as pd +import pytest + +pytest.importorskip("paramiko") + +from linopy import Model # noqa: E402 +from linopy.remote.ssh import RemoteHandler # noqa: E402 + + +class _FakeSFTPClient: + """In-memory SFTP stand-in: ``put`` / ``get`` round-trip file bytes.""" + + def __init__(self) -> None: + self.store: dict[str, bytes] = {} + + def open(self, path: str, mode: str) -> Any: + store = self.store + + @contextmanager + def _writer() -> Iterator[Any]: + class _Writer: + def write(self_inner, data: str | bytes) -> None: + store[path] = data.encode() if isinstance(data, str) else data + + yield _Writer() + + return _writer() + + def put(self, local_path: str, remote_path: str) -> None: + with open(local_path, "rb") as fh: + self.store[remote_path] = fh.read() + + def get(self, remote_path: str, local_path: str) -> None: + with open(local_path, "wb") as fh: + fh.write(self.store[remote_path]) + + def remove(self, path: str) -> None: + self.store.pop(path, None) + + +def _make_sos_model() -> Model: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x * np.array([1.0, 2.0, 3.0]), sense="max") + return m + + +@pytest.fixture +def handler() -> RemoteHandler: + """``RemoteHandler`` wired to an in-memory SFTP and a no-op shell.""" + client = MagicMock() + client.invoke_shell.return_value.makefile.return_value = MagicMock() + sftp = _FakeSFTPClient() + client.open_sftp.return_value = sftp + + h = RemoteHandler(hostname="fake", client=client) + # The unsolved model gets put() into sftp.store under model_unsolved_file; + # serve it back as the "solved" model so read_netcdf has something valid. + h.sftp_client = sftp # type: ignore[assignment] + h.execute = MagicMock() # type: ignore[method-assign] + + original_put = sftp.put + + def put_and_mirror(local_path: str, remote_path: str) -> None: + original_put(local_path, remote_path) + if remote_path == h.model_unsolved_file: + sftp.store[h.model_solved_file] = sftp.store[remote_path] + + sftp.put = put_and_mirror # type: ignore[method-assign] + return h + + +class TestSolveOnRemoteSosBracket: + """``solve_on_remote`` must bracket SOS reformulation around transfer.""" + + def test_reformulates_before_transfer_and_restores_after( + self, handler: RemoteHandler + ) -> None: + m = _make_sos_model() + + observed: dict[str, bool] = {} + real_write = handler.write_model_on_remote + + def spy_write(model: Model) -> None: + observed["state_active"] = model._sos_reformulation_state is not None + observed["has_aux_var"] = "_sos_reform_x_y" in model.variables + real_write(model) + + handler.write_model_on_remote = spy_write # type: ignore[method-assign] + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + handler.solve_on_remote(m, reformulate_sos=True, solver_name="highs") + + assert observed["state_active"] is True + assert observed["has_aux_var"] is True + assert not any("active SOS reformulation" in str(w.message) for w in captured) + assert m._sos_reformulation_state is None + assert "_sos_reform_x_y" not in m.variables + assert list(m.variables.sos) == ["x"] + + def test_skips_bracket_when_reformulate_sos_false( + self, handler: RemoteHandler + ) -> None: + m = _make_sos_model() + + observed: dict[str, bool] = {} + real_write = handler.write_model_on_remote + + def spy_write(model: Model) -> None: + observed["state_active"] = model._sos_reformulation_state is not None + real_write(model) + + handler.write_model_on_remote = spy_write # type: ignore[method-assign] + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + handler.solve_on_remote(m, reformulate_sos=False) + + assert observed["state_active"] is False + assert not any("active SOS reformulation" in str(w.message) for w in captured) + assert m._sos_reformulation_state is None + + def test_auto_without_solver_name_raises_on_sos_model( + self, handler: RemoteHandler + ) -> None: + m = _make_sos_model() + with pytest.raises(ValueError, match="requires an explicit `solver_name`"): + handler.solve_on_remote(m, reformulate_sos="auto") + + def test_no_sos_model_passes_through_unchanged( + self, handler: RemoteHandler, tmp_path: Path + ) -> None: + m = Model() + x = m.add_variables(lower=0, upper=1, name="x") + m.add_objective(1.0 * x, sense="max") + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + handler.solve_on_remote(m, reformulate_sos="auto") + + assert m._sos_reformulation_state is None + assert not any("active SOS reformulation" in str(w.message) for w in captured) diff --git a/test/test_oetc_settings.py b/test/test_oetc_settings.py index 0a9cac7c..12deeb66 100644 --- a/test/test_oetc_settings.py +++ b/test/test_oetc_settings.py @@ -317,5 +317,5 @@ def test_model_solve_forwards_to_oetc() -> None: m.solve(solver_name="gurobi", remote=handler, TimeLimit=100) handler.solve_on_oetc.assert_called_once_with( - m, solver_name="gurobi", TimeLimit=100 + m, solver_name="gurobi", reformulate_sos=False, TimeLimit=100 ) diff --git a/test/test_solvers.py b/test/test_solvers.py index 86600dae..1109c4c0 100644 --- a/test/test_solvers.py +++ b/test/test_solvers.py @@ -23,6 +23,14 @@ from linopy.solvers import _installed_version_in +@pytest.fixture +def lp_only_solver() -> str: + for name in ("glpk", "cbc"): + if name in solvers.available_solvers: + return name + pytest.skip("Need an LP-only solver (glpk or cbc) installed") + + @pytest.fixture def simple_model() -> Model: m = Model(chunk=None) @@ -467,3 +475,95 @@ def test_xpress_gpu_feature_reflects_installed_version() -> None: assert solvers.Xpress.supports( SolverFeature.GPU_ACCELERATION ) == _installed_version_in("xpress", ">=9.8.0") + + +class TestValidateModelOnBuild: + """Solver._build() runs solver-feature checks regardless of entry point.""" + + def test_quadratic_without_qp_support_raises(self, lp_only_solver: str) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + m.add_objective(x * x, sense="min") + + with pytest.raises(ValueError, match="does not support quadratic"): + solvers.Solver.from_name(lp_only_solver, m, io_api="lp") + + def test_semi_continuous_without_support_raises(self, lp_only_solver: str) -> None: + m = Model() + x = m.add_variables(name="x", lower=1, upper=10, semi_continuous=True) + m.add_objective(x) + + with pytest.raises(ValueError, match="does not support semi-continuous"): + solvers.Solver.from_name(lp_only_solver, m, io_api="lp") + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_solve_without_objective_raises(self) -> None: + m = Model() + m.add_variables(name="x", lower=0, upper=10) + # No objective added — both entry points should raise the same error. + with pytest.raises(ValueError, match="No objective has been set"): + solvers.Solver.from_name("highs", m, io_api="lp").solve() + with pytest.raises(ValueError, match="No objective has been set"): + m.solve("highs") + + +class TestSolverDoesNotMutateModel: + """Solver.from_model() must not mutate model state (sanitize stays Model-level).""" + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_from_model_leaves_constraints_untouched(self) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + # Constraint with a near-zero coefficient — would be sanitized away if + # the Solver path were sanitizing on build. + m.add_constraints(1e-12 * x + x >= 0, name="c") + m.add_objective(x) + + before = m.constraints["c"].coeffs.values.copy() + solvers.Solver.from_name("highs", m, io_api="lp") + after = m.constraints["c"].coeffs.values + + assert np.allclose(before, after, equal_nan=True), ( + "Solver.from_model() must not mutate model constraints. " + "Sanitization is a Model-level primitive; call " + "model.constraints.sanitize_zeros() / .sanitize_infinities() " + "explicitly before building." + ) + + +class TestAssignResultWiring: + """assign_result(result, solver=...) populates model.solver.""" + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_assign_result_with_solver_wires_model_solver(self) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + m.add_objective(x, sense="min") + + assert m.solver is None + solver = solvers.Solver.from_name("highs", m, io_api="lp") + result = solver.solve() + m.assign_result(result, solver=solver) + + assert m.solver is solver + assert m.solver_model is solver.solver_model + + @pytest.mark.skipif( + "highs" not in solvers.available_solvers, reason="HiGHS not installed" + ) + def test_assign_result_without_solver_kwarg_leaves_solver_unset(self) -> None: + m = Model() + x = m.add_variables(name="x", lower=0, upper=10) + m.add_objective(x, sense="min") + + solver = solvers.Solver.from_name("highs", m, io_api="lp") + result = solver.solve() + m.assign_result(result) # no solver kwarg + + assert m.solver is None diff --git a/test/test_sos_constraints.py b/test/test_sos_constraints.py index 30b2d767..a9529dc0 100644 --- a/test/test_sos_constraints.py +++ b/test/test_sos_constraints.py @@ -316,7 +316,7 @@ def test_unsupported_solver_raises_error() -> None: m.solve(solver_name=solver) -def test_to_highspy_raises_not_implemented() -> None: +def test_to_highspy_raises_when_sos_present() -> None: pytest.importorskip("highspy") m = Model() @@ -324,8 +324,5 @@ def test_to_highspy_raises_not_implemented() -> None: build = m.add_variables(coords=[locations], name="build", binary=True) m.add_sos_constraints(build, sos_type=1, sos_dim="locations") - with pytest.raises( - NotImplementedError, - match="SOS constraints are not supported by the HiGHS direct API", - ): + with pytest.raises(ValueError, match="does not support SOS constraints"): m.to_highspy() diff --git a/test/test_sos_reformulation.py b/test/test_sos_reformulation.py index 24ba62b3..51ec1770 100644 --- a/test/test_sos_reformulation.py +++ b/test/test_sos_reformulation.py @@ -3,13 +3,19 @@ from __future__ import annotations import logging +import warnings +from collections.abc import Callable +from pathlib import Path +from typing import Literal, cast import numpy as np import pandas as pd import pytest +import xarray as xr from linopy import Model, Variable, available_solvers from linopy.constants import SOS_TYPE_ATTR +from linopy.remote import RemoteHandler from linopy.sos_reformulation import ( compute_big_m_values, reformulate_sos1, @@ -312,6 +318,163 @@ def test_reformulate_inplace(self) -> None: assert "_sos_reform_x_y" in m.variables +class TestApplyUndoSOSReformulation: + """Tests for Model.apply_sos_reformulation / undo_sos_reformulation.""" + + @staticmethod + def _build_sos1_model() -> Model: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + return m + + def test_apply_stashes_state(self) -> None: + m = self._build_sos1_model() + assert m._sos_reformulation_state is None + + m.apply_sos_reformulation() + + assert m._sos_reformulation_state is not None + assert m._sos_reformulation_state.reformulated == ["x"] + assert len(list(m.variables.sos)) == 0 + assert "_sos_reform_x_y" in m.variables + + def test_undo_restores_and_clears_state(self) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + m.undo_sos_reformulation() + + assert m._sos_reformulation_state is None + assert list(m.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in m.variables + + def test_double_apply_raises(self) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + with pytest.raises(RuntimeError, match="already been applied"): + m.apply_sos_reformulation() + + def test_undo_without_apply_raises(self) -> None: + m = self._build_sos1_model() + + with pytest.raises(RuntimeError, match="No SOS reformulation"): + m.undo_sos_reformulation() + + @pytest.mark.parametrize( + "copy_fn", + [ + pytest.param(lambda m: m.copy(), id="model.copy()"), + pytest.param(lambda m: __import__("copy").copy(m), id="copy.copy(model)"), + pytest.param( + lambda m: __import__("copy").deepcopy(m), id="copy.deepcopy(model)" + ), + ], + ) + def test_copy_persists_state_and_undo_works_on_copy( + self, copy_fn: Callable[[Model], Model] + ) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + c = copy_fn(m) + + # State is carried over but is an independent object + assert c._sos_reformulation_state is not None + assert c._sos_reformulation_state is not m._sos_reformulation_state + # Aux vars/cons exist on the copy (they were copied as part of the + # reformulated model state) + assert "_sos_reform_x_y" in c.variables + assert "_sos_reform_x_upper" in c.constraints + assert "_sos_reform_x_card" in c.constraints + # SOS attrs are not on the copy's "x" yet (still in reformulated form) + assert "x" not in list(c.variables.sos) + + # Undo on the copy fully restores the original SOS form + c.undo_sos_reformulation() + assert c._sos_reformulation_state is None + assert list(c.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in c.variables + assert "_sos_reform_x_upper" not in c.constraints + assert "_sos_reform_x_card" not in c.constraints + + # Original is entirely unaffected + assert m._sos_reformulation_state is not None + assert "_sos_reform_x_y" in m.variables + assert len(list(m.variables.sos)) == 0 + + def test_to_netcdf_warns_when_state_active(self, tmp_path: Path) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + + with pytest.warns(UserWarning, match="active SOS reformulation"): + m.to_netcdf(tmp_path / "m.nc") + + # File written despite the warning — the netcdf carries the + # reformulated MILP form. + assert (tmp_path / "m.nc").exists() + + def test_to_netcdf_silent_after_undo(self, tmp_path: Path) -> None: + m = self._build_sos1_model() + m.apply_sos_reformulation() + m.undo_sos_reformulation() + + with warnings.catch_warnings(): + warnings.simplefilter("error") # any warning fails the test + m.to_netcdf(tmp_path / "m.nc") + + +@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") +class TestSolverPathSOSCheck: + """Solver._build() must raise on SOS-bearing model with non-SOS solver.""" + + def test_solver_from_name_raises_without_reformulation(self) -> None: + from linopy import solvers + + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x.sum(), sense="max") + + with pytest.raises(ValueError, match="does not support SOS"): + solvers.Solver.from_name("highs", m, io_api="lp") + + +@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") +class TestSolveAutoUndoOnFailure: + """Model.solve must auto-undo SOS reformulation when build/solve raises.""" + + def test_state_restored_when_build_raises( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + from linopy import solvers + + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x.sum(), sense="max") + + def boom(*args: object, **kwargs: object) -> None: + raise RuntimeError("simulated build failure") + + monkeypatch.setattr(solvers.Solver, "from_name", boom) + + with pytest.raises(RuntimeError, match="simulated build failure"): + m.solve(solver_name="highs", reformulate_sos=True) + + assert m._sos_reformulation_state is None + assert list(m.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in m.variables + + # A subsequent real solve must not hit "already applied" + monkeypatch.undo() + m.solve(solver_name="highs", reformulate_sos=True) + + @pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") class TestSolveWithReformulation: """Tests for solving with SOS reformulation.""" @@ -931,3 +1094,141 @@ def test_invalid_reformulate_sos_value(self) -> None: with pytest.raises(ValueError, match="Invalid value for reformulate_sos"): m.solve(solver_name="highs", reformulate_sos="invalid") # type: ignore[arg-type] + + +class TestResolveSOSReformulation: + """Helper contracts not already exercised end-to-end by ``m.solve(...)``.""" + + @staticmethod + def _sos_model() -> Model: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + return m + + def test_no_sos_short_circuits(self) -> None: + # Fast path: no SOS variables means False regardless of args. + m = Model() + m.add_variables(name="x") + for v in (True, False, "auto"): + assert m._resolve_sos_reformulation(None, v) is False + + def test_true_does_not_consult_solver_name(self) -> None: + # reformulate_sos=True must not require solver_name — no lookup. + assert self._sos_model()._resolve_sos_reformulation(None, True) is True + + def test_auto_with_none_solver_raises(self) -> None: + with pytest.raises(ValueError, match="requires an explicit `solver_name`"): + self._sos_model()._resolve_sos_reformulation(None, "auto") + + +@pytest.mark.skipif("highs" not in available_solvers, reason="HiGHS not installed") +class TestRemoteBracket: + """ + Model.solve(remote=...) must bracket SOS reformulation around the remote + dispatch and suppress the to_netcdf warning that fires inside the helper. + """ + + @staticmethod + def _sos_model() -> Model: + m = Model() + idx = pd.Index([0, 1, 2], name="i") + x = m.add_variables(lower=0, upper=1, coords=[idx], name="x") + m.add_sos_constraints(x, sos_type=1, sos_dim="i") + m.add_objective(x * np.array([1.0, 2.0, 3.0]), sense="max") + return m + + def _fake_handler( + self, observed: dict[str, object], tmp_path: Path + ) -> RemoteHandler: + """ + Non-OetcHandler stand-in with the SSH-shaped `solve_on_remote`. + + Records whether the model arrives in reformulated form, then runs + `model.to_netcdf(...)` and `read_netcdf(...)` (naturally — no + warning recording here, so we can observe at the call-site whether + Model.solve's suppression worked). + """ + from linopy.io import read_netcdf + from linopy.sos_reformulation import ( + sos_reformulation_context, + suppress_serialization_warning, + ) + + class _Handler: + def solve_on_remote( + _self, + model: Model, + *, + reformulate_sos: bool | Literal["auto"] = False, + **kwargs: object, + ) -> Model: + solver_name = kwargs.get("solver_name") + assert solver_name is None or isinstance(solver_name, str) + with sos_reformulation_context( + model, solver_name, reformulate_sos + ) as applied: + observed["state_active"] = ( + model._sos_reformulation_state is not None + ) + observed["solver_name_arg"] = solver_name + with suppress_serialization_warning(active=applied): + model.to_netcdf(tmp_path / "sent.nc") + solved = read_netcdf(tmp_path / "sent.nc") + for _name, var in solved.variables.items(): + arr = np.zeros(var.labels.shape, dtype=float) + var.solution = xr.DataArray(arr, dims=var.labels.dims) + solved.objective.set_value(0.0) + solved.status = "ok" + solved.termination_condition = "optimal" + return solved + + return cast(RemoteHandler, _Handler()) + + def test_remote_brackets_and_suppresses_warning(self, tmp_path: Path) -> None: + m = self._sos_model() + observed: dict[str, object] = {} + handler = self._fake_handler(observed, tmp_path) + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + m.solve(solver_name="highs", remote=handler, reformulate_sos=True) + + # Reformulation was active when the handler ran (apply happened + # before the remote dispatch). + assert observed["state_active"] is True + assert observed["solver_name_arg"] == "highs" + + # No "active SOS reformulation" warning escaped Model.solve. + assert not any("active SOS reformulation" in str(w.message) for w in captured) + + # Lifecycle wound down: state cleared, original SOS variable restored. + assert m._sos_reformulation_state is None + assert list(m.variables.sos) == ["x"] + assert "_sos_reform_x_y" not in m.variables + + def test_remote_skips_bracket_when_reformulate_sos_false( + self, tmp_path: Path + ) -> None: + m = self._sos_model() + observed: dict[str, object] = {} + handler = self._fake_handler(observed, tmp_path) + + with warnings.catch_warnings(record=True) as captured: + warnings.simplefilter("always") + m.solve(solver_name="highs", remote=handler, reformulate_sos=False) + + # No reformulation happened — model still has the original SOS var + # when the handler sees it, and to_netcdf never warns. + assert observed["state_active"] is False + assert not any("active SOS reformulation" in str(w.message) for w in captured) + assert m._sos_reformulation_state is None + + def test_remote_auto_requires_solver_name_with_sos(self, tmp_path: Path) -> None: + m = self._sos_model() + observed: dict[str, object] = {} + handler = self._fake_handler(observed, tmp_path) + + with pytest.raises(ValueError, match="requires an explicit `solver_name`"): + m.solve(remote=handler, reformulate_sos="auto")