diff --git a/linopy/dual.py b/linopy/dual.py index 0a787e68..15b51ccc 100644 --- a/linopy/dual.py +++ b/linopy/dual.py @@ -7,10 +7,9 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal import numpy as np -import pandas as pd import xarray as xr from linopy.expressions import LinearExpression @@ -21,7 +20,86 @@ logger = logging.getLogger(__name__) -def _var_lookup(m: Model) -> dict: +def _skip( + da: xr.DataArray, component_type: Literal["variable", "constraint"], name: str +) -> bool: + """ + Determine whether to skip processing a variable or constraint based on its labels. + + Parameters + ---------- + da : xr.DataArray + The labels DataArray of the variable or constraint. + component_type : Literal["variable", "constraint"] + The type of component being checked, used for logging. + name : str + The name of the variable or constraint, used for logging. + + Returns + ------- + bool + True if the component should be skipped (empty or fully masked), False otherwise. + """ + if da.size == 0: + logger.debug(f"Skipping empty {component_type} '{name}'.") + return True + + if (da == -1).all(): + logger.debug(f"{component_type} '{name}' is fully masked, skipping.") + return True + return False + + +def _lookup( + labels: xr.DataArray, name: str, component_type: Literal["variable", "constraint"] +) -> dict[int, tuple[str, dict]]: + """ + Create a lookup dictionary mapping labels to their corresponding names and coordinates. + + Parameters + ---------- + labels : xr.DataArray + Array of labels. + name : str + Name of the component. + component_type : Literal["variable", "constraint"] + Type of the component. + + Returns + ------- + dict[int, tuple[str, dict]] + Mapping from flat integer label to (name, coord_dict) tuple. + """ + lookup: dict[int, tuple[str, dict]] = {} + + vals = labels.values + if _skip(labels, component_type, name): + return lookup + + logger.debug( + f"Creating label lookup for {component_type} '{name}' with shape {labels.shape} and dims {labels.dims}." + ) + + if labels.ndim == 0: + lookup[int(vals.item())] = (name, {}) + return lookup + + coord_values = [labels.coords[d].values for d in labels.dims] + + # Choosing np.ndindex over np.argwhere or da.to_series for memory efficiency on large ND arrays + for idx in np.ndindex(vals.shape): + label = int(vals[idx]) + if label == -1: + continue + lookup[label] = ( + name, + {dim: coord_values[i][idx[i]] for i, dim in enumerate(labels.dims)}, + ) + + return lookup + + +def _var_lookup(m: Model) -> dict[int, tuple[str, dict]]: """ Build a flat label -> (var_name, coord_dict) lookup for all variables in m. @@ -43,40 +121,12 @@ def _var_lookup(m: Model) -> dict: var_lookup = {} logger.debug("Building variable label lookup.") for var_name, var in m.variables.items(): - labels = var.labels - flat_labels = labels.values.flatten() - - if len(flat_labels) == 0: - logger.debug(f"Skipping empty variable '{var_name}'.") - continue - if not (flat_labels != -1).any(): - logger.debug(f"Variable '{var_name}' is fully masked, skipping.") - continue - - logger.debug( - f"Creating label lookup for variable '{var_name}' with shape {labels.shape} and dims {labels.dims}." - ) - - coord_arrays = ( - np.meshgrid( - *[labels.coords[dim].values for dim in labels.dims], indexing="ij" - ) - if len(labels.dims) > 0 - else [] - ) - flat_coords = [arr.flatten() for arr in coord_arrays] - - for k, flat in enumerate(flat_labels): - if flat != -1: - var_lookup[int(flat)] = ( - var_name, - {dim: flat_coords[i][k] for i, dim in enumerate(labels.dims)}, - ) - + lookup = _lookup(var.labels, var_name, "variable") + var_lookup.update(lookup) return var_lookup -def _con_lookup(m: Model) -> dict: +def _con_lookup(m: Model) -> dict[int, tuple[str, dict]]: """ Build a flat label -> (con_name, coord_dict) lookup for all constraints in m. @@ -98,36 +148,8 @@ def _con_lookup(m: Model) -> dict: con_lookup = {} logger.debug("Building constraint label lookup.") for con_name, con in m.constraints.items(): - labels = con.labels - flat_labels = labels.values.flatten() - - if len(flat_labels) == 0: - logger.debug(f"Skipping empty constraint '{con_name}'.") - continue - if not (flat_labels != -1).any(): - logger.debug(f"Constraint '{con_name}' is fully masked, skipping.") - continue - - logger.debug( - f"Creating label lookup for constraint '{con_name}' with shape {labels.shape} and dims {labels.dims}." - ) - - coord_arrays = ( - np.meshgrid( - *[labels.coords[dim].values for dim in labels.dims], indexing="ij" - ) - if len(labels.dims) > 0 - else [] - ) - flat_coords = [arr.flatten() for arr in coord_arrays] - - for k, flat in enumerate(flat_labels): - if flat != -1: - con_lookup[int(flat)] = ( - con_name, - {dim: flat_coords[i][k] for i, dim in enumerate(labels.dims)}, - ) - + lookup = _lookup(con.labels, con_name, "constraint") + con_lookup.update(lookup) return con_lookup @@ -228,44 +250,35 @@ def _add_dual_variables(m: Model, m2: Model) -> dict: dual_vars = {} for name, con in m.constraints.items(): - sign_vals = con.sign.values.flatten() - - if len(sign_vals) == 0: - logger.warning(f"Constraint '{name}' has no sign values, skipping.") + if _skip(con.labels, "constraint", name): continue mask = con.labels != -1 - if not mask.any(): - logger.debug(f"Constraint '{name}' is fully masked, skipping.") - continue - - if sign_vals[0] == "=": - lower, upper = -np.inf, np.inf - var_type = "free" - elif sign_vals[0] == "<=": - lower, upper = (-np.inf, 0) if primal_is_min else (0, np.inf) - var_type = "non-positive" if primal_is_min else "non-negative" - elif sign_vals[0] == ">=": - lower, upper = (0, np.inf) if primal_is_min else (-np.inf, 0) - var_type = "non-negative" if primal_is_min else "non-positive" - else: - logger.warning( - f"Constraint '{name}' has unrecognized sign '{sign_vals[0]}', skipping." - ) - continue + sign = con.sign.isel({d: 0 for d in con.sign.dims}).item() + + match sign: + case "=": + lower, upper = -np.inf, np.inf + var_type = "free" + case "<=": + lower, upper = (-np.inf, 0) if primal_is_min else (0, np.inf) + var_type = "non-positive" if primal_is_min else "non-negative" + case ">=": + lower, upper = (0, np.inf) if primal_is_min else (-np.inf, 0) + var_type = "non-negative" if primal_is_min else "non-positive" + case _: + logger.warning( + f"Constraint '{name}' has unrecognized sign '{sign}', skipping." + ) + continue logger.debug( f"Adding {var_type} dual variable for constraint '{name}' with shape {con.shape} and dims {con.labels.dims}." ) - coords = ( - [con.labels.coords[dim] for dim in con.labels.dims] - if con.labels.dims - else None - ) dual_vars[name] = m2.add_variables( lower=lower, upper=upper, - coords=coords, + coords=con.labels.coords, name=name, mask=mask, ) @@ -409,9 +422,6 @@ def _add_dual_feasibility_constraints( # add dual feasibility constraints to m2 logger.debug("Adding dual feasibility constraints to model.") for var_name, var in m.variables.items(): - coords = [ - pd.Index(var.labels.coords[dim].values, name=dim) for dim in var.labels.dims - ] mask = var.labels != -1 c_vals = xr.DataArray( @@ -419,19 +429,12 @@ def _add_dual_feasibility_constraints( coords=var.labels.coords, ) - def rule( - m: Model, - *coord_vals: Any, - vname: str = var_name, - vdims: tuple = var.labels.dims, - ) -> LinearExpression | None: - coord_dict = dict(zip(vdims, coord_vals)) + def __rule(m: Model, *coord_vals: Any) -> LinearExpression | None: + coord_dict = dict(zip(var.labels.dims, coord_vals)) flat = var.labels.sel(**coord_dict).item() - if flat == -1: - return None - if flat not in dual_feas_terms[vname]: + if flat == -1 or flat not in (term_dict := dual_feas_terms[var_name]): return None - _, terms, _ = dual_feas_terms[vname][flat] + _, terms, _ = term_dict[flat] if not terms: return None return sum( @@ -439,7 +442,7 @@ def rule( for con_name, con_coords, coeff in terms ) - lhs = LinearExpression.from_rule(m2, rule, coords) + lhs = LinearExpression.from_rule(m2, __rule, var.labels.coords) m2.add_constraints(lhs == c_vals, name=var_name, mask=mask)