Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Release Notes

.. Upcoming Version

* Allow constant values in objective cost function
* Add support for SOS1 and SOS2 (Special Ordered Sets) constraints via ``Model.add_sos_constraints()`` and ``Model.remove_sos_constraints()``
* Add simplify method to LinearExpression to combine duplicate terms
* Add convenience function to create LinearExpression from constant
Expand Down
3 changes: 3 additions & 0 deletions linopy/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,9 @@ def empty(self) -> EmptyDeprecationWrapper:
"""
return EmptyDeprecationWrapper(not self.size)

def drop_constant(self: GenericExpression) -> GenericExpression:
return self - self.const # type: ignore

def densify_terms(self: GenericExpression) -> GenericExpression:
"""
Move all non-zero term entries to the front and cut off all-zero
Expand Down
88 changes: 86 additions & 2 deletions linopy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
import os
import re
from collections.abc import Callable, Mapping, Sequence
from functools import wraps
from pathlib import Path
from tempfile import NamedTemporaryFile, gettempdir
from typing import Any, Literal, overload
from typing import Any, Literal, ParamSpec, TypeVar, overload
from warnings import warn

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -77,6 +79,67 @@
logger = logging.getLogger(__name__)


P = ParamSpec("P")
R = TypeVar("R")


class ConstantInObjectiveWarning(UserWarning): ...


class ConstantObjectiveError(Exception): ...


def strip_and_replace_constant_objective(func: Callable[P, R]) -> Callable[P, R]:
"""
Decorates a Model instance method.

If the model objective contains a constant term, this decorator will:
- Remove the constant term from the model objective
- Call the decorated method
- Add the constant term back to the model objective
"""

@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
assert args, "Expected at least one argument (self)"
self = args[0]
assert isinstance(self, Model), (
f"First argument must be a Model instance, got {type(self)}"
)
model = self
if not self.objective.has_constant:
# Continue as normal if there is no constant term
return func(*args, **kwargs)

# The objective contains a constant term
if not model.allow_constant_objective:
raise ConstantObjectiveError(
"Objective function contains constant terms. Please use LinearExpression.drop_constants()/QuadraticExpression.drop_constants() or set Model.allow_constant_objective=True."
)

# Modify the model objective to drop the constant term
model = self
constant = float(self.objective.expression.const.values)
model.objective.expression = self.objective.expression.drop_constant()
args = (model, *args[1:]) # type: ignore

try:
result = func(*args, **kwargs)
except Exception as e:
# Even if there is an exception, make sure the model returns to it's original state
model.objective.expression = model.objective.expression + constant
raise e

# Re-add the constant term to return the model objective to the original expression
model.objective.expression = model.objective.expression + constant
if model.objective.value is not None:
model.objective.set_value(model.objective.value + constant)

return result

return wrapper


class Model:
"""
Linear optimization model.
Expand All @@ -103,6 +166,7 @@ class Model:
_dual: Dataset
_status: str
_termination_condition: str
_allow_constant_objective: bool
_xCounter: int
_cCounter: int
_varnameCounter: int
Expand All @@ -124,6 +188,7 @@ class Model:
# hidden attributes
"_status",
"_termination_condition",
"_allow_constant_objective",
# TODO: move counters to Variables and Constraints class
"_xCounter",
"_cCounter",
Expand Down Expand Up @@ -175,6 +240,7 @@ def __init__(

self._status: str = "initialized"
self._termination_condition: str = ""
self._allow_constant_objective: bool = False
self._xCounter: int = 0
self._cCounter: int = 0
self._varnameCounter: int = 0
Expand Down Expand Up @@ -727,6 +793,17 @@ def add_constraints(
self.constraints.add(constraint)
return constraint

@property
def allow_constant_objective(self) -> bool:
"""
Whether constant terms in the objective function are allowed.
"""
return self._allow_constant_objective

@allow_constant_objective.setter
def allow_constant_objective(self, allow: bool) -> None:
self._allow_constant_objective = allow

def add_objective(
self,
expr: Variable
Expand All @@ -748,7 +825,7 @@ def add_objective(

Returns
-------
linopy.LinearExpression
linopy.LinearExpression, linopy.QuadraticExpression
The objective function assigned to the model.
"""
if not overwrite:
Expand All @@ -758,8 +835,14 @@ def add_objective(
)
if isinstance(expr, Variable):
expr = 1 * expr

self.objective.expression = expr
self.objective.sense = sense
if not self.allow_constant_objective and self.objective.has_constant:
warn(
"Objective function contains constant terms but this is not allowed as Model.allow_constant_objective=False, running solve will result in an error. Please either remove constants from the expression with expr.drop_constants() or set Model.allow_constant_objective=True.",
ConstantInObjectiveWarning,
)

def remove_variables(self, name: str) -> None:
"""
Expand Down Expand Up @@ -1107,6 +1190,7 @@ def get_problem_file(
) as f:
return Path(f.name)

@strip_and_replace_constant_objective
def solve(
self,
solver_name: str | None = None,
Expand Down
10 changes: 7 additions & 3 deletions linopy/objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,15 @@ def expression(
if len(expr.coord_dims):
expr = expr.sum()

if (expr.const != 0.0) and not np.isnan(expr.const):
raise ValueError("Constant values in objective function not supported.")

self._expression = expr

@property
def has_constant(self) -> bool:
"""
Returns whether the objective has a constant term.
"""
return bool(self.expression.has_constant)

@property
def model(self) -> Model:
"""
Expand Down
15 changes: 15 additions & 0 deletions test/test_linear_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,21 @@ def test_cumsum(m: Model, multiple: float) -> None:
cumsum.nterm == 2


def test_drop_constant(x: Variable) -> None:
"""Test that constants are removed"""
expr_a = 2 * x
expr_b = expr_a + [1, 2]
expr_c = expr_b + float("nan")
for expr in [expr_a, expr_b, expr_c]:
expr = 2 * x + 10
expr_2 = expr.drop_constant()

assert all(expr_2.const.values == 0.0), (
f"Expected constant 0.0, got {expr_2.const.values}"
)
assert not bool(expr_2.has_constant)


def test_simplify_basic(x: Variable) -> None:
"""Test basic simplification with duplicate terms."""
expr = 2 * x + 3 * x + 1 * x
Expand Down
5 changes: 2 additions & 3 deletions test/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,8 @@ def test_objective() -> None:
assert m.objectiverange.min() == 2
assert m.objectiverange.max() == 2

# test objective with constant which is not supported
with pytest.raises(ValueError):
m.objective = m.objective + 3
# test objective with constant which is supported
m.objective = m.objective + 3


def test_remove_variable() -> None:
Expand Down
3 changes: 1 addition & 2 deletions test/test_objective.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,5 +192,4 @@ def test_repr(linear_objective: Objective, quadratic_objective: Objective) -> No
def test_objective_constant() -> None:
m = Model()
linear_expr = LinearExpression(None, m) + 1
with pytest.raises(ValueError):
m.objective = Objective(linear_expr, m)
m.objective = Objective(linear_expr, m)
49 changes: 49 additions & 0 deletions test/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from linopy import GREATER_EQUAL, LESS_EQUAL, Model, solvers
from linopy.common import to_path
from linopy.expressions import LinearExpression
from linopy.model import ConstantInObjectiveWarning, ConstantObjectiveError
from linopy.solver_capabilities import (
SolverFeature,
get_available_solvers_with_feature,
Expand Down Expand Up @@ -955,6 +956,54 @@ def test_model_resolve(
assert np.isclose(model.objective.value or 0, 5.25)


def test_model_with_constant_in_objective_feasible(model: Model) -> None:
objective = model.objective.expression + 1

with pytest.warns(ConstantInObjectiveWarning):
model.add_objective(expr=objective, overwrite=True)

with pytest.raises(ConstantObjectiveError):
status, _ = model.solve(solver_name="highs")

model.allow_constant_objective = True
status, _ = model.solve(solver_name="highs")
assert status == "ok"
# x = -0.1, y = 1.7
assert model.objective.value == 4.3
assert model.objective.expression.const == 1
assert model.objective.expression.solution == 4.3


def test_model_with_constant_in_objective_infeasible(model: Model) -> None:
objective = model.objective.expression + 1
model.add_objective(expr=objective, overwrite=True)
model.add_constraints([(1, "x")], "<=", 0)
model.add_constraints([(1, "y")], "<=", 0)

model.allow_constant_objective = True
_, condition = model.solve(solver_name="highs")

assert condition == "infeasible"
# Even though the problem was not solved, the constant term should still be accessible
assert model.objective.expression.const == 1


def test_model_with_constant_in_objective_error(model: Model) -> None:
objective = model.objective.expression + 1
model.allow_constant_objective = True
model.add_objective(expr=objective, overwrite=True)
model.add_constraints([(1, "x")], "<=", 0)
model.add_constraints([(1, "y")], "<=", 0)

try:
_ = model.solve(solver_name="apples")
except AssertionError:
pass

# Even if something goes wrong, the model objective should return to the correct state
assert model.objective.expression.const == 1


@pytest.mark.parametrize(
"solver,io_api,explicit_coordinate_names", [p for p in params if "direct" not in p]
)
Expand Down
14 changes: 14 additions & 0 deletions test/test_quadratic_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,20 @@ def test_quadratic_expression_constant_to_polars() -> None:
assert all(arr.to_numpy() == df["const"].to_numpy())


def test_drop_constant(x: Variable) -> None:
"""Test that constants are removed"""
expr_a = 2 * x * x
expr_b = expr_a + 1
for expr in [expr_a, expr_b]:
expr = 2 * x + 10
expr_2 = expr.drop_constant()

assert all(expr_2.const.values == 0.0), (
f"Expected constant 0.0, got {expr_2.const.values}"
)
assert not bool(expr_2.has_constant)


def test_quadratic_expression_to_matrix(model: Model, x: Variable, y: Variable) -> None:
expr: QuadraticExpression = x * y + x + 5 # type: ignore

Expand Down