From 5559dfde3f99bd7f8c86ce55b619032040d1bfe6 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Mon, 12 Jan 2026 18:35:00 -0500 Subject: [PATCH 1/6] Add index atom for array indexing and slicing Implements efficient indexing operator that handles both CVXPY's `index` (slice-based) and `special_index` (array/boolean) atoms. - Forward: O(n_selected) gather operation - Jacobian: Pre-computed row mapping for fast memcpy per row - Hessian: Pre-allocated scatter buffer with accumulation for repeated indices (equivalent to np.add.at) Enables NLP tests like test_hs071 and test_rosenbrock that use x[0], x[1] style indexing. Co-Authored-By: Claude Opus 4.5 --- include/affine.h | 2 + include/subexpr.h | 13 +++ python/atoms/index.h | 52 +++++++++ python/bindings.c | 2 + src/affine/index.c | 187 ++++++++++++++++++++++++++++++ src/dnlp_diff_engine/__init__.py | 28 +++++ tests/all_tests.c | 11 ++ tests/jacobian_tests/test_index.h | 140 ++++++++++++++++++++++ tests/wsum_hess/test_index.h | 116 ++++++++++++++++++ 9 files changed, 551 insertions(+) create mode 100644 python/atoms/index.h create mode 100644 src/affine/index.c create mode 100644 tests/jacobian_tests/test_index.h create mode 100644 tests/wsum_hess/test_index.h diff --git a/include/affine.h b/include/affine.h index 654557b..f7943eb 100644 --- a/include/affine.h +++ b/include/affine.h @@ -18,4 +18,6 @@ expr *new_trace(expr *child); expr *new_constant(int d1, int d2, int n_vars, const double *values); expr *new_variable(int d1, int d2, int var_id, int n_vars); +expr *new_index(expr *child, const int *indices, int n_selected); + #endif /* AFFINE_H */ diff --git a/include/subexpr.h b/include/subexpr.h index 8239a7d..8f1c081 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -109,4 +109,17 @@ typedef struct const_vector_mult_expr double *a; /* length equals node->size */ } const_vector_mult_expr; +/* Index/slicing: y = child[indices] where indices is a list of flattened positions */ +typedef struct index_expr +{ + expr base; + int *indices; /* Flattened indices to select (owned, copied) */ + int n_selected; /* Number of selected elements */ + /* Pre-computed for fast jacobian eval: */ + int *jac_row_starts; /* Position in jacobian->x where each selected row starts */ + int *jac_row_lengths; /* Number of nonzeros in each selected row */ + /* Pre-allocated for wsum_hess: */ + double *parent_w; /* Scatter buffer (size = child->size), zeroed each eval */ +} index_expr; + #endif /* SUBEXPR_H */ diff --git a/python/atoms/index.h b/python/atoms/index.h new file mode 100644 index 0000000..a4d1dfd --- /dev/null +++ b/python/atoms/index.h @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 + +#ifndef ATOM_INDEX_H +#define ATOM_INDEX_H + +#include "affine.h" +#include "common.h" + +/* Index/slicing: y = child[indices] where indices is a list of flattened positions */ +static PyObject *py_make_index(PyObject *self, PyObject *args) +{ + PyObject *child_capsule; + PyObject *indices_obj; + + if (!PyArg_ParseTuple(args, "OO", &child_capsule, &indices_obj)) + { + return NULL; + } + + expr *child = (expr *)PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + if (!child) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + /* Convert indices array to int32 */ + PyArrayObject *indices_array = + (PyArrayObject *)PyArray_FROM_OTF(indices_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY); + + if (!indices_array) + { + return NULL; + } + + int n_selected = (int)PyArray_SIZE(indices_array); + int *indices_data = (int *)PyArray_DATA(indices_array); + + expr *node = new_index(child, indices_data, n_selected); + + Py_DECREF(indices_array); + + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create index node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_INDEX_H */ diff --git a/python/bindings.c b/python/bindings.c index 3045e2b..af06f57 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -12,6 +12,7 @@ #include "atoms/cos.h" #include "atoms/entr.h" #include "atoms/exp.h" +#include "atoms/index.h" #include "atoms/left_matmul.h" #include "atoms/linear.h" #include "atoms/log.h" @@ -55,6 +56,7 @@ static PyMethodDef DNLPMethods[] = { {"make_linear", py_make_linear, METH_VARARGS, "Create linear op node"}, {"make_log", py_make_log, METH_VARARGS, "Create log node"}, {"make_exp", py_make_exp, METH_VARARGS, "Create exp node"}, + {"make_index", py_make_index, METH_VARARGS, "Create index node"}, {"make_add", py_make_add, METH_VARARGS, "Create add node"}, {"make_sum", py_make_sum, METH_VARARGS, "Create sum node"}, {"make_neg", py_make_neg, METH_VARARGS, "Create neg node"}, diff --git a/src/affine/index.c b/src/affine/index.c new file mode 100644 index 0000000..f909ab1 --- /dev/null +++ b/src/affine/index.c @@ -0,0 +1,187 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include "affine.h" +#include "subexpr.h" +#include +#include + +/* Index/slicing: y = child[indices] where indices is a list of flattened positions */ + +static void forward(expr *node, const double *u) +{ + expr *child = node->left; + index_expr *idx = (index_expr *)node; + + /* child's forward pass */ + child->forward(child, u); + + /* gather selected elements */ + for (int i = 0; i < idx->n_selected; i++) + { + node->value[i] = child->value[idx->indices[i]]; + } +} + +static void jacobian_init(expr *node) +{ + expr *child = node->left; + index_expr *idx = (index_expr *)node; + + /* initialize child's jacobian */ + child->jacobian_init(child); + CSR_Matrix *J_child = child->jacobian; + + /* allocate mapping arrays */ + idx->jac_row_starts = (int *)malloc(idx->n_selected * sizeof(int)); + idx->jac_row_lengths = (int *)malloc(idx->n_selected * sizeof(int)); + + /* count nnz and pre-compute row mapping */ + int nnz = 0; + for (int i = 0; i < idx->n_selected; i++) + { + int row = idx->indices[i]; + int row_len = J_child->p[row + 1] - J_child->p[row]; + idx->jac_row_starts[i] = nnz; + idx->jac_row_lengths[i] = row_len; + nnz += row_len; + } + + /* allocate jacobian with computed sparsity */ + node->jacobian = new_csr_matrix(idx->n_selected, node->n_vars, nnz); + + /* fill sparsity pattern (p and i arrays) */ + node->jacobian->p[0] = 0; + for (int i = 0; i < idx->n_selected; i++) + { + int row = idx->indices[i]; + int row_len = idx->jac_row_lengths[i]; + /* copy column indices from child row */ + if (row_len > 0) + { + memcpy(&node->jacobian->i[idx->jac_row_starts[i]], + &J_child->i[J_child->p[row]], row_len * sizeof(int)); + } + node->jacobian->p[i + 1] = idx->jac_row_starts[i] + row_len; + } +} + +static void eval_jacobian(expr *node) +{ + expr *child = node->left; + index_expr *idx = (index_expr *)node; + CSR_Matrix *J_child = child->jacobian; + + /* evaluate child's jacobian */ + child->eval_jacobian(child); + + /* fast copy using pre-computed mapping */ + for (int i = 0; i < idx->n_selected; i++) + { + int row = idx->indices[i]; + int row_len = idx->jac_row_lengths[i]; + if (row_len > 0) + { + memcpy(&node->jacobian->x[idx->jac_row_starts[i]], + &J_child->x[J_child->p[row]], + row_len * sizeof(double)); + } + } +} + +static void wsum_hess_init(expr *node) +{ + expr *child = node->left; + index_expr *idx = (index_expr *)node; + + /* initialize child's wsum_hess */ + child->wsum_hess_init(child); + + /* allocate scatter buffer (zeroed) */ + idx->parent_w = (double *)calloc(child->size, sizeof(double)); + + /* wsum_hess inherits from child (affine has no local Hessian) */ + /* We need to allocate our own to avoid aliasing issues */ + CSR_Matrix *child_hess = child->wsum_hess; + node->wsum_hess = new_csr_matrix(child_hess->m, child_hess->n, child_hess->nnz); + memcpy(node->wsum_hess->p, child_hess->p, (child_hess->m + 1) * sizeof(int)); + memcpy(node->wsum_hess->i, child_hess->i, child_hess->nnz * sizeof(int)); +} + +static void eval_wsum_hess(expr *node, const double *w) +{ + expr *child = node->left; + index_expr *idx = (index_expr *)node; + + /* zero the scatter buffer */ + memset(idx->parent_w, 0, child->size * sizeof(double)); + + /* scatter w to child size with accumulation (handles repeated indices) */ + for (int i = 0; i < idx->n_selected; i++) + { + idx->parent_w[idx->indices[i]] += w[i]; + } + + /* delegate to child */ + child->eval_wsum_hess(child, idx->parent_w); + + /* copy values from child */ + memcpy(node->wsum_hess->x, child->wsum_hess->x, child->wsum_hess->nnz * sizeof(double)); +} + +static bool is_affine(const expr *node) +{ + return node->left->is_affine(node->left); +} + +static void free_type_data(expr *node) +{ + index_expr *idx = (index_expr *)node; + if (idx->indices) + { + free(idx->indices); + idx->indices = NULL; + } + if (idx->jac_row_starts) + { + free(idx->jac_row_starts); + idx->jac_row_starts = NULL; + } + if (idx->jac_row_lengths) + { + free(idx->jac_row_lengths); + idx->jac_row_lengths = NULL; + } + if (idx->parent_w) + { + free(idx->parent_w); + idx->parent_w = NULL; + } +} + +expr *new_index(expr *child, const int *indices, int n_selected) +{ + /* allocate type-specific struct */ + index_expr *idx = (index_expr *)calloc(1, sizeof(index_expr)); + expr *node = &idx->base; + + /* output shape is (n_selected, 1) - flattened */ + init_expr(node, n_selected, 1, child->n_vars, forward, jacobian_init, + eval_jacobian, is_affine, free_type_data); + + node->wsum_hess_init = wsum_hess_init; + node->eval_wsum_hess = eval_wsum_hess; + node->left = child; + expr_retain(child); + + /* copy indices */ + idx->indices = (int *)malloc(n_selected * sizeof(int)); + memcpy(idx->indices, indices, n_selected * sizeof(int)); + idx->n_selected = n_selected; + + /* mapping arrays allocated lazily in jacobian_init */ + idx->jac_row_starts = NULL; + idx->jac_row_lengths = NULL; + idx->parent_w = NULL; + + return node; +} diff --git a/src/dnlp_diff_engine/__init__.py b/src/dnlp_diff_engine/__init__.py index 3a10c67..8d9a032 100644 --- a/src/dnlp_diff_engine/__init__.py +++ b/src/dnlp_diff_engine/__init__.py @@ -100,6 +100,27 @@ def _convert_multiply(expr, children): return _diffengine.make_multiply(children[0], children[1]) +def _extract_flat_indices_from_index(expr): + """Extract flattened indices from CVXPY index expression.""" + parent_shape = expr.args[0].shape + indices_per_dim = [np.arange(s.start, s.stop, s.step) for s in expr.key] + + if len(indices_per_dim) == 1: + return indices_per_dim[0].astype(np.int32) + elif len(indices_per_dim) == 2: + # Fortran order: idx = row + col * n_rows + return np.add.outer( + indices_per_dim[0], indices_per_dim[1] * parent_shape[0] + ).flatten(order="F").astype(np.int32) + else: + raise NotImplementedError("index with >2 dimensions not supported") + + +def _extract_flat_indices_from_special_index(expr): + """Extract flattened indices from CVXPY special_index expression.""" + return np.reshape(expr._select_mat, expr._select_mat.size, order="F").astype(np.int32) + + def _convert_quad_form(expr, children): """Convert quadratic form x.T @ P @ x.""" @@ -162,6 +183,13 @@ def _convert_quad_form(expr, children): "entr": lambda _expr, children: _diffengine.make_entr(children[0]), "logistic": lambda _expr, children: _diffengine.make_logistic(children[0]), "xexp": lambda _expr, children: _diffengine.make_xexp(children[0]), + # Indexing/slicing + "index": lambda expr, children: _diffengine.make_index( + children[0], _extract_flat_indices_from_index(expr) + ), + "special_index": lambda expr, children: _diffengine.make_index( + children[0], _extract_flat_indices_from_special_index(expr) + ), } diff --git a/tests/all_tests.c b/tests/all_tests.c index 5d65653..9c7692d 100644 --- a/tests/all_tests.c +++ b/tests/all_tests.c @@ -18,6 +18,7 @@ #include "jacobian_tests/test_const_vector_mult.h" #include "jacobian_tests/test_elementwise_mult.h" #include "jacobian_tests/test_hstack.h" +#include "jacobian_tests/test_index.h" #include "jacobian_tests/test_left_matmul.h" #include "jacobian_tests/test_log.h" #include "jacobian_tests/test_neg.h" @@ -43,6 +44,7 @@ #include "wsum_hess/test_const_scalar_mult.h" #include "wsum_hess/test_const_vector_mult.h" #include "wsum_hess/test_hstack.h" +#include "wsum_hess/test_index.h" #include "wsum_hess/test_left_matmul.h" #include "wsum_hess/test_multiply.h" #include "wsum_hess/test_prod.h" @@ -111,6 +113,12 @@ int main(void) mu_run_test(test_jacobian_sum_log_axis_1, tests_run); mu_run_test(test_jacobian_hstack_vectors, tests_run); mu_run_test(test_jacobian_hstack_matrix, tests_run); + mu_run_test(test_index_forward_simple, tests_run); + mu_run_test(test_index_forward_repeated, tests_run); + mu_run_test(test_index_jacobian_of_variable, tests_run); + mu_run_test(test_index_jacobian_of_log, tests_run); + mu_run_test(test_index_jacobian_repeated, tests_run); + mu_run_test(test_sum_of_index, tests_run); mu_run_test(test_promote_scalar_jacobian, tests_run); mu_run_test(test_promote_scalar_to_matrix_jacobian, tests_run); mu_run_test(test_wsum_hess_multiply_1, tests_run); @@ -150,6 +158,9 @@ int main(void) mu_run_test(test_wsum_hess_rel_entr_matrix, tests_run); mu_run_test(test_wsum_hess_hstack, tests_run); mu_run_test(test_wsum_hess_hstack_matrix, tests_run); + mu_run_test(test_wsum_hess_index_log, tests_run); + mu_run_test(test_wsum_hess_index_repeated, tests_run); + mu_run_test(test_wsum_hess_sum_index_log, tests_run); mu_run_test(test_wsum_hess_quad_over_lin_xy, tests_run); mu_run_test(test_wsum_hess_quad_over_lin_yx, tests_run); mu_run_test(test_wsum_hess_quad_form, tests_run); diff --git a/tests/jacobian_tests/test_index.h b/tests/jacobian_tests/test_index.h new file mode 100644 index 0000000..5bf73bb --- /dev/null +++ b/tests/jacobian_tests/test_index.h @@ -0,0 +1,140 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "affine.h" +#include "elementwise_univariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_index_forward_simple(void) +{ + /* x = [1, 2, 3], select indices [0, 2] -> [1, 3] */ + double u[3] = {1.0, 2.0, 3.0}; + int indices[2] = {0, 2}; + expr *var = new_variable(3, 1, 0, 3); + expr *idx = new_index(var, indices, 2); + idx->forward(idx, u); + + double expected[2] = {1.0, 3.0}; + mu_assert("index forward simple", cmp_double_array(idx->value, expected, 2)); + + free_expr(idx); + return 0; +} + +const char *test_index_forward_repeated(void) +{ + /* x = [1, 2, 3], select [0, 0, 2] -> [1, 1, 3] */ + double u[3] = {1.0, 2.0, 3.0}; + int indices[3] = {0, 0, 2}; + expr *var = new_variable(3, 1, 0, 3); + expr *idx = new_index(var, indices, 3); + idx->forward(idx, u); + + double expected[3] = {1.0, 1.0, 3.0}; + mu_assert("index forward repeated", cmp_double_array(idx->value, expected, 3)); + + free_expr(idx); + return 0; +} + +const char *test_index_jacobian_of_variable(void) +{ + /* x[0, 2] where x is variable of size 3 + * Jacobian should select rows 0 and 2 of identity */ + double u[3] = {1.0, 2.0, 3.0}; + int indices[2] = {0, 2}; + expr *var = new_variable(3, 1, 0, 3); + expr *idx = new_index(var, indices, 2); + idx->forward(idx, u); + idx->jacobian_init(idx); + idx->eval_jacobian(idx); + + /* Jacobian is 2x3 with pattern: row 0 selects col 0, row 1 selects col 2 */ + double expected_x[2] = {1.0, 1.0}; + int expected_p[3] = {0, 1, 2}; /* CSR row ptrs */ + int expected_i[2] = {0, 2}; /* column indices */ + + mu_assert("index jac vals", cmp_double_array(idx->jacobian->x, expected_x, 2)); + mu_assert("index jac p", cmp_int_array(idx->jacobian->p, expected_p, 3)); + mu_assert("index jac i", cmp_int_array(idx->jacobian->i, expected_i, 2)); + + free_expr(idx); + return 0; +} + +const char *test_index_jacobian_of_log(void) +{ + /* log(x)[0, 2] - jacobian should be selected rows of diag(1/x) */ + double u[3] = {1.0, 2.0, 4.0}; + int indices[2] = {0, 2}; + expr *var = new_variable(3, 1, 0, 3); + expr *log_node = new_log(var); + expr *idx = new_index(log_node, indices, 2); + idx->forward(idx, u); + idx->jacobian_init(idx); + idx->eval_jacobian(idx); + + /* d/dx log(x) = diag(1/x), then select rows 0 and 2 + * Row 0: 1/1 = 1.0 at col 0 + * Row 1: 1/4 = 0.25 at col 2 */ + double expected_x[2] = {1.0, 0.25}; + int expected_i[2] = {0, 2}; + + mu_assert("index of log jac vals", cmp_double_array(idx->jacobian->x, expected_x, 2)); + mu_assert("index of log jac cols", cmp_int_array(idx->jacobian->i, expected_i, 2)); + + free_expr(idx); + return 0; +} + +const char *test_index_jacobian_repeated(void) +{ + /* x[0, 0] where x is size 3 - both outputs depend on x[0] */ + double u[3] = {1.0, 2.0, 3.0}; + int indices[2] = {0, 0}; + expr *var = new_variable(3, 1, 0, 3); + expr *idx = new_index(var, indices, 2); + idx->forward(idx, u); + idx->jacobian_init(idx); + idx->eval_jacobian(idx); + + /* Both rows select column 0 */ + double expected_x[2] = {1.0, 1.0}; + int expected_p[3] = {0, 1, 2}; + int expected_i[2] = {0, 0}; /* Both reference col 0 */ + + mu_assert("index repeated jac vals", cmp_double_array(idx->jacobian->x, expected_x, 2)); + mu_assert("index repeated jac i", cmp_int_array(idx->jacobian->i, expected_i, 2)); + + free_expr(idx); + return 0; +} + +const char *test_sum_of_index(void) +{ + /* sum(x[0, 2]) = x[0] + x[2] + * Gradient should be sparse with entries at cols 0 and 2 */ + double u[3] = {1.0, 2.0, 3.0}; + int indices[2] = {0, 2}; + + expr *var = new_variable(3, 1, 0, 3); + expr *idx = new_index(var, indices, 2); + expr *s = new_sum(idx, -1); /* sum all */ + + s->forward(s, u); + s->jacobian_init(s); + s->eval_jacobian(s); + + /* Gradient: [1, 0, 1] in sparse form */ + double expected_x[2] = {1.0, 1.0}; + int expected_i[2] = {0, 2}; + + mu_assert("sum of index vals", cmp_double_array(s->jacobian->x, expected_x, 2)); + mu_assert("sum of index cols", cmp_int_array(s->jacobian->i, expected_i, 2)); + + free_expr(s); + return 0; +} diff --git a/tests/wsum_hess/test_index.h b/tests/wsum_hess/test_index.h new file mode 100644 index 0000000..0a2fb9c --- /dev/null +++ b/tests/wsum_hess/test_index.h @@ -0,0 +1,116 @@ +// SPDX-License-Identifier: Apache-2.0 + +#include +#include +#include +#include + +#include "affine.h" +#include "elementwise_univariate.h" +#include "expr.h" +#include "minunit.h" +#include "test_helpers.h" + +const char *test_wsum_hess_index_log(void) +{ + /* log(x)[0, 2] with w = [1, 2] + * Hessian of log is diag(-1/x^2) + * After scatter: parent_w = [1, 0, 2] + * Hessian: diag(-1/x^2) weighted by [1, 0, 2] = diag([-1, 0, -0.125]) */ + double u[3] = {1.0, 2.0, 4.0}; + int indices[2] = {0, 2}; + double w[2] = {1.0, 2.0}; + + expr *var = new_variable(3, 1, 0, 3); + expr *log_node = new_log(var); + expr *idx = new_index(log_node, indices, 2); + + idx->forward(idx, u); + idx->jacobian_init(idx); + idx->wsum_hess_init(idx); + idx->eval_wsum_hess(idx, w); + + /* Expected diagonal values: + * H[0,0] = -1 * 1/1^2 = -1.0 + * H[1,1] = 0 (no weight) + * H[2,2] = -2 * 1/16 = -0.125 */ + double expected_x[3] = {-1.0, 0.0, -0.125}; + int expected_p[4] = {0, 1, 2, 3}; + int expected_i[3] = {0, 1, 2}; + + mu_assert("index log hess vals", cmp_double_array(idx->wsum_hess->x, expected_x, 3)); + mu_assert("index log hess p", cmp_int_array(idx->wsum_hess->p, expected_p, 4)); + mu_assert("index log hess i", cmp_int_array(idx->wsum_hess->i, expected_i, 3)); + + free_expr(idx); + return 0; +} + +const char *test_wsum_hess_index_repeated(void) +{ + /* log(x)[0, 0] with w = [1, 2] - weights should accumulate + * parent_w = [3, 0, 0] (1+2 at position 0) */ + double u[3] = {2.0, 3.0, 4.0}; + int indices[2] = {0, 0}; + double w[2] = {1.0, 2.0}; + + expr *var = new_variable(3, 1, 0, 3); + expr *log_node = new_log(var); + expr *idx = new_index(log_node, indices, 2); + + idx->forward(idx, u); + idx->jacobian_init(idx); + idx->wsum_hess_init(idx); + idx->eval_wsum_hess(idx, w); + + /* Hessian of log at x=2 is -1/4 + * weighted by 3 (accumulated) -> -3/4 = -0.75 + * Other positions have 0 weight */ + double expected_x[3] = {-0.75, 0.0, 0.0}; + int expected_p[4] = {0, 1, 2, 3}; + int expected_i[3] = {0, 1, 2}; + + mu_assert("index repeated hess vals", cmp_double_array(idx->wsum_hess->x, expected_x, 3)); + mu_assert("index repeated hess p", cmp_int_array(idx->wsum_hess->p, expected_p, 4)); + mu_assert("index repeated hess i", cmp_int_array(idx->wsum_hess->i, expected_i, 3)); + + free_expr(idx); + return 0; +} + +const char *test_wsum_hess_sum_index_log(void) +{ + /* sum(log(x)[0, 2]) with w = 1.0 + * This tests the chain: sum -> index -> log -> variable + * Gradient: [1/x[0], 0, 1/x[2]] + * Hessian: diag([-1/x[0]^2, 0, -1/x[2]^2]) */ + double u[3] = {1.0, 2.0, 4.0}; + int indices[2] = {0, 2}; + double w = 1.0; + + expr *var = new_variable(3, 1, 0, 3); + expr *log_node = new_log(var); + expr *idx = new_index(log_node, indices, 2); + expr *sum_node = new_sum(idx, -1); + + sum_node->forward(sum_node, u); + sum_node->jacobian_init(sum_node); + sum_node->wsum_hess_init(sum_node); + sum_node->eval_wsum_hess(sum_node, &w); + + /* Expected diagonal values: + * H[0,0] = -1 * 1/1^2 = -1.0 + * H[1,1] = 0 (index 1 not selected) + * H[2,2] = -1 * 1/16 = -0.0625 */ + double expected_x[3] = {-1.0, 0.0, -0.0625}; + int expected_p[4] = {0, 1, 2, 3}; + int expected_i[3] = {0, 1, 2}; + + mu_assert("sum index log hess vals", + cmp_double_array(sum_node->wsum_hess->x, expected_x, 3)); + mu_assert("sum index log hess p", cmp_int_array(sum_node->wsum_hess->p, expected_p, 4)); + mu_assert("sum index log hess i", cmp_int_array(sum_node->wsum_hess->i, expected_i, 3)); + + free_expr(sum_node); + return 0; +} From c69620a67a70f6e74537c76f318c0df5a9e334aa Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Mon, 12 Jan 2026 18:44:52 -0500 Subject: [PATCH 2/6] Add reshape converter (Fortran order only) Reshape with order='F' is a pass-through since the underlying data layout is unchanged - only shape interpretation differs. Note: Only Fortran order is supported. C order would require data permutation which is not yet implemented. Co-Authored-By: Claude Opus 4.5 --- src/dnlp_diff_engine/__init__.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/dnlp_diff_engine/__init__.py b/src/dnlp_diff_engine/__init__.py index 8d9a032..f4e2dec 100644 --- a/src/dnlp_diff_engine/__init__.py +++ b/src/dnlp_diff_engine/__init__.py @@ -121,6 +121,24 @@ def _extract_flat_indices_from_special_index(expr): return np.reshape(expr._select_mat, expr._select_mat.size, order="F").astype(np.int32) +def _convert_reshape(expr, children): + """Convert reshape - only Fortran order is supported. + + For Fortran order, reshape is a no-op since the underlying data layout + is unchanged. We just pass through the child expression. + + Note: Only order='F' (Fortran/column-major) is supported. This is the + default and most common case in CVXPY. C order would require permutation. + """ + if expr.order != 'F': + raise NotImplementedError( + f"reshape with order='{expr.order}' not supported. " + "Only order='F' (Fortran) is currently supported." + ) + # Pass through - data layout is unchanged with Fortran order + return children[0] + + def _convert_quad_form(expr, children): """Convert quadratic form x.T @ P @ x.""" @@ -190,6 +208,8 @@ def _convert_quad_form(expr, children): "special_index": lambda expr, children: _diffengine.make_index( children[0], _extract_flat_indices_from_special_index(expr) ), + # Reshape (Fortran order only - pass-through since data layout unchanged) + "reshape": _convert_reshape, } From 0b7e41a69bc5583234066da6ce54c941409b3363 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Mon, 12 Jan 2026 18:55:40 -0500 Subject: [PATCH 3/6] Add quad_over_lin binding and converter Enables test_socp and test_portfolio_socp to pass. Co-Authored-By: Claude Opus 4.5 --- python/atoms/quad_over_lin.h | 36 ++++++++++++++++++++++++++++++++ python/bindings.c | 3 +++ src/dnlp_diff_engine/__init__.py | 3 +++ 3 files changed, 42 insertions(+) create mode 100644 python/atoms/quad_over_lin.h diff --git a/python/atoms/quad_over_lin.h b/python/atoms/quad_over_lin.h new file mode 100644 index 0000000..4cb3b91 --- /dev/null +++ b/python/atoms/quad_over_lin.h @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 + +#ifndef ATOM_QUAD_OVER_LIN_H +#define ATOM_QUAD_OVER_LIN_H + +#include "bivariate.h" +#include "common.h" + +/* quad_over_lin: y = sum(x^2) / z where x is left, z is right (scalar) */ +static PyObject *py_make_quad_over_lin(PyObject *self, PyObject *args) +{ + (void)self; + PyObject *left_capsule, *right_capsule; + if (!PyArg_ParseTuple(args, "OO", &left_capsule, &right_capsule)) + { + return NULL; + } + expr *left = (expr *)PyCapsule_GetPointer(left_capsule, EXPR_CAPSULE_NAME); + expr *right = (expr *)PyCapsule_GetPointer(right_capsule, EXPR_CAPSULE_NAME); + if (!left || !right) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_quad_over_lin(left, right); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create quad_over_lin node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_QUAD_OVER_LIN_H */ diff --git a/python/bindings.c b/python/bindings.c index af06f57..fda34d7 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -22,6 +22,7 @@ #include "atoms/power.h" #include "atoms/promote.h" #include "atoms/quad_form.h" +#include "atoms/quad_over_lin.h" #include "atoms/right_matmul.h" #include "atoms/sin.h" #include "atoms/sinh.h" @@ -84,6 +85,8 @@ static PyMethodDef DNLPMethods[] = { "Create right matmul node (f(x) @ A)"}, {"make_quad_form", py_make_quad_form, METH_VARARGS, "Create quadratic form node (x' * Q * x)"}, + {"make_quad_over_lin", py_make_quad_over_lin, METH_VARARGS, + "Create quad_over_lin node (sum(x^2) / y)"}, {"make_problem", py_make_problem, METH_VARARGS, "Create problem from objective and constraints"}, {"problem_init_derivatives", py_problem_init_derivatives, METH_VARARGS, diff --git a/src/dnlp_diff_engine/__init__.py b/src/dnlp_diff_engine/__init__.py index f4e2dec..c9be053 100644 --- a/src/dnlp_diff_engine/__init__.py +++ b/src/dnlp_diff_engine/__init__.py @@ -184,6 +184,9 @@ def _convert_quad_form(expr, children): # Bivariate "multiply": _convert_multiply, "QuadForm": _convert_quad_form, + "quad_over_lin": lambda _expr, children: _diffengine.make_quad_over_lin( + children[0], children[1] + ), # Matrix multiplication "MulExpression": _convert_matmul, # Elementwise univariate with parameter From a8fa3dd111428ca6972cac438421d3c6e18b74d3 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Mon, 12 Jan 2026 19:09:09 -0500 Subject: [PATCH 4/6] Add rel_entr binding and converter Exposes the existing rel_entr C implementation (for equal-sized args) through Python bindings. The scalar argument variants declared in the header are not yet implemented in C. Co-Authored-By: Claude Opus 4.5 --- python/atoms/rel_entr.h | 36 ++++++++++++++++++++++++++++++++ python/bindings.c | 3 +++ src/dnlp_diff_engine/__init__.py | 6 ++++++ 3 files changed, 45 insertions(+) create mode 100644 python/atoms/rel_entr.h diff --git a/python/atoms/rel_entr.h b/python/atoms/rel_entr.h new file mode 100644 index 0000000..e233d17 --- /dev/null +++ b/python/atoms/rel_entr.h @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: Apache-2.0 + +#ifndef ATOM_REL_ENTR_H +#define ATOM_REL_ENTR_H + +#include "bivariate.h" +#include "common.h" + +/* rel_entr: rel_entr(x, y) = x * log(x/y) elementwise */ +static PyObject *py_make_rel_entr(PyObject *self, PyObject *args) +{ + (void)self; + PyObject *left_capsule, *right_capsule; + if (!PyArg_ParseTuple(args, "OO", &left_capsule, &right_capsule)) + { + return NULL; + } + expr *left = (expr *)PyCapsule_GetPointer(left_capsule, EXPR_CAPSULE_NAME); + expr *right = (expr *)PyCapsule_GetPointer(right_capsule, EXPR_CAPSULE_NAME); + if (!left || !right) + { + PyErr_SetString(PyExc_ValueError, "invalid child capsule"); + return NULL; + } + + expr *node = new_rel_entr_vector_args(left, right); + if (!node) + { + PyErr_SetString(PyExc_RuntimeError, "failed to create rel_entr node"); + return NULL; + } + expr_retain(node); /* Capsule owns a reference */ + return PyCapsule_New(node, EXPR_CAPSULE_NAME, expr_capsule_destructor); +} + +#endif /* ATOM_REL_ENTR_H */ diff --git a/python/bindings.c b/python/bindings.c index fda34d7..1042701 100644 --- a/python/bindings.c +++ b/python/bindings.c @@ -23,6 +23,7 @@ #include "atoms/promote.h" #include "atoms/quad_form.h" #include "atoms/quad_over_lin.h" +#include "atoms/rel_entr.h" #include "atoms/right_matmul.h" #include "atoms/sin.h" #include "atoms/sinh.h" @@ -87,6 +88,8 @@ static PyMethodDef DNLPMethods[] = { "Create quadratic form node (x' * Q * x)"}, {"make_quad_over_lin", py_make_quad_over_lin, METH_VARARGS, "Create quad_over_lin node (sum(x^2) / y)"}, + {"make_rel_entr", py_make_rel_entr, METH_VARARGS, + "Create rel_entr node: x * log(x/y) elementwise"}, {"make_problem", py_make_problem, METH_VARARGS, "Create problem from objective and constraints"}, {"problem_init_derivatives", py_problem_init_derivatives, METH_VARARGS, diff --git a/src/dnlp_diff_engine/__init__.py b/src/dnlp_diff_engine/__init__.py index c9be053..a8cda90 100644 --- a/src/dnlp_diff_engine/__init__.py +++ b/src/dnlp_diff_engine/__init__.py @@ -139,6 +139,11 @@ def _convert_reshape(expr, children): return children[0] +def _convert_rel_entr(_expr, children): + """Convert rel_entr(x, y) = x * log(x/y) elementwise.""" + return _diffengine.make_rel_entr(children[0], children[1]) + + def _convert_quad_form(expr, children): """Convert quadratic form x.T @ P @ x.""" @@ -187,6 +192,7 @@ def _convert_quad_form(expr, children): "quad_over_lin": lambda _expr, children: _diffengine.make_quad_over_lin( children[0], children[1] ), + "rel_entr": _convert_rel_entr, # Matrix multiplication "MulExpression": _convert_matmul, # Elementwise univariate with parameter From 89e7033ef52d5d45b43ead242e21cc1452f4bbe0 Mon Sep 17 00:00:00 2001 From: William Zijie Zhang Date: Tue, 13 Jan 2026 02:02:20 -0500 Subject: [PATCH 5/6] Optimize index atom: add duplicate detection, simplify Jacobian - Add has_duplicates flag to detect repeated indices at construction - Hessian eval: fast path (direct write) when no duplicates, avoiding O(child_size) memset; slow path (zero + accumulate) for duplicates - Remove jac_row_starts and jac_row_lengths arrays from index_expr - Simplify jacobian_init: use CSR p array directly instead of helpers - Simplify eval_jacobian: compute row lengths on the fly (trivial cost) - Convert &array[i] to array + i style for consistency with codebase - Update CLAUDE.md with improved build instructions and atom lists Co-Authored-By: Claude Opus 4.5 --- CLAUDE.md | 51 ++++++++++++----------- include/subexpr.h | 7 +--- src/affine/index.c | 101 +++++++++++++++++++++++---------------------- 3 files changed, 79 insertions(+), 80 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 544eccb..43b2bfd 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -11,17 +11,26 @@ DNLP-diff-engine is a C library with Python bindings that provides automatic dif ### Python Package (Recommended) ```bash -# Install in development mode (builds C library + Python bindings) -pip install -e . +# Install in development mode with uv (recommended) +uv pip install -e ".[test]" -# Run all Python tests +# Or with pip +pip install -e ".[test]" + +# Run all Python tests (tests are in python/tests/) pytest # Run specific test file -pytest tests/python/test_unconstrained.py +pytest python/tests/test_unconstrained.py # Run specific test -pytest tests/python/test_unconstrained.py::test_sum_log +pytest python/tests/test_unconstrained.py::test_sum_log + +# Lint with ruff +ruff check src/ + +# Auto-fix lint issues +ruff check --fix src/ ``` ### Standalone C Library @@ -35,16 +44,6 @@ cmake --build build ./build/all_tests ``` -### Legacy Python Build (without pip) - -```bash -# Build Python bindings manually (from project root) -cd python && cmake -B build -S . && cmake --build build && cd .. - -# Run tests from python/ directory -cd python && python tests/test_problem_native.py -``` - ## Architecture ### Expression Tree System @@ -59,10 +58,10 @@ The core abstraction is the `expr` struct (in `include/expr.h`) representing a n Atoms are organized by mathematical properties in `src/`: -- **`affine/`** - Linear operations: `variable`, `constant`, `add`, `neg`, `sum`, `promote`, `hstack`, `trace`, `linear_op` -- **`elementwise_univariate/`** - Scalar functions applied elementwise: `log`, `exp`, `entr`, `power`, `logistic`, trigonometric, hyperbolic -- **`bivariate/`** - Two-argument operations: `multiply`, `quad_over_lin`, `rel_entr` -- **`other/`** - Special atoms not fitting above categories +- **`affine/`** - Linear operations: `variable`, `constant`, `add`, `neg`, `sum`, `promote`, `hstack`, `trace`, `linear_op`, `index` +- **`elementwise_univariate/`** - Scalar functions applied elementwise: `log`, `exp`, `entr`, `power`, `logistic`, `xexp`, trigonometric (`sin`, `cos`, `tan`), hyperbolic (`sinh`, `tanh`, `asinh`, `atanh`) +- **`bivariate/`** - Two-argument operations: `multiply`, `quad_over_lin`, `rel_entr`, `const_scalar_mult`, `const_vector_mult`, `left_matmul`, `right_matmul` +- **`other/`** - Special atoms: `quad_form`, `prod` Each atom implements its own `forward`, `jacobian_init`, `eval_jacobian`, and `eval_wsum_hess` functions following a consistent pattern. @@ -87,7 +86,8 @@ The Python package `dnlp_diff_engine` (in `src/dnlp_diff_engine/`) provides: **High-level API** (`__init__.py`): - `C_problem` class wraps the C problem struct - `convert_problem()` builds expression trees from CVXPY Problem objects -- Atoms are mapped via `ATOM_CONVERTERS` dictionary +- Atoms are mapped via `ATOM_CONVERTERS` dictionary (maps CVXPY atom names → converter functions) +- Special converters handle: matrix multiplication (`_convert_matmul`), multiply with constants (`_convert_multiply`), indexing, reshape (Fortran order only) **Low-level C extension** (`_core` module, built from `python/bindings.c`): - Atom constructors: `make_variable`, `make_constant`, `make_log`, `make_exp`, `make_add`, etc. @@ -108,14 +108,15 @@ Hessian computes weighted sum: `obj_w * H_obj + sum(lambda_i * H_constraint_i)` ## Key Directories -- `include/` - Header files defining public API -- `src/` - C implementation files -- `src/dnlp_diff_engine/` - Python package (installed via pip) -- `python/` - Python bindings C code and binding headers +- `include/` - Header files defining public API (`expr.h`, `problem.h`, atom headers) +- `src/` - C implementation files organized by atom category +- `src/dnlp_diff_engine/` - Python package with high-level API +- `python/` - Python bindings C code (`bindings.c`) - `python/atoms/` - Python binding headers for each atom type - `python/problem/` - Python binding headers for problem interface +- `python/tests/` - Python integration tests (run via pytest) - `tests/` - C tests using minunit framework -- `tests/python/` - Python tests (run via pytest) +- `tests/forward_pass/` - Forward evaluation tests (C) - `tests/jacobian_tests/` - Jacobian correctness tests (C) - `tests/wsum_hess/` - Hessian correctness tests (C) diff --git a/include/subexpr.h b/include/subexpr.h index 8f1c081..ae5be93 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -115,11 +115,8 @@ typedef struct index_expr expr base; int *indices; /* Flattened indices to select (owned, copied) */ int n_selected; /* Number of selected elements */ - /* Pre-computed for fast jacobian eval: */ - int *jac_row_starts; /* Position in jacobian->x where each selected row starts */ - int *jac_row_lengths; /* Number of nonzeros in each selected row */ - /* Pre-allocated for wsum_hess: */ - double *parent_w; /* Scatter buffer (size = child->size), zeroed each eval */ + bool has_duplicates; /* True if indices contain duplicates (affects Hessian path) */ + double *parent_w; /* Scatter buffer for wsum_hess (size = child->size) */ } index_expr; #endif /* SUBEXPR_H */ diff --git a/src/affine/index.c b/src/affine/index.c index f909ab1..18aa120 100644 --- a/src/affine/index.c +++ b/src/affine/index.c @@ -7,6 +7,24 @@ /* Index/slicing: y = child[indices] where indices is a list of flattened positions */ +/* Check if indices array contains duplicates using a bitmap. + * Returns true if duplicates exist, false otherwise. */ +static bool check_for_duplicates(const int *indices, int n_selected, int max_idx) +{ + bool *seen = (bool *)calloc(max_idx, sizeof(bool)); + bool has_dup = false; + for (int i = 0; i < n_selected && !has_dup; i++) + { + if (seen[indices[i]]) + { + has_dup = true; + } + seen[indices[i]] = true; + } + free(seen); + return has_dup; +} + static void forward(expr *node, const double *u) { expr *child = node->left; @@ -27,41 +45,29 @@ static void jacobian_init(expr *node) expr *child = node->left; index_expr *idx = (index_expr *)node; - /* initialize child's jacobian */ child->jacobian_init(child); CSR_Matrix *J_child = child->jacobian; - /* allocate mapping arrays */ - idx->jac_row_starts = (int *)malloc(idx->n_selected * sizeof(int)); - idx->jac_row_lengths = (int *)malloc(idx->n_selected * sizeof(int)); - - /* count nnz and pre-compute row mapping */ + /* count nnz */ int nnz = 0; for (int i = 0; i < idx->n_selected; i++) { int row = idx->indices[i]; - int row_len = J_child->p[row + 1] - J_child->p[row]; - idx->jac_row_starts[i] = nnz; - idx->jac_row_lengths[i] = row_len; - nnz += row_len; + nnz += J_child->p[row + 1] - J_child->p[row]; } - /* allocate jacobian with computed sparsity */ node->jacobian = new_csr_matrix(idx->n_selected, node->n_vars, nnz); + CSR_Matrix *J = node->jacobian; - /* fill sparsity pattern (p and i arrays) */ - node->jacobian->p[0] = 0; + /* fill p and i arrays in one pass */ + J->p[0] = 0; for (int i = 0; i < idx->n_selected; i++) { int row = idx->indices[i]; - int row_len = idx->jac_row_lengths[i]; - /* copy column indices from child row */ - if (row_len > 0) - { - memcpy(&node->jacobian->i[idx->jac_row_starts[i]], - &J_child->i[J_child->p[row]], row_len * sizeof(int)); - } - node->jacobian->p[i + 1] = idx->jac_row_starts[i] + row_len; + int src = J_child->p[row]; + int len = J_child->p[row + 1] - src; + memcpy(J->i + J->p[i], J_child->i + src, len * sizeof(int)); + J->p[i + 1] = J->p[i] + len; } } @@ -69,22 +75,17 @@ static void eval_jacobian(expr *node) { expr *child = node->left; index_expr *idx = (index_expr *)node; - CSR_Matrix *J_child = child->jacobian; - /* evaluate child's jacobian */ child->eval_jacobian(child); - /* fast copy using pre-computed mapping */ + CSR_Matrix *J = node->jacobian; + CSR_Matrix *J_child = child->jacobian; + for (int i = 0; i < idx->n_selected; i++) { int row = idx->indices[i]; - int row_len = idx->jac_row_lengths[i]; - if (row_len > 0) - { - memcpy(&node->jacobian->x[idx->jac_row_starts[i]], - &J_child->x[J_child->p[row]], - row_len * sizeof(double)); - } + int len = J->p[i + 1] - J->p[i]; + memcpy(J->x + J->p[i], J_child->x + J_child->p[row], len * sizeof(double)); } } @@ -112,13 +113,22 @@ static void eval_wsum_hess(expr *node, const double *w) expr *child = node->left; index_expr *idx = (index_expr *)node; - /* zero the scatter buffer */ - memset(idx->parent_w, 0, child->size * sizeof(double)); - - /* scatter w to child size with accumulation (handles repeated indices) */ - for (int i = 0; i < idx->n_selected; i++) + if (idx->has_duplicates) + { + /* slow path: must zero and accumulate for repeated indices */ + memset(idx->parent_w, 0, child->size * sizeof(double)); + for (int i = 0; i < idx->n_selected; i++) + { + idx->parent_w[idx->indices[i]] += w[i]; + } + } + else { - idx->parent_w[idx->indices[i]] += w[i]; + /* fast path: direct write (no memset needed, no accumulation) */ + for (int i = 0; i < idx->n_selected; i++) + { + idx->parent_w[idx->indices[i]] = w[i]; + } } /* delegate to child */ @@ -141,16 +151,6 @@ static void free_type_data(expr *node) free(idx->indices); idx->indices = NULL; } - if (idx->jac_row_starts) - { - free(idx->jac_row_starts); - idx->jac_row_starts = NULL; - } - if (idx->jac_row_lengths) - { - free(idx->jac_row_lengths); - idx->jac_row_lengths = NULL; - } if (idx->parent_w) { free(idx->parent_w); @@ -178,9 +178,10 @@ expr *new_index(expr *child, const int *indices, int n_selected) memcpy(idx->indices, indices, n_selected * sizeof(int)); idx->n_selected = n_selected; - /* mapping arrays allocated lazily in jacobian_init */ - idx->jac_row_starts = NULL; - idx->jac_row_lengths = NULL; + /* detect duplicates for Hessian optimization */ + idx->has_duplicates = check_for_duplicates(indices, n_selected, child->size); + + /* parent_w allocated lazily in wsum_hess_init */ idx->parent_w = NULL; return node; From 2fcc71c7f14dcae4e81523eb8675deca7a5fd735 Mon Sep 17 00:00:00 2001 From: Daniel Date: Tue, 13 Jan 2026 08:54:34 -0800 Subject: [PATCH 6/6] small edits, most substantial thing removed unnecessary loop --- include/affine.h | 2 +- include/subexpr.h | 7 +- python/atoms/index.h | 15 +-- src/affine/index.c | 147 ++++++++++++++---------------- src/other/quad_form.c | 1 - tests/jacobian_tests/test_index.h | 14 ++- 6 files changed, 88 insertions(+), 98 deletions(-) diff --git a/include/affine.h b/include/affine.h index f7943eb..c1dd13b 100644 --- a/include/affine.h +++ b/include/affine.h @@ -18,6 +18,6 @@ expr *new_trace(expr *child); expr *new_constant(int d1, int d2, int n_vars, const double *values); expr *new_variable(int d1, int d2, int var_id, int n_vars); -expr *new_index(expr *child, const int *indices, int n_selected); +expr *new_index(expr *child, const int *indices, int n_idxs); #endif /* AFFINE_H */ diff --git a/include/subexpr.h b/include/subexpr.h index ae5be93..97a31bf 100644 --- a/include/subexpr.h +++ b/include/subexpr.h @@ -109,14 +109,13 @@ typedef struct const_vector_mult_expr double *a; /* length equals node->size */ } const_vector_mult_expr; -/* Index/slicing: y = child[indices] where indices is a list of flattened positions */ +/* Index/slicing: y = child[indices] where indices is a list of flat positions */ typedef struct index_expr { expr base; int *indices; /* Flattened indices to select (owned, copied) */ - int n_selected; /* Number of selected elements */ - bool has_duplicates; /* True if indices contain duplicates (affects Hessian path) */ - double *parent_w; /* Scatter buffer for wsum_hess (size = child->size) */ + int n_idxs; /* Number of selected elements */ + bool has_duplicates; /* True if indices have duplicates (affects Hessian path) */ } index_expr; #endif /* SUBEXPR_H */ diff --git a/python/atoms/index.h b/python/atoms/index.h index a4d1dfd..035ac56 100644 --- a/python/atoms/index.h +++ b/python/atoms/index.h @@ -6,7 +6,8 @@ #include "affine.h" #include "common.h" -/* Index/slicing: y = child[indices] where indices is a list of flattened positions */ +/* Index/slicing: y = child[indices] where indices is a list of flattened positions + */ static PyObject *py_make_index(PyObject *self, PyObject *args) { PyObject *child_capsule; @@ -17,7 +18,7 @@ static PyObject *py_make_index(PyObject *self, PyObject *args) return NULL; } - expr *child = (expr *)PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); + expr *child = (expr *) PyCapsule_GetPointer(child_capsule, EXPR_CAPSULE_NAME); if (!child) { PyErr_SetString(PyExc_ValueError, "invalid child capsule"); @@ -25,18 +26,18 @@ static PyObject *py_make_index(PyObject *self, PyObject *args) } /* Convert indices array to int32 */ - PyArrayObject *indices_array = - (PyArrayObject *)PyArray_FROM_OTF(indices_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY); + PyArrayObject *indices_array = (PyArrayObject *) PyArray_FROM_OTF( + indices_obj, NPY_INT32, NPY_ARRAY_IN_ARRAY); if (!indices_array) { return NULL; } - int n_selected = (int)PyArray_SIZE(indices_array); - int *indices_data = (int *)PyArray_DATA(indices_array); + int n_idxs = (int) PyArray_SIZE(indices_array); + int *indices_data = (int *) PyArray_DATA(indices_array); - expr *node = new_index(child, indices_data, n_selected); + expr *node = new_index(child, indices_data, n_idxs); Py_DECREF(indices_array); diff --git a/src/affine/index.c b/src/affine/index.c index 18aa120..c407daa 100644 --- a/src/affine/index.c +++ b/src/affine/index.c @@ -1,19 +1,17 @@ -// SPDX-License-Identifier: Apache-2.0 - #include "affine.h" #include "subexpr.h" #include #include -/* Index/slicing: y = child[indices] where indices is a list of flattened positions */ +/* Index/slicing: y = child[indices] where indices is a list of flat positions */ /* Check if indices array contains duplicates using a bitmap. * Returns true if duplicates exist, false otherwise. */ -static bool check_for_duplicates(const int *indices, int n_selected, int max_idx) +static bool check_for_duplicates(const int *indices, int n_idxs, int max_idx) { - bool *seen = (bool *)calloc(max_idx, sizeof(bool)); + bool *seen = (bool *) calloc(max_idx, sizeof(bool)); bool has_dup = false; - for (int i = 0; i < n_selected && !has_dup; i++) + for (int i = 0; i < n_idxs && !has_dup; i++) { if (seen[indices[i]]) { @@ -27,115 +25,109 @@ static bool check_for_duplicates(const int *indices, int n_selected, int max_idx static void forward(expr *node, const double *u) { - expr *child = node->left; - index_expr *idx = (index_expr *)node; + expr *x = node->left; + index_expr *idx = (index_expr *) node; /* child's forward pass */ - child->forward(child, u); + x->forward(x, u); /* gather selected elements */ - for (int i = 0; i < idx->n_selected; i++) + for (int i = 0; i < idx->n_idxs; i++) { - node->value[i] = child->value[idx->indices[i]]; + node->value[i] = x->value[idx->indices[i]]; } } static void jacobian_init(expr *node) { - expr *child = node->left; - index_expr *idx = (index_expr *)node; - - child->jacobian_init(child); - CSR_Matrix *J_child = child->jacobian; - - /* count nnz */ - int nnz = 0; - for (int i = 0; i < idx->n_selected; i++) - { - int row = idx->indices[i]; - nnz += J_child->p[row + 1] - J_child->p[row]; - } + expr *x = node->left; + index_expr *idx = (index_expr *) node; + x->jacobian_init(x); - node->jacobian = new_csr_matrix(idx->n_selected, node->n_vars, nnz); - CSR_Matrix *J = node->jacobian; + CSR_Matrix *Jx = x->jacobian; + CSR_Matrix *J = new_csr_matrix(node->size, node->n_vars, Jx->nnz); - /* fill p and i arrays in one pass */ + /* set sparsity pattern */ J->p[0] = 0; - for (int i = 0; i < idx->n_selected; i++) + for (int i = 0; i < idx->n_idxs; i++) { int row = idx->indices[i]; - int src = J_child->p[row]; - int len = J_child->p[row + 1] - src; - memcpy(J->i + J->p[i], J_child->i + src, len * sizeof(int)); + int len = Jx->p[row + 1] - Jx->p[row]; + memcpy(J->i + J->p[i], Jx->i + Jx->p[row], len * sizeof(int)); J->p[i + 1] = J->p[i] + len; } + + node->jacobian = J; } static void eval_jacobian(expr *node) { - expr *child = node->left; - index_expr *idx = (index_expr *)node; - - child->eval_jacobian(child); + expr *x = node->left; + index_expr *idx = (index_expr *) node; + x->eval_jacobian(x); CSR_Matrix *J = node->jacobian; - CSR_Matrix *J_child = child->jacobian; + CSR_Matrix *Jx = x->jacobian; - for (int i = 0; i < idx->n_selected; i++) + for (int i = 0; i < idx->n_idxs; i++) { int row = idx->indices[i]; int len = J->p[i + 1] - J->p[i]; - memcpy(J->x + J->p[i], J_child->x + J_child->p[row], len * sizeof(double)); + memcpy(J->x + J->p[i], Jx->x + Jx->p[row], len * sizeof(double)); } } static void wsum_hess_init(expr *node) { - expr *child = node->left; - index_expr *idx = (index_expr *)node; + expr *x = node->left; /* initialize child's wsum_hess */ - child->wsum_hess_init(child); - - /* allocate scatter buffer (zeroed) */ - idx->parent_w = (double *)calloc(child->size, sizeof(double)); - - /* wsum_hess inherits from child (affine has no local Hessian) */ - /* We need to allocate our own to avoid aliasing issues */ - CSR_Matrix *child_hess = child->wsum_hess; - node->wsum_hess = new_csr_matrix(child_hess->m, child_hess->n, child_hess->nnz); - memcpy(node->wsum_hess->p, child_hess->p, (child_hess->m + 1) * sizeof(int)); - memcpy(node->wsum_hess->i, child_hess->i, child_hess->nnz * sizeof(int)); + x->wsum_hess_init(x); + + /* for setting weight vector to evaluate hessian of child */ + node->dwork = (double *) calloc(x->size, sizeof(double)); + + /* in the implementation of eval_wsum_hess we evaluate the + child's hessian with a weight vector that has w[i] = 0 + if i is not included in idx->indices. This can lead to + many numerical zeros in child->wsum_hess that are actually + structural zeros, but we do not try to exploit that sparsity + right now. */ + CSR_Matrix *H_child = x->wsum_hess; + node->wsum_hess = new_csr_matrix(H_child->m, H_child->n, H_child->nnz); + memcpy(node->wsum_hess->p, H_child->p, (H_child->m + 1) * sizeof(int)); + memcpy(node->wsum_hess->i, H_child->i, H_child->nnz * sizeof(int)); } static void eval_wsum_hess(expr *node, const double *w) { expr *child = node->left; - index_expr *idx = (index_expr *)node; + index_expr *idx = (index_expr *) node; if (idx->has_duplicates) { - /* slow path: must zero and accumulate for repeated indices */ - memset(idx->parent_w, 0, child->size * sizeof(double)); - for (int i = 0; i < idx->n_selected; i++) + /* zero and accumulate for repeated indices */ + memset(node->dwork, 0, child->size * sizeof(double)); + for (int i = 0; i < idx->n_idxs; i++) { - idx->parent_w[idx->indices[i]] += w[i]; + node->dwork[idx->indices[i]] += w[i]; } } else { - /* fast path: direct write (no memset needed, no accumulation) */ - for (int i = 0; i < idx->n_selected; i++) + /* direct write (no memset needed, no accumulation) */ + for (int i = 0; i < idx->n_idxs; i++) { - idx->parent_w[idx->indices[i]] = w[i]; + node->dwork[idx->indices[i]] = w[i]; } } /* delegate to child */ - child->eval_wsum_hess(child, idx->parent_w); + child->eval_wsum_hess(child, node->dwork); /* copy values from child */ - memcpy(node->wsum_hess->x, child->wsum_hess->x, child->wsum_hess->nnz * sizeof(double)); + memcpy(node->wsum_hess->x, child->wsum_hess->x, + child->wsum_hess->nnz * sizeof(double)); } static bool is_affine(const expr *node) @@ -145,44 +137,37 @@ static bool is_affine(const expr *node) static void free_type_data(expr *node) { - index_expr *idx = (index_expr *)node; + index_expr *idx = (index_expr *) node; if (idx->indices) { free(idx->indices); idx->indices = NULL; } - if (idx->parent_w) - { - free(idx->parent_w); - idx->parent_w = NULL; - } } -expr *new_index(expr *child, const int *indices, int n_selected) +expr *new_index(expr *child, const int *indices, int n_idxs) { /* allocate type-specific struct */ - index_expr *idx = (index_expr *)calloc(1, sizeof(index_expr)); + index_expr *idx = (index_expr *) calloc(1, sizeof(index_expr)); expr *node = &idx->base; - /* output shape is (n_selected, 1) - flattened */ - init_expr(node, n_selected, 1, child->n_vars, forward, jacobian_init, - eval_jacobian, is_affine, free_type_data); + /* output shape is (n_idxs, 1) - flattened */ + init_expr(node, n_idxs, 1, child->n_vars, forward, jacobian_init, eval_jacobian, + is_affine, free_type_data); - node->wsum_hess_init = wsum_hess_init; - node->eval_wsum_hess = eval_wsum_hess; node->left = child; expr_retain(child); + node->wsum_hess_init = wsum_hess_init; + node->eval_wsum_hess = eval_wsum_hess; + /* copy indices */ - idx->indices = (int *)malloc(n_selected * sizeof(int)); - memcpy(idx->indices, indices, n_selected * sizeof(int)); - idx->n_selected = n_selected; + idx->indices = (int *) malloc(n_idxs * sizeof(int)); + memcpy(idx->indices, indices, n_idxs * sizeof(int)); + idx->n_idxs = n_idxs; /* detect duplicates for Hessian optimization */ - idx->has_duplicates = check_for_duplicates(indices, n_selected, child->size); - - /* parent_w allocated lazily in wsum_hess_init */ - idx->parent_w = NULL; + idx->has_duplicates = check_for_duplicates(indices, n_idxs, child->size); return node; } diff --git a/src/other/quad_form.c b/src/other/quad_form.c index 04badb8..1fb68a3 100644 --- a/src/other/quad_form.c +++ b/src/other/quad_form.c @@ -27,7 +27,6 @@ static void forward(expr *node, const double *u) static void jacobian_init(expr *node) { - CSR_Matrix *Q = ((quad_form_expr *) node)->Q; assert(node->left->var_id != NOT_A_VARIABLE); assert(node->left->d2 == 1); expr *x = node->left; diff --git a/tests/jacobian_tests/test_index.h b/tests/jacobian_tests/test_index.h index 5bf73bb..7f3c905 100644 --- a/tests/jacobian_tests/test_index.h +++ b/tests/jacobian_tests/test_index.h @@ -83,8 +83,10 @@ const char *test_index_jacobian_of_log(void) double expected_x[2] = {1.0, 0.25}; int expected_i[2] = {0, 2}; - mu_assert("index of log jac vals", cmp_double_array(idx->jacobian->x, expected_x, 2)); - mu_assert("index of log jac cols", cmp_int_array(idx->jacobian->i, expected_i, 2)); + mu_assert("index of log jac vals", + cmp_double_array(idx->jacobian->x, expected_x, 2)); + mu_assert("index of log jac cols", + cmp_int_array(idx->jacobian->i, expected_i, 2)); free_expr(idx); return 0; @@ -106,8 +108,12 @@ const char *test_index_jacobian_repeated(void) int expected_p[3] = {0, 1, 2}; int expected_i[2] = {0, 0}; /* Both reference col 0 */ - mu_assert("index repeated jac vals", cmp_double_array(idx->jacobian->x, expected_x, 2)); - mu_assert("index repeated jac i", cmp_int_array(idx->jacobian->i, expected_i, 2)); + mu_assert("index repeated jac vals", + cmp_double_array(idx->jacobian->x, expected_x, 2)); + mu_assert("index repeated row ptr", + cmp_int_array(idx->jacobian->p, expected_p, 4)); + mu_assert("index repeated jac i", + cmp_int_array(idx->jacobian->i, expected_i, 2)); free_expr(idx); return 0;