diff --git a/include/pyoptinterface/knitro_model.hpp b/include/pyoptinterface/knitro_model.hpp index 6947e28..48258e2 100644 --- a/include/pyoptinterface/knitro_model.hpp +++ b/include/pyoptinterface/knitro_model.hpp @@ -56,6 +56,7 @@ B(KN_add_obj_quadratic_struct); \ B(KN_del_obj_quadratic_struct); \ B(KN_del_obj_quadratic_struct_all); \ + B(KN_chg_obj_linear_term); \ B(KN_add_con_constant); \ B(KN_add_con_linear_struct); \ B(KN_add_con_linear_term); \ @@ -400,6 +401,7 @@ class KNITROModel : public OnesideLinearConstraintMixin, double get_obj_value() const; void set_obj_sense(ObjectiveSense sense); ObjectiveSense get_obj_sense() const; + void set_objective_coefficient(const VariableIndex &variable, double coefficient); // Solve functions void optimize(); diff --git a/lib/knitro_model.cpp b/lib/knitro_model.cpp index 18b3347..ab822c8 100644 --- a/lib/knitro_model.cpp +++ b/lib/knitro_model.cpp @@ -646,6 +646,17 @@ void KNITROModel::set_objective(const ExprBuilder &expr, ObjectiveSense sense) } } +void KNITROModel::set_objective_coefficient(const VariableIndex &variable, double coefficient) +{ + KNINT indexVar = _variable_index(variable); + // NOTE: To make sure the coefficient is updated correctly, + // we need to call KN_update before changing the linear term + _update(); + int error = knitro::KN_chg_obj_linear_term(m_kc.get(), indexVar, coefficient); + _check_error(error); + m_is_dirty = true; +} + void KNITROModel::add_single_nl_objective(ExpressionGraph &graph, const ExpressionHandle &result) { _add_graph(graph); diff --git a/lib/knitro_model_ext.cpp b/lib/knitro_model_ext.cpp index aca59f6..16cb35d 100644 --- a/lib/knitro_model_ext.cpp +++ b/lib/knitro_model_ext.cpp @@ -154,6 +154,8 @@ NB_MODULE(knitro_model_ext, m) nb::arg("expr"), nb::arg("sense") = ObjectiveSense::Minimize) .def("_add_single_nl_objective", &KNITROModel::add_single_nl_objective, nb::arg("graph"), nb::arg("result")) + .def("set_objective_coefficient", &KNITROModel::set_objective_coefficient, nb::arg("variable"), + nb::arg("coefficient")) // clang-format off BIND_F(get_obj_value) diff --git a/tests/test_update.py b/tests/test_update.py index 4947a02..657979e 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -37,3 +37,12 @@ def test_update(model_interface): assert model.get_value(x[0]) == approx(2.0) assert model.get_value(x[2]) == approx(1.0) + + model.set_variable_attribute(x[0], poi.VariableAttribute.LowerBound, 1.5) + model.set_variable_attribute(x[2], poi.VariableAttribute.LowerBound, 0.5) + model.set_objective_coefficient(x[0], -2.0) + model.set_obj_sense(poi.ObjectiveSense.Maximize) + model.optimize() + + assert model.get_value(x[0]) == approx(1.5) + assert model.get_value(x[2]) == approx(0.75)