diff --git a/include/pyoptinterface/knitro_model.hpp b/include/pyoptinterface/knitro_model.hpp index 48258e2..01dd976 100644 --- a/include/pyoptinterface/knitro_model.hpp +++ b/include/pyoptinterface/knitro_model.hpp @@ -142,6 +142,7 @@ struct CallbackEvaluator CppAD::sparse_rcv, std::vector> jac_; CppAD::sparse_jac_work jac_work_; CppAD::sparse_rc> hess_pattern_; + CppAD::sparse_rc> hess_pattern_symm_; CppAD::sparse_rcv, std::vector> hess_; CppAD::sparse_hes_work hess_work_; @@ -171,10 +172,19 @@ struct CallbackEvaluator select_rows[fun_rows[k]] = true; } fun.rev_hes_sparsity(select_rows, false, true, hess_pattern_); + for (size_t k = 0; k < hess_pattern_.nnz(); k++) + { + size_t row = hess_pattern_.row()[k]; + size_t col = hess_pattern_.col()[k]; + if (row <= col) + { + hess_pattern_symm_.push_back(row, col); + } + } x.resize(fun.Domain(), 0.0); w.resize(fun.Range(), 0.0); jac_ = CppAD::sparse_rcv, std::vector>(jac_pattern_); - hess_ = CppAD::sparse_rcv, std::vector>(hess_pattern_); + hess_ = CppAD::sparse_rcv, std::vector>(hess_pattern_symm_); } void eval_fun(const V *req_x, V *res_y, bool aggregate = false) @@ -204,7 +214,7 @@ struct CallbackEvaluator x[i] = req_x[indexVars[i]]; } fun.sparse_jac_rev(x, jac_, jac_pattern_, jac_coloring_, jac_work_); - auto jac = jac_.val(); + auto& jac = jac_.val(); for (size_t i = 0; i < jac_.nnz(); i++) { res_jac[i] = jac[i]; @@ -229,7 +239,7 @@ struct CallbackEvaluator } } fun.sparse_hes(x, w, hess_, hess_pattern_, hess_coloring_, hess_work_); - auto hess = hess_.val(); + auto& hess = hess_.val(); for (size_t i = 0; i < hess_.nnz(); i++) { res_hess[i] = hess[i]; @@ -241,8 +251,8 @@ struct CallbackEvaluator CallbackPattern pattern; pattern.indexCons = indexCons; - auto jac_rows = jac_pattern_.row(); - auto jac_cols = jac_pattern_.col(); + auto& jac_rows = jac_pattern_.row(); + auto& jac_cols = jac_pattern_.col(); if (indexCons.empty()) { for (size_t k = 0; k < jac_pattern_.nnz(); k++) @@ -259,9 +269,9 @@ struct CallbackEvaluator } } - auto hess_rows = hess_pattern_.row(); - auto hess_cols = hess_pattern_.col(); - for (size_t k = 0; k < hess_pattern_.nnz(); k++) + auto& hess_rows = hess_pattern_symm_.row(); + auto& hess_cols = hess_pattern_symm_.col(); + for (size_t k = 0; k < hess_pattern_symm_.nnz(); k++) { pattern.hessIndexVars1.push_back(indexVars[hess_rows[k]]); pattern.hessIndexVars2.push_back(indexVars[hess_cols[k]]);