Skip to content

Commit dfd16f3

Browse files
committed
Allow quicksum for high dimensional numpy array directly without writing poi.quicksum(x.flat)
1 parent f6d5fa5 commit dfd16f3

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/pyoptinterface/_src/aml.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from collections.abc import Collection
55
from typing import Tuple, Union
66

7+
import numpy as np
8+
79

810
def make_variable_ndarray(
911
model,
@@ -105,6 +107,8 @@ def f(*args):
105107
def quicksum_(expr: ExprBuilder, terms, f=None):
106108
if isinstance(terms, dict):
107109
iter = terms.values()
110+
elif isinstance(terms, np.ndarray):
111+
iter = terms.flat
108112
else:
109113
iter = terms
110114
if f:

tests/test_matrix_api.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,15 @@ def test_matrix_api(model_interface):
2727
model.optimize()
2828
obj_value = model.get_model_attribute(poi.ModelAttribute.ObjectiveValue)
2929
assert obj_value == approx(-N * ub)
30+
31+
32+
def test_quicksum_ndarray(model_interface):
33+
model = model_interface
34+
35+
N = 10
36+
x = model.add_m_variables((N, 2 * N), lb=1.0, ub=3.0)
37+
obj = poi.quicksum(x)
38+
model.set_objective(obj)
39+
model.optimize()
40+
obj_value = model.get_model_attribute(poi.ModelAttribute.ObjectiveValue)
41+
assert obj_value == approx(2 * N**2)

0 commit comments

Comments
 (0)