Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 110 additions & 107 deletions linopy/dual.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
import pandas as pd
import xarray as xr

from linopy.expressions import LinearExpression
Expand All @@ -21,7 +20,86 @@
logger = logging.getLogger(__name__)


def _var_lookup(m: Model) -> dict:
def _skip(
da: xr.DataArray, component_type: Literal["variable", "constraint"], name: str
) -> bool:
"""
Determine whether to skip processing a variable or constraint based on its labels.

Parameters
----------
da : xr.DataArray
The labels DataArray of the variable or constraint.
component_type : Literal["variable", "constraint"]
The type of component being checked, used for logging.
name : str
The name of the variable or constraint, used for logging.

Returns
-------
bool
True if the component should be skipped (empty or fully masked), False otherwise.
"""
if da.size == 0:
logger.debug(f"Skipping empty {component_type} '{name}'.")
return True

if (da == -1).all():
logger.debug(f"{component_type} '{name}' is fully masked, skipping.")
return True
return False


def _lookup(
labels: xr.DataArray, name: str, component_type: Literal["variable", "constraint"]
) -> dict[int, tuple[str, dict]]:
"""
Create a lookup dictionary mapping labels to their corresponding names and coordinates.

Parameters
----------
labels : xr.DataArray
Array of labels.
name : str
Name of the component.
component_type : Literal["variable", "constraint"]
Type of the component.

Returns
-------
dict[int, tuple[str, dict]]
Mapping from flat integer label to (name, coord_dict) tuple.
"""
lookup: dict[int, tuple[str, dict]] = {}

vals = labels.values
if _skip(labels, component_type, name):
return lookup

logger.debug(
f"Creating label lookup for {component_type} '{name}' with shape {labels.shape} and dims {labels.dims}."
)

if labels.ndim == 0:
lookup[int(vals.item())] = (name, {})
return lookup

coord_values = [labels.coords[d].values for d in labels.dims]

# Choosing np.ndindex over np.argwhere or da.to_series for memory efficiency on large ND arrays
for idx in np.ndindex(vals.shape):
label = int(vals[idx])
if label == -1:
continue
lookup[label] = (
name,
{dim: coord_values[i][idx[i]] for i, dim in enumerate(labels.dims)},
)

return lookup


def _var_lookup(m: Model) -> dict[int, tuple[str, dict]]:
"""
Build a flat label -> (var_name, coord_dict) lookup for all variables in m.

Expand All @@ -43,40 +121,12 @@ def _var_lookup(m: Model) -> dict:
var_lookup = {}
logger.debug("Building variable label lookup.")
for var_name, var in m.variables.items():
labels = var.labels
flat_labels = labels.values.flatten()

if len(flat_labels) == 0:
logger.debug(f"Skipping empty variable '{var_name}'.")
continue
if not (flat_labels != -1).any():
logger.debug(f"Variable '{var_name}' is fully masked, skipping.")
continue

logger.debug(
f"Creating label lookup for variable '{var_name}' with shape {labels.shape} and dims {labels.dims}."
)

coord_arrays = (
np.meshgrid(
*[labels.coords[dim].values for dim in labels.dims], indexing="ij"
)
if len(labels.dims) > 0
else []
)
flat_coords = [arr.flatten() for arr in coord_arrays]

for k, flat in enumerate(flat_labels):
if flat != -1:
var_lookup[int(flat)] = (
var_name,
{dim: flat_coords[i][k] for i, dim in enumerate(labels.dims)},
)

lookup = _lookup(var.labels, var_name, "variable")
var_lookup.update(lookup)
return var_lookup


def _con_lookup(m: Model) -> dict:
def _con_lookup(m: Model) -> dict[int, tuple[str, dict]]:
"""
Build a flat label -> (con_name, coord_dict) lookup for all constraints in m.

Expand All @@ -98,36 +148,8 @@ def _con_lookup(m: Model) -> dict:
con_lookup = {}
logger.debug("Building constraint label lookup.")
for con_name, con in m.constraints.items():
labels = con.labels
flat_labels = labels.values.flatten()

if len(flat_labels) == 0:
logger.debug(f"Skipping empty constraint '{con_name}'.")
continue
if not (flat_labels != -1).any():
logger.debug(f"Constraint '{con_name}' is fully masked, skipping.")
continue

logger.debug(
f"Creating label lookup for constraint '{con_name}' with shape {labels.shape} and dims {labels.dims}."
)

coord_arrays = (
np.meshgrid(
*[labels.coords[dim].values for dim in labels.dims], indexing="ij"
)
if len(labels.dims) > 0
else []
)
flat_coords = [arr.flatten() for arr in coord_arrays]

for k, flat in enumerate(flat_labels):
if flat != -1:
con_lookup[int(flat)] = (
con_name,
{dim: flat_coords[i][k] for i, dim in enumerate(labels.dims)},
)

lookup = _lookup(con.labels, con_name, "constraint")
con_lookup.update(lookup)
return con_lookup


Expand Down Expand Up @@ -228,44 +250,35 @@ def _add_dual_variables(m: Model, m2: Model) -> dict:

dual_vars = {}
for name, con in m.constraints.items():
sign_vals = con.sign.values.flatten()

if len(sign_vals) == 0:
logger.warning(f"Constraint '{name}' has no sign values, skipping.")
if _skip(con.labels, "constraint", name):
continue

mask = con.labels != -1
if not mask.any():
logger.debug(f"Constraint '{name}' is fully masked, skipping.")
continue

if sign_vals[0] == "=":
lower, upper = -np.inf, np.inf
var_type = "free"
elif sign_vals[0] == "<=":
lower, upper = (-np.inf, 0) if primal_is_min else (0, np.inf)
var_type = "non-positive" if primal_is_min else "non-negative"
elif sign_vals[0] == ">=":
lower, upper = (0, np.inf) if primal_is_min else (-np.inf, 0)
var_type = "non-negative" if primal_is_min else "non-positive"
else:
logger.warning(
f"Constraint '{name}' has unrecognized sign '{sign_vals[0]}', skipping."
)
continue
sign = con.sign.isel({d: 0 for d in con.sign.dims}).item()

match sign:
case "=":
lower, upper = -np.inf, np.inf
var_type = "free"
case "<=":
lower, upper = (-np.inf, 0) if primal_is_min else (0, np.inf)
var_type = "non-positive" if primal_is_min else "non-negative"
case ">=":
lower, upper = (0, np.inf) if primal_is_min else (-np.inf, 0)
var_type = "non-negative" if primal_is_min else "non-positive"
case _:
logger.warning(
f"Constraint '{name}' has unrecognized sign '{sign}', skipping."
)
continue

logger.debug(
f"Adding {var_type} dual variable for constraint '{name}' with shape {con.shape} and dims {con.labels.dims}."
)
coords = (
[con.labels.coords[dim] for dim in con.labels.dims]
if con.labels.dims
else None
)
dual_vars[name] = m2.add_variables(
lower=lower,
upper=upper,
coords=coords,
coords=con.labels.coords,
name=name,
mask=mask,
)
Expand Down Expand Up @@ -409,37 +422,27 @@ def _add_dual_feasibility_constraints(
# add dual feasibility constraints to m2
logger.debug("Adding dual feasibility constraints to model.")
for var_name, var in m.variables.items():
coords = [
pd.Index(var.labels.coords[dim].values, name=dim) for dim in var.labels.dims
]
mask = var.labels != -1

c_vals = xr.DataArray(
np.vectorize(lambda flat: c_by_label.get(flat, 0.0))(var.labels.values),
coords=var.labels.coords,
)

def rule(
m: Model,
*coord_vals: Any,
vname: str = var_name,
vdims: tuple = var.labels.dims,
) -> LinearExpression | None:
coord_dict = dict(zip(vdims, coord_vals))
def __rule(m: Model, *coord_vals: Any) -> LinearExpression | None:
coord_dict = dict(zip(var.labels.dims, coord_vals))
flat = var.labels.sel(**coord_dict).item()
if flat == -1:
return None
if flat not in dual_feas_terms[vname]:
if flat == -1 or flat not in (term_dict := dual_feas_terms[var_name]):
return None
_, terms, _ = dual_feas_terms[vname][flat]
_, terms, _ = term_dict[flat]
if not terms:
return None
return sum(
coeff * dual_vars[con_name].at[tuple(con_coords.values())]
for con_name, con_coords, coeff in terms
)

lhs = LinearExpression.from_rule(m2, rule, coords)
lhs = LinearExpression.from_rule(m2, __rule, var.labels.coords)
m2.add_constraints(lhs == c_vals, name=var_name, mask=mask)


Expand Down
Loading