Skip to content

Commit 15059e5

Browse files
committed
Add dirty check for ipopt
1 parent 6d61f37 commit 15059e5

File tree

5 files changed

+67
-3
lines changed

5 files changed

+67
-3
lines changed

include/pyoptinterface/ipopt_model.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ struct IpoptModel : public OnesideLinearConstraintMixin<IpoptModel>,
240240
Hashmap<std::string, std::string> m_options_str;
241241

242242
IpoptResult m_result;
243+
bool m_is_dirty = true;
243244
enum ApplicationReturnStatus m_status;
244245

245246
std::unique_ptr<IpoptProblemInfo, IpoptfreeproblemT> m_problem = nullptr;

lib/ipopt_model.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ VariableIndex IpoptModel::add_variable(double lb, double ub, double start, const
8080
m_var_names.emplace(vi.index, name);
8181
}
8282

83+
m_is_dirty = true;
84+
8385
return vi;
8486
}
8587

@@ -216,6 +218,8 @@ ConstraintIndex IpoptModel::add_linear_constraint(const ScalarAffineFunction &f,
216218
m_linear_con_lb.push_back(lb);
217219
m_linear_con_ub.push_back(ub);
218220

221+
m_is_dirty = true;
222+
219223
return con;
220224
}
221225

@@ -257,6 +261,8 @@ ConstraintIndex IpoptModel::add_quadratic_constraint(const ScalarQuadraticFuncti
257261
m_quadratic_con_lb.push_back(lb);
258262
m_quadratic_con_ub.push_back(ub);
259263

264+
m_is_dirty = true;
265+
260266
return con;
261267
}
262268

@@ -285,6 +291,8 @@ void IpoptModel::set_objective(const ExprBuilder &expr, ObjectiveSense sense)
285291
{
286292
throw std::runtime_error("Only linear and quadratic objective is supported");
287293
}
294+
295+
m_is_dirty = true;
288296
}
289297

290298
void IpoptModel::_set_linear_objective(const ScalarAffineFunction &expr)
@@ -369,6 +377,8 @@ ConstraintIndex IpoptModel::add_single_nl_constraint(size_t graph_index,
369377
nl_constraint_graph_memberships.push_back(ConstraintGraphMembership{
370378
.graph = (int)graph_index, .rank = (int)graph.m_constraint_outputs.size() - 1});
371379

380+
m_is_dirty = true;
381+
372382
return ConstraintIndex(ConstraintType::IPOPT_NL, constraint_index);
373383
}
374384

@@ -636,8 +646,6 @@ void IpoptModel::analyze_structure()
636646

637647
nl_constraint_map_ext2int[i_nl_con] = index_base + i_graph_rank;
638648
}
639-
/*fmt::print("constraint_indices_offsets {}\n", m_nl_evaluator.constraint_indices_offsets);
640-
fmt::print("nl_constraint_map_ext2int {}\n", nl_constraint_map_ext2int);*/
641649

642650
// construct the lower bound and upper bound of the constraints
643651
auto n_constraints = m_linear_con_evaluator.n_constraints +
@@ -754,6 +762,7 @@ void IpoptModel::optimize()
754762
&m_result.obj_val, m_result.mult_g.data(),
755763
m_result.mult_x_L.data(), m_result.mult_x_U.data(), (void *)this);
756764
m_result.is_valid = true;
765+
m_is_dirty = false;
757766
}
758767

759768
void IpoptModel::load_current_solution()

lib/ipopt_model_ext.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ NB_MODULE(ipopt_model_ext, m)
4343
.def(nb::init<>())
4444
.def("close", &IpoptModel::close)
4545
.def_ro("m_status", &IpoptModel::m_status)
46+
.def_rw("m_is_dirty", &IpoptModel::m_is_dirty)
4647
.def("add_variable", &IpoptModel::add_variable, nb::arg("lb") = -INFINITY,
4748
nb::arg("ub") = INFINITY, nb::arg("start") = 0.0, nb::arg("name") = "")
4849
.def("get_variable_lb", &IpoptModel::get_variable_lb)

src/pyoptinterface/_src/ipopt.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,9 @@ def get_rawstatusstring(model):
130130

131131

132132
def get_terminationstatus(model):
133+
is_dirty = model.m_is_dirty
134+
if is_dirty:
135+
return TerminationStatusCode.OPTIMIZE_NOT_CALLED
133136
status = model.m_status
134137
if (
135138
status == ApplicationReturnStatus.Solve_Succeeded
@@ -511,6 +514,8 @@ def add_nl_objective(self, expr):
511514
self.graph_instance_to_index[graph] = graph_index
512515
self.graph_instances.append(graph)
513516

517+
self.m_is_dirty = True
518+
514519
def optimize(self):
515520
self._find_similar_graphs()
516521
self._compile_evaluators()

tests/test_nlp.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@ def test_easy_nlp(nlp_model_ctor):
8080

8181

8282
def test_nlfunc_ifelse(nlp_model_ctor):
83-
if nlp_model_ctor is not ipopt.Model:
83+
model = nlp_model_ctor()
84+
if not isinstance(model, ipopt.Model):
8485
pytest.skip("ifelse is only supported in IPOPT")
8586

8687
for x_, fx in zip([0.2, 0.5, 1.0, 2.0, 3.0], [0.2, 0.5, 1.0, 4.0, 9.0]):
@@ -100,6 +101,53 @@ def test_nlfunc_ifelse(nlp_model_ctor):
100101
assert x_value == pytest.approx(x_)
101102

102103

104+
def test_ipopt_optimizer_not_called():
105+
model = ipopt.Model()
106+
107+
x = model.add_variable(lb=0.0, ub=10.0)
108+
termination_status = model.get_model_attribute(poi.ModelAttribute.TerminationStatus)
109+
assert termination_status == poi.TerminationStatusCode.OPTIMIZE_NOT_CALLED
110+
111+
model.set_objective(x**2)
112+
model.optimize()
113+
termination_status = model.get_model_attribute(poi.ModelAttribute.TerminationStatus)
114+
assert termination_status == poi.TerminationStatusCode.LOCALLY_SOLVED
115+
116+
model.add_linear_constraint(x >= 0.5)
117+
termination_status = model.get_model_attribute(poi.ModelAttribute.TerminationStatus)
118+
assert termination_status == poi.TerminationStatusCode.OPTIMIZE_NOT_CALLED
119+
120+
model.optimize()
121+
termination_status = model.get_model_attribute(poi.ModelAttribute.TerminationStatus)
122+
assert termination_status == poi.TerminationStatusCode.LOCALLY_SOLVED
123+
124+
model.add_quadratic_constraint(x**2 >= 0.36)
125+
termination_status = model.get_model_attribute(poi.ModelAttribute.TerminationStatus)
126+
assert termination_status == poi.TerminationStatusCode.OPTIMIZE_NOT_CALLED
127+
128+
model.optimize()
129+
termination_status = model.get_model_attribute(poi.ModelAttribute.TerminationStatus)
130+
assert termination_status == poi.TerminationStatusCode.LOCALLY_SOLVED
131+
132+
with nl.graph():
133+
model.add_nl_constraint(nl.exp(x) <= 100.0)
134+
termination_status = model.get_model_attribute(poi.ModelAttribute.TerminationStatus)
135+
assert termination_status == poi.TerminationStatusCode.OPTIMIZE_NOT_CALLED
136+
137+
model.optimize()
138+
termination_status = model.get_model_attribute(poi.ModelAttribute.TerminationStatus)
139+
assert termination_status == poi.TerminationStatusCode.LOCALLY_SOLVED
140+
141+
with nl.graph():
142+
model.add_nl_objective(nl.log(x))
143+
termination_status = model.get_model_attribute(poi.ModelAttribute.TerminationStatus)
144+
assert termination_status == poi.TerminationStatusCode.OPTIMIZE_NOT_CALLED
145+
146+
model.optimize()
147+
termination_status = model.get_model_attribute(poi.ModelAttribute.TerminationStatus)
148+
assert termination_status == poi.TerminationStatusCode.LOCALLY_SOLVED
149+
150+
103151
if __name__ == "__main__":
104152

105153
def c():

0 commit comments

Comments
 (0)